Commit 46e26bf1 authored by 王敏's avatar 王敏
Browse files

修复部分代码

parent 83f2f396
...@@ -158,11 +158,14 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -158,11 +158,14 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
): ):
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None: if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2 assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args x_q, x_scale = input_quant_args
elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
x_q, x_scale = silu_quant_args
else: else:
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
...@@ -373,7 +376,7 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -373,7 +376,7 @@ class SlimQuantW4A8Int8MoEMethod:
) )
def apply(# tp def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
...@@ -394,6 +397,7 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -394,6 +397,7 @@ class SlimQuantW4A8Int8MoEMethod:
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -434,4 +438,6 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -434,4 +438,6 @@ class SlimQuantW4A8Int8MoEMethod:
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
...@@ -103,8 +103,12 @@ class DeepseekV2MLP(nn.Module): ...@@ -103,8 +103,12 @@ class DeepseekV2MLP(nn.Module):
): ):
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd) gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x, new_resi return x, new_resi
else: else:
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
...@@ -574,6 +578,9 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -574,6 +578,9 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else:
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim) q = q.view(-1, self.num_local_heads, self.qk_head_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment