Commit 05eca476 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev_GLM4.7_moe_call_RQ' into 'v0.15.1-dev'

perf: GLM4.7增加MOE调用rmsQuant, fix: 修掉fused_moe向后传递None导致的报错

See merge request dcutoolkit/deeplearing/vllm!505
parents 57979f97 0f6b9a19
......@@ -1721,7 +1721,7 @@ class FusedMoE(CustomOp):
hidden_states, router_logits
)
else:
if envs.USE_FUSED_RMS_QUANT:
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name(),
i_q=i_q,
......@@ -1749,7 +1749,10 @@ class FusedMoE(CustomOp):
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s)
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s)
else:
return self.forward_native(hidden_states, router_logits)
def forward_impl_chunked(
self,
......@@ -1981,7 +1984,7 @@ class FusedMoE(CustomOp):
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT:
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
shared_output = self.shared_experts(hidden_states, iqis=(i_q, i_s))
else:
shared_output = self.shared_experts(hidden_states)
......@@ -2196,7 +2199,7 @@ def moe_forward_shared(
) -> tuple[torch.Tensor, torch.Tensor]:
self = get_layer_from_name(layer_name)
assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT:
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
else:
return self.forward_impl(hidden_states, router_logits)
......
......@@ -8,6 +8,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm import envs
# TODO(bnell): Add shared + fused combo function? e.g. +
......@@ -64,9 +65,14 @@ class SharedFusedMoE(FusedMoE):
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None)
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
assert iqis[0] is not None
assert iqis[1] is not None
shared_out = self._shared_experts(hidden_states,
i_q=iqis[0],
i_s=iqis[1])
else:
shared_out = self._shared_experts(hidden_states)
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
......@@ -79,19 +85,35 @@ class SharedFusedMoE(FusedMoE):
else:
shared_out = None
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None,
)
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
assert iqis[0] is not None
assert iqis[1] is not None
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0],
i_s=iqis[1]
)
else:
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits
)
else:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None,
)
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
assert iqis[0] is not None
assert iqis[1] is not None
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0],
i_s=iqis[1]
)
else:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits
)
# ensure early TP reduction of shared expert outputs when required
if (
shared_out is not None
......
......@@ -170,10 +170,14 @@ def apply_int8_linear(
# * static, layer.input_scale is scalar and x_scale is input_scale.
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
assert input_quant_args[0] is not None
assert input_quant_args[1] is not None
x_zp =None
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_RMS_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
assert silu_quant_args[0] is not None
assert silu_quant_args[1] is not None
x_zp =None
x_q, x_scale = silu_quant_args
else:
......
......@@ -110,12 +110,13 @@ class Glm4MoeMLP(nn.Module):
def forward(self, x,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None):
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
if envs.USE_FUSED_SILU_MUL_QUANT and iqis is not None:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
......@@ -207,16 +208,22 @@ class Glm4MoE(nn.Module):
router_logits_dtype=torch.float32,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self,
hidden_states: torch.Tensor,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits,
iqis=iqis)
else:
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits)
if self.shared_experts is not None:
shared_output, final_hidden_states = fused_moe_out
......@@ -420,7 +427,7 @@ class Glm4MoeDecoderLayer(nn.Module):
else:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if envs.USE_FUSED_RMS_QUANT and isinstance(self.mlp, Glm4MoeMLP):
if envs.USE_FUSED_RMS_QUANT:
self.post_attention_layernorm = FusedRMSNormQuant(
config.hidden_size, eps=config.rms_norm_eps
)
......@@ -459,8 +466,12 @@ class Glm4MoeDecoderLayer(nn.Module):
)
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, iqis=(i_q, i_s))
if envs.USE_FUSED_RMS_QUANT and isinstance(self.mlp, Glm4MoeMLP):
i_q, i_s, _ = self.post_attention_layernorm(hidden_states, residual)
if envs.USE_FUSED_RMS_QUANT:
update_hs = True if isinstance(self.mlp, Glm4MoE) else False
i_q, i_s, residual = self.post_attention_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=update_hs)
hidden_states = self.mlp(hidden_states, iqis=(i_q, i_s))
else:
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment