Commit 39096bf4 authored by wujl5's avatar wujl5
Browse files

marlin情况下eager+rmsquant输出正常

parent 8ff0c0d2
...@@ -1670,7 +1670,7 @@ class FusedMoE(CustomOp): ...@@ -1670,7 +1670,7 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_ i_s: Optional[torch.Tensor] = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1] og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states: if self.hidden_size != og_hidden_states:
...@@ -1742,7 +1742,7 @@ class FusedMoE(CustomOp): ...@@ -1742,7 +1742,7 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_ i_s: Optional[torch.Tensor] = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s) return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s)
...@@ -1912,7 +1912,6 @@ class FusedMoE(CustomOp): ...@@ -1912,7 +1912,6 @@ class FusedMoE(CustomOp):
# parallel execution of shared experts with the FusedMoE via # parallel execution of shared experts with the FusedMoE via
# separate cuda stream) # separate cuda stream)
if self.gate is not None: if self.gate is not None:
print("YYYY: unsupported using gate in FusedMOE.forward_impl") # true
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if use_chunked_impl: if use_chunked_impl:
...@@ -2188,12 +2187,6 @@ def moe_forward_shared( ...@@ -2188,12 +2187,6 @@ def moe_forward_shared(
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
import traceback
import torch.distributed as dist
print("=======")
# if dist.get_rank() == 0:
# traceback.print_stack()
print("=======")
self = get_layer_from_name(layer_name) self = get_layer_from_name(layer_name)
assert self.shared_experts is not None assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
......
...@@ -62,8 +62,6 @@ class SharedFusedMoE(FusedMoE): ...@@ -62,8 +62,6 @@ class SharedFusedMoE(FusedMoE):
router_logits: torch.Tensor, router_logits: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if self._shared_experts is not None: print("YYYY: 处理下游专家输入, self.use_overlapped is :", self.use_overlapped)
if not self.use_overlapped: if not self.use_overlapped:
if self._shared_experts is not None: if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states, shared_out = self._shared_experts(hidden_states,
...@@ -81,8 +79,6 @@ class SharedFusedMoE(FusedMoE): ...@@ -81,8 +79,6 @@ class SharedFusedMoE(FusedMoE):
else: else:
shared_out = None shared_out = None
print("YYYY: i suppose not this branch!!!")
fused_out = super().forward( fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -21,7 +21,6 @@ from vllm import envs ...@@ -21,7 +21,6 @@ from vllm import envs
from lightop.op import rms_norm_dynamic_per_token_quant as ligtop_rms_norm_dynamic_per_token_quant from lightop.op import rms_norm_dynamic_per_token_quant as ligtop_rms_norm_dynamic_per_token_quant
def rms_norm( def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -395,7 +394,7 @@ def fused_rmsquant(input: torch.Tensor, ...@@ -395,7 +394,7 @@ def fused_rmsquant(input: torch.Tensor,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
update_input: Optional[bool] = True): update_input: Optional[bool] = True):
from lmslim.quantize.quant_ops import lm_faster_rmsquant from lmslim.quantize.quant_ops import lm_faster_rmsquant # TODO:wjl
i_q, _scales = lm_faster_rmsquant(input=input, i_q, _scales = lm_faster_rmsquant(input=input,
rms_weight=rms_weight, rms_weight=rms_weight,
epsilon=epsilon, epsilon=epsilon,
......
...@@ -717,7 +717,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -717,7 +717,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT and iqis is not None: if envs.USE_FUSED_RMS_QUANT and iqis is not None:
print("YYYYY: mlp.gate_up self.quant_method.apply: ", self.quant_method.apply)
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis) output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
else: else:
output_parallel = self.quant_method.apply(self, input_, bias) output_parallel = self.quant_method.apply(self, input_, bias)
......
...@@ -131,7 +131,6 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -131,7 +131,6 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
assert self.q_b_proj is not None, ( assert self.q_b_proj is not None, (
"q_b_proj is required when q_lora_rank is not None" "q_b_proj is required when q_lora_rank is not None"
) )
# print("YYYY: self.fused_qkv_a_proj is:", self.fused_qkv_a_proj) # MergedColumnParallelLinear(in_features=7168, output_features=2112, bias=False, tp_size=1, gather_output=False)
if envs.USE_FUSED_RMS_QUANT and iqis is not None: if envs.USE_FUSED_RMS_QUANT and iqis is not None:
qkv_lora = self.fused_qkv_a_proj(hidden_states, iqis=iqis)[0] qkv_lora = self.fused_qkv_a_proj(hidden_states, iqis=iqis)[0]
else: else:
...@@ -143,7 +142,6 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -143,7 +142,6 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
q_c = self.q_a_layernorm(q_c) q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0] q = self.q_b_proj(q_c)[0]
else: else:
# print("YYYY: self.q_lora_rank is None. unsupported for now!!!")
assert self.kv_a_proj_with_mqa is not None, ( assert self.kv_a_proj_with_mqa is not None, (
"kv_a_proj_with_mqa is required when q_lora_rank is None" "kv_a_proj_with_mqa is required when q_lora_rank is None"
) )
......
...@@ -1255,6 +1255,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1255,6 +1255,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -1271,6 +1273,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1271,6 +1273,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_map=layer.expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
i_q=i_q,
i_s=i_s
) )
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm import envs
try: try:
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
...@@ -167,6 +168,15 @@ def apply_int8_linear( ...@@ -167,6 +168,15 @@ def apply_int8_linear(
# ops.scaled_int8_quant supports both dynamic and static quant. # ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x. # * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale. # * static, layer.input_scale is scalar and x_scale is input_scale.
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_zp =None
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_RMS_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
x_zp =None
x_q, x_scale = silu_quant_args
else:
symmetric = azp_adj is None symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True: if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input) x_q, x_scale=per_token_quant_int8(input)
......
...@@ -413,15 +413,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -413,15 +413,11 @@ class DeepseekV2MoE(nn.Module):
hidden_states = sequence_parallel_chunk(hidden_states) hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router: if self.experts.is_internal_router:
# print("YYYY: self.experts.is_internal_router is True.") # True
# In this case, the gate/router runs inside the FusedMoE class # In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts( fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states, iqis=iqis hidden_states=hidden_states, router_logits=hidden_states, iqis=iqis
) )
else: # NO else:
print("YYYY: self.experts.is_internal_router is False.")
# router_logits: (num_tokens, n_experts)
print("YYYY: RQ MOE.gate quant: self.gate.quant_method.apply: ", self.gate.quant_method.apply)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts( fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits, iqis=iqis hidden_states=hidden_states, router_logits=router_logits, iqis=iqis
...@@ -1093,7 +1089,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1093,7 +1089,6 @@ class DeepseekV2DecoderLayer(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
# Fix residual FP16 overflow # Fix residual FP16 overflow
# print("YYYYY: forward_RQ is called")
residual_fix_overflow = False residual_fix_overflow = False
assert self.input_layernorm.has_weight is True assert self.input_layernorm.has_weight is True
if residual is None: if residual is None:
...@@ -1101,40 +1096,20 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1101,40 +1096,20 @@ class DeepseekV2DecoderLayer(nn.Module):
i_q, i_s, _ = self.input_layernorm(x=hidden_states, i_q, i_s, _ = self.input_layernorm(x=hidden_states,
residual=None, residual=None,
quant_dtype=torch.int8, quant_dtype=torch.int8,
# update_input=False # wjl, del update_input=False
update_input=True )
)
# print("YYYY: i_q:", i_q.flatten()[:5], i_q.shape)
# print("YYYY: i_s:", i_s.flatten()[:5], i_s.shape)
# i_q, i_s = lm_faster_rmsquant(input=hidden_states,
# rms_weight = self.input_layernorm.weight.data,
# epsilon=self._eps,
# quant_dtype=torch.int8,
# residual=None,
# update_input=False)
residual_fix_overflow = True residual_fix_overflow = True
else: else:
# hidden_states, residual = self.input_layernorm(hidden_states, residual)
# i_q, i_s = lm_faster_rmsquant(input=hidden_states,
# rms_weight = self.input_layernorm.weight.data,
# epsilon=self._eps,
# quant_dtype=torch.int8,
# residual=residual,
# update_input=False)
# print("YYYY: input rms residual bf", residual.flatten()[:5])
i_q, i_s, residual = self.input_layernorm(x=hidden_states, i_q, i_s, residual = self.input_layernorm(x=hidden_states,
residual=residual, residual=residual,
quant_dtype=torch.int8, quant_dtype=torch.int8,
# update_input=False, # wjl, del update_input=False
update_input=True,
) )
# print("YYYY:input rms residual af", residual.flatten()[:5])
attn_kwargs = { attn_kwargs = {
"positions": positions, "positions": positions,
"hidden_states": hidden_states, "hidden_states": hidden_states,
# "iqis": (i_q, i_s) # wjl, del "iqis": (i_q, i_s)
} }
if not self.use_mha: if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling attn_kwargs["llama_4_scaling"] = llama_4_scaling
...@@ -1144,7 +1119,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1144,7 +1119,6 @@ class DeepseekV2DecoderLayer(nn.Module):
not isinstance(self.self_attn, DeepseekAttention) not isinstance(self.self_attn, DeepseekAttention)
and hidden_states.dtype == torch.float16 and hidden_states.dtype == torch.float16
): ):
print("YYYYY: FP16 overflow fix is applied")
# Fix FP16 overflow # Fix FP16 overflow
# We scale both hidden_states and residual before # We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale. # rmsnorm, and rmsnorm result would not affect by scale.
...@@ -1155,25 +1129,17 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1155,25 +1129,17 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1.0 / self.routed_scaling_factor residual *= 1.0 / self.routed_scaling_factor
# Fully Connected # Fully Connected
# hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
# _i_q, _i_s = lm_faster_rmsquant(input=hidden_states,
# rms_weight=self.post_attention_layernorm.weight.data,
# epsilon=self._eps,
# quant_dtype=torch.int8,
# residual=residual,
# update_input=update_hs)
assert self.post_attention_layernorm.has_weight is True assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states, _i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
residual=residual, residual=residual,
quant_dtype=torch.int8, quant_dtype=torch.int8,
update_input=update_hs, # wjl, del update_input=update_hs
# update_input=True,
) )
new_resi = residual new_resi = residual
hidden_states = self.mlp(hidden_states, hidden_states = self.mlp(hidden_states,
# iqis=(_i_q, _i_s) # wjl, del # iqis=(_i_q, _i_s) # TODO:wjl
) )
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
...@@ -75,9 +75,6 @@ from .utils import ( ...@@ -75,9 +75,6 @@ from .utils import (
) )
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
if envs.USE_FUSED_RMS_QUANT:
from lightop import rms_norm_dynamic_per_token_quant
logger = init_logger(__name__) logger = init_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment