Commit 60b37c6b authored by zhuwenwen's avatar zhuwenwen
Browse files

remove USE_FUSED_RMS_QUANT and USE_FUSED_SILU_MUL_QUANT

parent c964b9ad
...@@ -270,8 +270,6 @@ if TYPE_CHECKING: ...@@ -270,8 +270,6 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_SUM: bool = False VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
VLLM_USE_PIECEWISE: bool = False VLLM_USE_PIECEWISE: bool = False
...@@ -1726,14 +1724,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1726,14 +1724,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")), ("true", "1")),
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")),
# vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")),
# vLLM will split prefill and decode, not mix up # vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT": "VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "False").lower() in lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "False").lower() in
......
...@@ -1592,7 +1592,6 @@ class RowParallelLinear(LinearBase): ...@@ -1592,7 +1592,6 @@ class RowParallelLinear(LinearBase):
def forward( def forward(
self, self,
input_, input_,
use_fused_silu_mul_quant: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
...@@ -1607,16 +1606,7 @@ class RowParallelLinear(LinearBase): ...@@ -1607,16 +1606,7 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if use_fused_silu_mul_quant: output_parallel = self.quant_method.apply(self, input_parallel, bias_)
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self,
input_parallel,
bias_,
silu_quant_args=silu_quant_args)
else:
output_parallel = self.quant_method.apply(self, input_parallel, bias_)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
......
...@@ -159,14 +159,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -159,14 +159,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
input_quant_args: Optional[list[torch.Tensor]] = None, input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None silu_quant_args: Optional[list[torch.Tensor]] = None
): ):
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None: x_q, x_scale = per_token_quant_int8(x)
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
x_q, x_scale = silu_quant_args
else:
x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1: if self.w8a8_strategy==1:
m=x_q.shape[0] m=x_q.shape[0]
......
...@@ -196,11 +196,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -196,11 +196,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: # if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
if not envs.is_set("USE_FUSED_RMS_QUANT"): # if not envs.is_set("USE_FUSED_RMS_QUANT"):
os.environ['USE_FUSED_RMS_QUANT'] = '1' # os.environ['USE_FUSED_RMS_QUANT'] = '1'
if not envs.is_set("USE_FUSED_SILU_MUL_QUANT"): # if not envs.is_set("USE_FUSED_SILU_MUL_QUANT"):
os.environ['USE_FUSED_SILU_MUL_QUANT'] = '1' # os.environ['USE_FUSED_SILU_MUL_QUANT'] = '1'
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
...@@ -228,11 +228,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -228,11 +228,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: # if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
if not envs.is_set("USE_FUSED_RMS_QUANT"): # if not envs.is_set("USE_FUSED_RMS_QUANT"):
os.environ['USE_FUSED_RMS_QUANT'] = '1' # os.environ['USE_FUSED_RMS_QUANT'] = '1'
if not envs.is_set("USE_FUSED_SILU_MUL_QUANT"): # if not envs.is_set("USE_FUSED_SILU_MUL_QUANT"):
os.environ['USE_FUSED_SILU_MUL_QUANT'] = '1' # os.environ['USE_FUSED_SILU_MUL_QUANT'] = '1'
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
......
...@@ -232,24 +232,11 @@ class DeepseekV2MLP(nn.Module): ...@@ -232,24 +232,11 @@ class DeepseekV2MLP(nn.Module):
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x, def forward(self, x):
rms_weight: torch.Tensor | None = None, gate_up, _ = self.gate_up_proj(x)
residual: torch.Tensor | None = None, x = self.act_fn(gate_up)
update_hd: bool | None = False x, _ = self.down_proj(x)
): return x
if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x, new_resi
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment