"vscode:/vscode.git/clone" did not exist on "87185c88d54bd97c4c08f1fd3c5a8564e4924e2a"
Commit 075841f3 authored by wujl5's avatar wujl5
Browse files

DS_v2_w4a8_CRQ增加slilu_mul_quant支持

parent 8c646ebe
......@@ -1536,9 +1536,18 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if use_fused_silu_mul_quant:
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args)
else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel)
......@@ -1561,7 +1570,7 @@ class RowParallelLinear(LinearBase):
return output
return output, resi, xq, xs, output_bias
else:
else: # RQ and Defualt forward
if self.input_is_parallel:
input_parallel = input_
else:
......
......@@ -167,6 +167,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and silu_quant_args is not None:
x_q, x_scale = silu_quant_args
else:
x_q, x_scale = per_token_quant_int8(x)
......
......@@ -109,8 +109,11 @@ class DeepseekV2MLP(nn.Module):
return x, new_resi
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
gate_up, _ = self.gate_up_proj(x, xqxs=xqxs)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
else:
gate_up, _ = self.gate_up_proj(x)
......@@ -929,8 +932,6 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_func = self.choose_forward()
return forward_func(positions=positions, hidden_states=hidden_states, residual=residual )
@support_torch_compile
class DeepseekV2Model(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment