Commit cab4877a authored by shangxl's avatar shangxl
Browse files

fix SGLANG_USE_FUSED_SILU_MUL_QUANT bug


Signed-off-by: default avatarwanghan <wanghan5@sugon.com>
parent 31653dd9
...@@ -459,7 +459,8 @@ class DeepseekV2MLP(nn.Module): ...@@ -459,7 +459,8 @@ class DeepseekV2MLP(nn.Module):
x = (x, None, y) x = (x, None, y)
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
if _use_fused_silu_mul_quant: d = gate_up.shape[-1]
if _use_fused_silu_mul_quant and d % 8 == 0 and d <= 16384:
x, _ = self.down_proj(gate_up, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter, use_fused_silu_mul_quant=True) x, _ = self.down_proj(gate_up, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter, use_fused_silu_mul_quant=True)
else: else:
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment