Commit c2ef7fdd authored by zhuwenwen's avatar zhuwenwen
Browse files

remove USE_FUSED_RMS_QUANT and USE_FUSED_SILU_MUL_QUANT

parent 383f2ce8
...@@ -233,8 +233,6 @@ if TYPE_CHECKING: ...@@ -233,8 +233,6 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_SUM: bool = False VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
VLLM_USE_PIECEWISE: bool = False VLLM_USE_PIECEWISE: bool = False
...@@ -1639,14 +1637,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1639,14 +1637,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")), ("true", "1")),
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")),
# vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")),
# vLLM will split prefill and decode, not mix up # vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT": "VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "False").lower() in lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "False").lower() in
......
...@@ -35,18 +35,6 @@ from vllm.utils import GiB_bytes ...@@ -35,18 +35,6 @@ from vllm.utils import GiB_bytes
import os import os
from vllm.model_executor.utils import gemm_bank_conf from vllm.model_executor.utils import gemm_bank_conf
if envs.USE_FUSED_RMS_QUANT:
try:
from lmslim.quantize.quant_ops import lm_faster_rmsquant
except Exception as e:
print(f"Error: Import fused rmsquant error: {e}")
if envs.USE_FUSED_SILU_MUL_QUANT:
try:
# from lightop import fuse_silu_mul_quant
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
except Exception as e:
print(f"Error: Import fused silu_mul_qunat error: {e}")
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -441,54 +429,18 @@ class ReplicatedLinear(LinearBase): ...@@ -441,54 +429,18 @@ class ReplicatedLinear(LinearBase):
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)
def forward( def forward(
self, self,
input_: torch.Tensor, x: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
quant_args: Optional[list] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None): bias = self.bias if not self.skip_bias_add else None
if quant_args is not None: assert self.quant_method is not None
input_quant_args = quant_args
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else: output = self.quant_method.apply(self, x, bias)
i_q, _scales = lm_faster_rmsquant(input=input_, output_bias = self.bias if self.skip_bias_add else None
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias, input_quant_args
else: if not self.return_bias:
bias = self.bias if not self.skip_bias_add else None return output
assert self.quant_method is not None return output, output_bias
output = self.quant_method.apply(self, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"in_features={self.input_size}" s = f"in_features={self.input_size}"
...@@ -652,48 +604,22 @@ class ColumnParallelLinear(LinearBase): ...@@ -652,48 +604,22 @@ class ColumnParallelLinear(LinearBase):
def forward( def forward(
self, self,
input_, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: bias = self.bias if not self.skip_bias_add else None
input_quant_args = None
assert rms_weight is not None # Matrix multiply.
i_q, _scales = lm_faster_rmsquant(input=input_, assert self.quant_method is not None
rms_weight=rms_weight, output_parallel = self.quant_method.apply(self, input_, bias)
epsilon=self.eps,
quant_dtype=torch.int8, if self.gather_output and self.tp_size > 1:
residual=residual, # All-gather across the partitions.
update_input=update_hd) output = tensor_model_parallel_all_gather(output_parallel)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
if self.gather_output and self.tp_size > 1:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias
else: else:
bias = self.bias if not self.skip_bias_add else None output = output_parallel
# Matrix multiply. output_bias = self.bias if self.skip_bias_add else None
assert self.quant_method is not None if not self.return_bias:
output_parallel = self.quant_method.apply(self, input_, bias) return output
if self.gather_output and self.tp_size > 1: return output, output_bias
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"in_features={self.input_size}" s = f"in_features={self.input_size}"
...@@ -730,54 +656,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -730,54 +656,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
will be treated as a "Replicated" MergedLinear. will be treated as a "Replicated" MergedLinear.
""" """
def forward(
self, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert residual is not None and rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias
else: # not USE_FUSED_RMS_QUANT
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def __init__( def __init__(
self, self,
input_size: int, input_size: int,
...@@ -1548,7 +1426,6 @@ class RowParallelLinear(LinearBase): ...@@ -1548,7 +1426,6 @@ class RowParallelLinear(LinearBase):
def forward( def forward(
self, self,
input_, input_,
use_fused_silu_mul_quant: Optional[bool] = False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
...@@ -1562,16 +1439,7 @@ class RowParallelLinear(LinearBase): ...@@ -1562,16 +1439,7 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if use_fused_silu_mul_quant: output_parallel = self.quant_method.apply(self, input_parallel, bias_)
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self,
input_parallel,
bias_,
silu_quant_args=silu_quant_args)
else:
output_parallel = self.quant_method.apply(self, input_parallel, bias_)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
......
...@@ -155,16 +155,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -155,16 +155,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
): ):
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None: x_q, x_scale = per_token_quant_int8(x)
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
x_q, x_scale = silu_quant_args
else:
x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1: if self.w8a8_strategy==1:
m=x_q.shape[0] m=x_q.shape[0]
......
...@@ -468,20 +468,16 @@ def apply_int8_linear( ...@@ -468,20 +468,16 @@ def apply_int8_linear(
# ops.scaled_int8_quant supports both dynamic and static quant. # ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x. # * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale. # * static, layer.input_scale is scalar and x_scale is input_scale.
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2 symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None x_zp =None
x_q, x_scale = input_quant_args else:
else: # not USE_FUSED_RMS_QUANT x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
symmetric = azp_adj is None input_scale,
if input_scale is None and input_zero_point is None and symmetric is True: input_zero_point,
x_q, x_scale=per_token_quant_int8(input) symmetric=symmetric)
x_zp =None
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
if x_zp is not None: if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token # Currently, static is always per-tensor and dynamic is per-token
......
...@@ -133,24 +133,11 @@ class DeepseekV2MLP(nn.Module): ...@@ -133,24 +133,11 @@ class DeepseekV2MLP(nn.Module):
"Only silu is supported for now.") "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x, def forward(self, x):
rms_weight: Optional[torch.Tensor] = None, gate_up, _ = self.gate_up_proj(x)
residual: Optional[torch.Tensor] = None, x = self.act_fn(gate_up)
update_hd: Optional[bool] = False x, _ = self.down_proj(x)
): return x
if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x, new_resi
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
...@@ -1128,87 +1115,44 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1128,87 +1115,44 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
if envs.USE_FUSED_RMS_QUANT: # Self Attention
# Fix residual FP16 overflow # Fix residual FP16 overflow
residual_fix_overflow = False residual_fix_overflow = False
if residual is None:
assert self.input_layernorm.has_weight is True residual = hidden_states.clone()
if residual is None: hidden_states = self.input_layernorm(hidden_states)
residual = hidden_states residual_fix_overflow = True
hidden_states, _ = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = None
)
residual_fix_overflow = True
else:
hidden_states, new_residual = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = residual
)
residual = new_residual
if hidden_states.dtype == torch.float16:
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi = self.mlp(hidden_states, self.post_attention_layernorm.weight.data, residual)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi
else: else:
# Self Attention hidden_states, residual = self.input_layernorm(
# Fix residual FP16 overflow
residual_fix_overflow = False
if residual is None:
residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.self_attn(
if isinstance(self.mlp, positions=positions,
DeepseekV2MLP) and hidden_states.dtype == torch.float16: hidden_states=hidden_states,
# Fix FP16 overflow )
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer. if hidden_states.dtype == torch.float16:
# The scaling of DeepseekV2MOE output would be done in the forward # Fix FP16 overflow
# of DeepseekV2MOE # We scale both hidden_states and residual before
hidden_states *= 1. / self.routed_scaling_factor # rmsnorm, and rmsnorm result would not affect by scale.
return hidden_states, residual hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual
@support_torch_compile @support_torch_compile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment