Commit 8ff0c0d2 authored by wujl5's avatar wujl5
Browse files

rmsquant融合的add和rms功能可以使用了,quant的tensor输入到模型后输出错误, moe的gate-up层未量化

parent 2dd0894f
......@@ -6,7 +6,7 @@ import os
from collections.abc import Callable, Iterable
from contextlib import nullcontext
from enum import Enum
from typing import Literal, cast, get_args, overload
from typing import Literal, cast, get_args, overload, Optional
import torch
import torch.nn.functional as F
......@@ -1669,6 +1669,8 @@ class FusedMoE(CustomOp):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states:
......@@ -1720,7 +1722,9 @@ class FusedMoE(CustomOp):
)
else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name()
hidden_states, router_logits, encode_layer_name(),
i_q=i_q,
i_s=i_s
)
return (
reduce_output(shared_output)[..., :og_hidden_states],
......@@ -1737,8 +1741,10 @@ class FusedMoE(CustomOp):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits)
return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s)
def forward_impl_chunked(
self,
......@@ -1880,6 +1886,8 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
use_fused_gate: bool | None = False,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
......@@ -1904,6 +1912,7 @@ class FusedMoE(CustomOp):
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
print("YYYY: unsupported using gate in FusedMOE.forward_impl") # true
router_logits, _ = self.gate(hidden_states)
if use_chunked_impl:
......@@ -2005,12 +2014,24 @@ class FusedMoE(CustomOp):
if self.capture is not None:
self.capture(topk_ids)
if envs.USE_FUSED_RMS_QUANT:
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe,
i_q=i_q,
i_s=i_s
)
else:
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe
)
if has_separate_shared_experts:
......@@ -2133,16 +2154,20 @@ def moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> torch.Tensor:
self = get_layer_from_name(layer_name)
assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits)
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
def moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -2160,9 +2185,20 @@ def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
import traceback
import torch.distributed as dist
print("=======")
# if dist.get_rank() == 0:
# traceback.print_stack()
print("=======")
self = get_layer_from_name(layer_name)
assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
else:
return self.forward_impl(hidden_states, router_logits)
......@@ -2170,6 +2206,8 @@ def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
......
......@@ -62,9 +62,13 @@ class SharedFusedMoE(FusedMoE):
router_logits: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
if self._shared_experts is not None: print("YYYY: 处理下游专家输入, self.use_overlapped is :", self.use_overlapped)
if not self.use_overlapped:
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states)
shared_out = self._shared_experts(hidden_states,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None)
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
......@@ -77,14 +81,20 @@ class SharedFusedMoE(FusedMoE):
else:
shared_out = None
print("YYYY: i suppose not this branch!!!")
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None,
)
else:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None,
)
# ensure early TP reduction of shared expert outputs when required
if (
......
......@@ -339,11 +339,11 @@ class FusedRMSNormQuant(nn.Module):
update_input: Optional[bool] = True
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
x, x_scales = fused_rmsquant(x, self.weight,
i_q, i_s = fused_rmsquant(x, self.weight,
self.variance_epsilon,
quant_dtype, residual,
update_input)
return x, x_scales, residual
return i_q, i_s, residual
def fused_rmsquant_impl(
......
......@@ -716,6 +716,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Matrix multiply.
assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
print("YYYYY: mlp.gate_up self.quant_method.apply: ", self.quant_method.apply)
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
else:
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
......
......@@ -8,6 +8,7 @@ from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm import envs
@dataclass
......@@ -130,6 +131,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
assert self.q_b_proj is not None, (
"q_b_proj is required when q_lora_rank is not None"
)
# print("YYYY: self.fused_qkv_a_proj is:", self.fused_qkv_a_proj) # MergedColumnParallelLinear(in_features=7168, output_features=2112, bias=False, tp_size=1, gather_output=False)
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
qkv_lora = self.fused_qkv_a_proj(hidden_states, iqis=iqis)[0]
else:
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
......@@ -138,6 +143,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
# print("YYYY: self.q_lora_rank is None. unsupported for now!!!")
assert self.kv_a_proj_with_mqa is not None, (
"kv_a_proj_with_mqa is required when q_lora_rank is None"
)
......
......@@ -398,6 +398,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -420,6 +422,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
i_q=i_q,
i_s=i_s
)
def select_gemm_impl(
......
......@@ -345,7 +345,7 @@ class DeepseekV2MoE(nn.Module):
else None,
)
def forward(self, hidden_states: torch.Tensor,
def forward_default(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
......@@ -413,12 +413,15 @@ class DeepseekV2MoE(nn.Module):
hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router:
# print("YYYY: self.experts.is_internal_router is True.") # True
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states, iqis=iqis
)
else:
else: # NO
print("YYYY: self.experts.is_internal_router is False.")
# router_logits: (num_tokens, n_experts)
print("YYYY: RQ MOE.gate quant: self.gate.quant_method.apply: ", self.gate.quant_method.apply)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits, iqis=iqis
......@@ -453,6 +456,18 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states.view(num_tokens, hidden_dim)
def choose_forward(self):
if envs.USE_FUSED_RMS_QUANT:
return self.forward_RQ
else:
return self.forward_default
def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
forward_func = self.choose_forward()
return forward_func(hidden_states, iqis=iqis)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
import math
......@@ -1078,14 +1093,19 @@ class DeepseekV2DecoderLayer(nn.Module):
) -> torch.Tensor:
# Self Attention
# Fix residual FP16 overflow
# print("YYYYY: forward_RQ is called")
residual_fix_overflow = False
assert self.input_layernorm.has_weight is True
if residual is None:
residual = hidden_states.clone()
i_q, i_s, _ = self.input_layernorm(x=hidden_states,
residual=None,
quant_dtype=torch.int8,
update_input=False)
# update_input=False # wjl, del
update_input=True
)
# print("YYYY: i_q:", i_q.flatten()[:5], i_q.shape)
# print("YYYY: i_s:", i_s.flatten()[:5], i_s.shape)
# i_q, i_s = lm_faster_rmsquant(input=hidden_states,
# rms_weight = self.input_layernorm.weight.data,
# epsilon=self._eps,
......@@ -1101,17 +1121,20 @@ class DeepseekV2DecoderLayer(nn.Module):
# quant_dtype=torch.int8,
# residual=residual,
# update_input=False)
# print("YYYY: input rms residual bf", residual.flatten()[:5])
i_q, i_s, residual = self.input_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=False)
# update_input=False, # wjl, del
update_input=True,
)
# print("YYYY:input rms residual af", residual.flatten()[:5])
attn_kwargs = {
"positions": positions,
"hidden_states": hidden_states,
"iqis": (i_q, i_s)
# "iqis": (i_q, i_s) # wjl, del
}
if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling
......@@ -1121,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
not isinstance(self.self_attn, DeepseekAttention)
and hidden_states.dtype == torch.float16
):
print("YYYYY: FP16 overflow fix is applied")
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
......@@ -1139,14 +1163,18 @@ class DeepseekV2DecoderLayer(nn.Module):
# quant_dtype=torch.int8,
# residual=residual,
# update_input=update_hs)
assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=update_hs)
update_input=update_hs, # wjl, del
# update_input=True,
)
new_resi = residual
hidden_states = self.mlp(hidden_states,
iqis=(_i_q, _i_s))
# iqis=(_i_q, _i_s) # wjl, del
)
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment