Commit 39096bf4 authored by wujl5's avatar wujl5
Browse files

marlin情况下eager+rmsquant输出正常

parent 8ff0c0d2
......@@ -1670,7 +1670,7 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
i_s: Optional[torch.Tensor] = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states:
......@@ -1742,7 +1742,7 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
i_s: Optional[torch.Tensor] = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s)
......@@ -1912,7 +1912,6 @@ class FusedMoE(CustomOp):
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
print("YYYY: unsupported using gate in FusedMOE.forward_impl") # true
router_logits, _ = self.gate(hidden_states)
if use_chunked_impl:
......@@ -2188,12 +2187,6 @@ def moe_forward_shared(
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
import traceback
import torch.distributed as dist
print("=======")
# if dist.get_rank() == 0:
# traceback.print_stack()
print("=======")
self = get_layer_from_name(layer_name)
assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT:
......
......@@ -62,8 +62,6 @@ class SharedFusedMoE(FusedMoE):
router_logits: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
if self._shared_experts is not None: print("YYYY: 处理下游专家输入, self.use_overlapped is :", self.use_overlapped)
if not self.use_overlapped:
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states,
......@@ -80,8 +78,6 @@ class SharedFusedMoE(FusedMoE):
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
shared_out = None
print("YYYY: i suppose not this branch!!!")
fused_out = super().forward(
hidden_states=hidden_states,
......
......@@ -21,7 +21,6 @@ from vllm import envs
from lightop.op import rms_norm_dynamic_per_token_quant as ligtop_rms_norm_dynamic_per_token_quant
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
......@@ -395,7 +394,7 @@ def fused_rmsquant(input: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor] = None,
update_input: Optional[bool] = True):
from lmslim.quantize.quant_ops import lm_faster_rmsquant
from lmslim.quantize.quant_ops import lm_faster_rmsquant # TODO:wjl
i_q, _scales = lm_faster_rmsquant(input=input,
rms_weight=rms_weight,
epsilon=epsilon,
......
......@@ -717,7 +717,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Matrix multiply.
assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
print("YYYYY: mlp.gate_up self.quant_method.apply: ", self.quant_method.apply)
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
else:
output_parallel = self.quant_method.apply(self, input_, bias)
......
......@@ -131,7 +131,6 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
assert self.q_b_proj is not None, (
"q_b_proj is required when q_lora_rank is not None"
)
# print("YYYY: self.fused_qkv_a_proj is:", self.fused_qkv_a_proj) # MergedColumnParallelLinear(in_features=7168, output_features=2112, bias=False, tp_size=1, gather_output=False)
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
qkv_lora = self.fused_qkv_a_proj(hidden_states, iqis=iqis)[0]
else:
......@@ -143,7 +142,6 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
# print("YYYY: self.q_lora_rank is None. unsupported for now!!!")
assert self.kv_a_proj_with_mqa is not None, (
"kv_a_proj_with_mqa is required when q_lora_rank is None"
)
......
......@@ -1255,6 +1255,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -1271,6 +1273,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
use_nn_moe=use_nn_moe,
i_q=i_q,
i_s=i_s
)
......
......@@ -7,6 +7,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
from vllm import envs
try:
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
......@@ -167,15 +168,24 @@ def apply_int8_linear(
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_zp =None
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_RMS_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
x_zp =None
x_q, x_scale = silu_quant_args
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
......
......@@ -413,15 +413,11 @@ class DeepseekV2MoE(nn.Module):
hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router:
# print("YYYY: self.experts.is_internal_router is True.") # True
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states, iqis=iqis
)
else: # NO
print("YYYY: self.experts.is_internal_router is False.")
# router_logits: (num_tokens, n_experts)
print("YYYY: RQ MOE.gate quant: self.gate.quant_method.apply: ", self.gate.quant_method.apply)
else:
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits, iqis=iqis
......@@ -1093,7 +1089,6 @@ class DeepseekV2DecoderLayer(nn.Module):
) -> torch.Tensor:
# Self Attention
# Fix residual FP16 overflow
# print("YYYYY: forward_RQ is called")
residual_fix_overflow = False
assert self.input_layernorm.has_weight is True
if residual is None:
......@@ -1101,40 +1096,20 @@ class DeepseekV2DecoderLayer(nn.Module):
i_q, i_s, _ = self.input_layernorm(x=hidden_states,
residual=None,
quant_dtype=torch.int8,
# update_input=False # wjl, del
update_input=True
update_input=False
)
# print("YYYY: i_q:", i_q.flatten()[:5], i_q.shape)
# print("YYYY: i_s:", i_s.flatten()[:5], i_s.shape)
# i_q, i_s = lm_faster_rmsquant(input=hidden_states,
# rms_weight = self.input_layernorm.weight.data,
# epsilon=self._eps,
# quant_dtype=torch.int8,
# residual=None,
# update_input=False)
residual_fix_overflow = True
else:
# hidden_states, residual = self.input_layernorm(hidden_states, residual)
# i_q, i_s = lm_faster_rmsquant(input=hidden_states,
# rms_weight = self.input_layernorm.weight.data,
# epsilon=self._eps,
# quant_dtype=torch.int8,
# residual=residual,
# update_input=False)
# print("YYYY: input rms residual bf", residual.flatten()[:5])
i_q, i_s, residual = self.input_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
# update_input=False, # wjl, del
update_input=True,
update_input=False
)
# print("YYYY:input rms residual af", residual.flatten()[:5])
attn_kwargs = {
"positions": positions,
"hidden_states": hidden_states,
# "iqis": (i_q, i_s) # wjl, del
"iqis": (i_q, i_s)
}
if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling
......@@ -1144,7 +1119,6 @@ class DeepseekV2DecoderLayer(nn.Module):
not isinstance(self.self_attn, DeepseekAttention)
and hidden_states.dtype == torch.float16
):
print("YYYYY: FP16 overflow fix is applied")
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
......@@ -1155,25 +1129,17 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1.0 / self.routed_scaling_factor
# Fully Connected
# hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
# _i_q, _i_s = lm_faster_rmsquant(input=hidden_states,
# rms_weight=self.post_attention_layernorm.weight.data,
# epsilon=self._eps,
# quant_dtype=torch.int8,
# residual=residual,
# update_input=update_hs)
assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=update_hs, # wjl, del
# update_input=True,
update_input=update_hs
)
new_resi = residual
hidden_states = self.mlp(hidden_states,
# iqis=(_i_q, _i_s) # wjl, del
# iqis=(_i_q, _i_s) # TODO:wjl
)
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
......@@ -75,9 +75,6 @@ from .utils import (
)
from vllm.utils.torch_utils import direct_register_custom_op
if envs.USE_FUSED_RMS_QUANT:
from lightop import rms_norm_dynamic_per_token_quant
logger = init_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment