Commit 0f6b9a19 authored by wujl5's avatar wujl5
Browse files

perf: GLM4.7增加MOE调用rmsQuant, fix: 修掉fused_moe向后传递None导致的报错

parent 3f414133
...@@ -1721,7 +1721,7 @@ class FusedMoE(CustomOp): ...@@ -1721,7 +1721,7 @@ class FusedMoE(CustomOp):
hidden_states, router_logits hidden_states, router_logits
) )
else: else:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared( shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name(), hidden_states, router_logits, encode_layer_name(),
i_q=i_q, i_q=i_q,
...@@ -1749,7 +1749,10 @@ class FusedMoE(CustomOp): ...@@ -1749,7 +1749,10 @@ class FusedMoE(CustomOp):
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s) if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s)
else:
return self.forward_native(hidden_states, router_logits)
def forward_impl_chunked( def forward_impl_chunked(
self, self,
...@@ -1981,7 +1984,7 @@ class FusedMoE(CustomOp): ...@@ -1981,7 +1984,7 @@ class FusedMoE(CustomOp):
# because matrix multiply maybe modify the hidden_states. # because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream: if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
shared_output = self.shared_experts(hidden_states, iqis=(i_q, i_s)) shared_output = self.shared_experts(hidden_states, iqis=(i_q, i_s))
else: else:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
...@@ -2196,7 +2199,7 @@ def moe_forward_shared( ...@@ -2196,7 +2199,7 @@ def moe_forward_shared(
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
self = get_layer_from_name(layer_name) self = get_layer_from_name(layer_name)
assert self.shared_experts is not None assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s) return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
else: else:
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
......
...@@ -8,6 +8,7 @@ from vllm.distributed import ( ...@@ -8,6 +8,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm import envs
# TODO(bnell): Add shared + fused combo function? e.g. + # TODO(bnell): Add shared + fused combo function? e.g. +
...@@ -64,9 +65,14 @@ class SharedFusedMoE(FusedMoE): ...@@ -64,9 +65,14 @@ class SharedFusedMoE(FusedMoE):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped: if not self.use_overlapped:
if self._shared_experts is not None: if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states, if envs.USE_FUSED_RMS_QUANT and iqis is not None:
i_q=iqis[0] if iqis is not None else None, assert iqis[0] is not None
i_s=iqis[1] if iqis is not None else None) assert iqis[1] is not None
shared_out = self._shared_experts(hidden_states,
i_q=iqis[0],
i_s=iqis[1])
else:
shared_out = self._shared_experts(hidden_states)
# Reduce shared expert outputs if necessary, since the MLP # Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False. # should have been created with reduce_results=False.
...@@ -79,19 +85,35 @@ class SharedFusedMoE(FusedMoE): ...@@ -79,19 +85,35 @@ class SharedFusedMoE(FusedMoE):
else: else:
shared_out = None shared_out = None
fused_out = super().forward( if envs.USE_FUSED_RMS_QUANT and iqis is not None:
hidden_states=hidden_states, assert iqis[0] is not None
router_logits=router_logits, assert iqis[1] is not None
i_q=iqis[0] if iqis is not None else None, fused_out = super().forward(
i_s=iqis[1] if iqis is not None else None, hidden_states=hidden_states,
) router_logits=router_logits,
i_q=iqis[0],
i_s=iqis[1]
)
else:
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits
)
else: else:
shared_out, fused_out = super().forward( if envs.USE_FUSED_RMS_QUANT and iqis is not None:
hidden_states=hidden_states, assert iqis[0] is not None
router_logits=router_logits, assert iqis[1] is not None
i_q=iqis[0] if iqis is not None else None, shared_out, fused_out = super().forward(
i_s=iqis[1] if iqis is not None else None, hidden_states=hidden_states,
) router_logits=router_logits,
i_q=iqis[0],
i_s=iqis[1]
)
else:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits
)
# ensure early TP reduction of shared expert outputs when required # ensure early TP reduction of shared expert outputs when required
if ( if (
shared_out is not None shared_out is not None
......
...@@ -170,10 +170,14 @@ def apply_int8_linear( ...@@ -170,10 +170,14 @@ def apply_int8_linear(
# * static, layer.input_scale is scalar and x_scale is input_scale. # * static, layer.input_scale is scalar and x_scale is input_scale.
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None: if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2 assert len(input_quant_args) == 2
assert input_quant_args[0] is not None
assert input_quant_args[1] is not None
x_zp =None x_zp =None
x_q, x_scale = input_quant_args x_q, x_scale = input_quant_args
elif envs.USE_FUSED_RMS_QUANT and silu_quant_args is not None: elif envs.USE_FUSED_RMS_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2 assert len(silu_quant_args) == 2
assert silu_quant_args[0] is not None
assert silu_quant_args[1] is not None
x_zp =None x_zp =None
x_q, x_scale = silu_quant_args x_q, x_scale = silu_quant_args
else: else:
......
...@@ -110,12 +110,13 @@ class Glm4MoeMLP(nn.Module): ...@@ -110,12 +110,13 @@ class Glm4MoeMLP(nn.Module):
def forward(self, x, def forward(self, x,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None): iqis: tuple[torch.Tensor, torch.Tensor] | None = None):
gate_up, _ = self.gate_up_proj(x, iqis=iqis) if envs.USE_FUSED_SILU_MUL_QUANT and iqis is not None:
if envs.USE_FUSED_SILU_MUL_QUANT: gate_up, _ = self.gate_up_proj(x, iqis=iqis)
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up) xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs)) x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else: else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x return x
...@@ -207,16 +208,22 @@ class Glm4MoE(nn.Module): ...@@ -207,16 +208,22 @@ class Glm4MoE(nn.Module):
router_logits_dtype=torch.float32, router_logits_dtype=torch.float32,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self,
hidden_states: torch.Tensor,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states.to(dtype=torch.float32)) router_logits = self.gate(hidden_states.to(dtype=torch.float32))
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
fused_moe_out = self.experts( fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits,
) iqis=iqis)
else:
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits)
if self.shared_experts is not None: if self.shared_experts is not None:
shared_output, final_hidden_states = fused_moe_out shared_output, final_hidden_states = fused_moe_out
...@@ -420,7 +427,7 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -420,7 +427,7 @@ class Glm4MoeDecoderLayer(nn.Module):
else: else:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if envs.USE_FUSED_RMS_QUANT and isinstance(self.mlp, Glm4MoeMLP): if envs.USE_FUSED_RMS_QUANT:
self.post_attention_layernorm = FusedRMSNormQuant( self.post_attention_layernorm = FusedRMSNormQuant(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
...@@ -459,8 +466,12 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -459,8 +466,12 @@ class Glm4MoeDecoderLayer(nn.Module):
) )
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, iqis=(i_q, i_s)) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, iqis=(i_q, i_s))
if envs.USE_FUSED_RMS_QUANT and isinstance(self.mlp, Glm4MoeMLP): if envs.USE_FUSED_RMS_QUANT:
i_q, i_s, _ = self.post_attention_layernorm(hidden_states, residual) update_hs = True if isinstance(self.mlp, Glm4MoE) else False
i_q, i_s, residual = self.post_attention_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=update_hs)
hidden_states = self.mlp(hidden_states, iqis=(i_q, i_s)) hidden_states = self.mlp(hidden_states, iqis=(i_q, i_s))
else: else:
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment