Commit 74c6e218 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch '092dev_DS_v2_w4a8_silu_mul_quant_reabase' into 'v0.9.2-dev'

deepseek_v2_w4a8模型forward_CRQ分支逻辑增加slilu_mul_quant融合

See merge request dcutoolkit/deeplearing/vllm!261
parents dbbc0b2e 421850cb
......@@ -173,7 +173,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = True
VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False
......@@ -1141,10 +1141,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")),
# vllm will use silu_mul_quant fused op
# vllm will use silu_mul_quant fused op,
# This variable has a default value of true,
# but it is still controlled by CRQ and RQ.
"USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")),
lambda: bool(int(os.getenv("USE_FUSED_SILU_MUL_QUANT", "1"))),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
......
......@@ -1536,6 +1536,15 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if use_fused_silu_mul_quant:
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args)
else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
......@@ -1561,7 +1570,7 @@ class RowParallelLinear(LinearBase):
return output
return output, resi, xq, xs, output_bias
else:
else: # RQ and Defualt forward
if self.input_is_parallel:
input_parallel = input_
else:
......
......@@ -167,6 +167,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and silu_quant_args is not None:
x_q, x_scale = silu_quant_args
else:
x_q, x_scale = per_token_quant_int8(x)
......
......@@ -109,6 +109,9 @@ class DeepseekV2MLP(nn.Module):
return x, new_resi
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
gate_up, _ = self.gate_up_proj(x, xqxs=xqxs)
if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
......@@ -651,6 +654,9 @@ class DeepseekV2MLAAttention(nn.Module):
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else:
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
......@@ -927,8 +933,6 @@ class DeepseekV2DecoderLayer(nn.Module):
return forward_func(positions=positions, hidden_states=hidden_states, residual=residual )
@support_torch_compile
class DeepseekV2Model(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment