Commit 3155bf9e authored by zhuwenwen's avatar zhuwenwen
Browse files

deepseek-r1-w4a8 mlp/moe调用silu-mul-quant融合

parent 98cf1a9e
......@@ -205,6 +205,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_CAT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
def get_default_cache_root():
......@@ -1401,6 +1402,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")),
# vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -39,6 +39,13 @@ if envs.USE_FUSED_RMS_QUANT:
except Exception as e:
print(f"Error: Import fused rmsquant error: {e}")
if envs.USE_FUSED_SILU_MUL_QUANT:
try:
# from lightop import fuse_silu_mul_quant
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
except Exception as e:
print(f"Error: Import fused silu_mul_qunat error: {e}")
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
......@@ -1500,7 +1507,8 @@ class RowParallelLinear(LinearBase):
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(
self, input_
self, input_,
use_fused_silu_mul_quant: Optional[bool] = False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
......@@ -1514,6 +1522,15 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if use_fused_silu_mul_quant:
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args)
else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
......
......@@ -153,11 +153,14 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None
input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
):
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
x_q, x_scale = silu_quant_args
else:
x_q, x_scale = per_token_quant_int8(x)
......
......@@ -116,6 +116,9 @@ class DeepseekV2MLP(nn.Module):
):
if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x, new_resi
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment