Commit 826f22e1 authored by wujl5's avatar wujl5
Browse files

接入siluMulQuant融合

parent 39096bf4
...@@ -302,6 +302,7 @@ if TYPE_CHECKING: ...@@ -302,6 +302,7 @@ if TYPE_CHECKING:
VLLM_USE_MOE_W16A16_TRITON: bool = False VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1897,6 +1898,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1897,6 +1898,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
).lower() ).lower()
in ("true", "1") in ("true", "1")
), ),
# vllm will use silu_mul_quant fused op,
# This variable has a default value of true,
# but it is still controlled by RQ.
"USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', 'True').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -370,7 +370,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -370,7 +370,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward( return self.forward(
layer=layer, layer=layer,
......
...@@ -39,6 +39,7 @@ from vllm.platforms import current_platform ...@@ -39,6 +39,7 @@ from vllm.platforms import current_platform
import os import os
from vllm.model_executor.utils import gemm_bank_conf from vllm.model_executor.utils import gemm_bank_conf
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1548,6 +1549,8 @@ class RowParallelLinear(LinearBase): ...@@ -1548,6 +1549,8 @@ class RowParallelLinear(LinearBase):
def forward( def forward(
self, self,
input_, input_,
*,
use_fused_silu_mul_quant: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
...@@ -1562,7 +1565,16 @@ class RowParallelLinear(LinearBase): ...@@ -1562,7 +1565,16 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self, input_parallel, bias_) if use_fused_silu_mul_quant:
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args)
else:
output_parallel = self.quant_method.apply(self, input_parallel, bias_)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
......
...@@ -227,8 +227,11 @@ class DeepseekV2MLP(nn.Module): ...@@ -227,8 +227,11 @@ class DeepseekV2MLP(nn.Module):
): ):
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis) gate_up, _ = self.gate_up_proj(x, iqis=iqis)
x = self.act_fn(gate_up) if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(x) x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up, use_fused_silu_mul_quant=True)
x, _ = self.down_proj(x)
else: else:
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment