Commit 9f201bc1 authored by zhuwenwen's avatar zhuwenwen
Browse files

deepseek-r1-w4a8使用rmsquant融合算子及横向融合

parent 243b2f0c
...@@ -204,6 +204,7 @@ if TYPE_CHECKING: ...@@ -204,6 +204,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHT_OP: bool = False VLLM_USE_LIGHT_OP: bool = False
VLLM_USE_TRITON_CAT: bool = False VLLM_USE_TRITON_CAT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1396,6 +1397,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1396,6 +1397,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")), ("true", "1")),
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -33,6 +33,12 @@ from vllm.platforms import current_platform ...@@ -33,6 +33,12 @@ from vllm.platforms import current_platform
import os import os
from vllm.model_executor.utils import gemm_bank_conf from vllm.model_executor.utils import gemm_bank_conf
if envs.USE_FUSED_RMS_QUANT:
try:
from lmslim.quantize.quant_ops import lm_faster_rmsquant
except Exception as e:
print(f"Error: Import fused rmsquant error: {e}")
logger = init_logger(__name__) logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [ WEIGHT_LOADER_V2_SUPPORTED = [
...@@ -325,6 +331,7 @@ class ReplicatedLinear(LinearBase): ...@@ -325,6 +331,7 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
...@@ -338,6 +345,7 @@ class ReplicatedLinear(LinearBase): ...@@ -338,6 +345,7 @@ class ReplicatedLinear(LinearBase):
prefix=prefix, prefix=prefix,
return_bias=return_bias, return_bias=return_bias,
disable_tp=disable_tp) disable_tp=disable_tp)
self.eps = eps
# All the linear layer supports quant method. # All the linear layer supports quant method.
assert self.quant_method is not None assert self.quant_method is not None
...@@ -385,15 +393,53 @@ class ReplicatedLinear(LinearBase): ...@@ -385,15 +393,53 @@ class ReplicatedLinear(LinearBase):
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)
def forward( def forward(
self, x: torch.Tensor self,
input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
quant_args: Optional[list] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
assert self.quant_method is not None if quant_args is not None:
output = self.quant_method.apply(self, x, bias) input_quant_args = quant_args
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias: bias = self.bias if not self.skip_bias_add else None
return output assert self.quant_method is not None
return output, output_bias output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias, input_quant_args
else:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"in_features={self.input_size}" s = f"in_features={self.input_size}"
...@@ -439,6 +485,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -439,6 +485,7 @@ class ColumnParallelLinear(LinearBase):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None, output_sizes: Optional[list[int]] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
...@@ -468,6 +515,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -468,6 +515,7 @@ class ColumnParallelLinear(LinearBase):
return_bias=return_bias, return_bias=return_bias,
disable_tp=disable_tp) disable_tp=disable_tp)
self.eps = eps
self.gather_output = gather_output self.gather_output = gather_output
if output_sizes is None: if output_sizes is None:
...@@ -553,22 +601,49 @@ class ColumnParallelLinear(LinearBase): ...@@ -553,22 +601,49 @@ class ColumnParallelLinear(LinearBase):
param.load_column_parallel_weight(loaded_weight=loaded_weight) param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward( def forward(
self, input_ self, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
# Matrix multiply. assert rms_weight is not None
assert self.quant_method is not None i_q, _scales = lm_faster_rmsquant(input=input_,
output_parallel = self.quant_method.apply(self, input_, bias) rms_weight=rms_weight,
if self.gather_output and self.tp_size > 1: epsilon=self.eps,
# All-gather across the partitions. quant_dtype=torch.int8,
output = tensor_model_parallel_all_gather(output_parallel) residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
if self.gather_output and self.tp_size > 1:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias
else: else:
output = output_parallel bias = self.bias if not self.skip_bias_add else None
output_bias = self.bias if self.skip_bias_add else None # Matrix multiply.
if not self.return_bias: assert self.quant_method is not None
return output output_parallel = self.quant_method.apply(self, input_, bias)
return output, output_bias if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"in_features={self.input_size}" s = f"in_features={self.input_size}"
...@@ -605,6 +680,54 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -605,6 +680,54 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
will be treated as a "Replicated" MergedLinear. will be treated as a "Replicated" MergedLinear.
""" """
def forward(
self, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert residual is not None and rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias
else: # not USE_FUSED_RMS_QUANT
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def __init__( def __init__(
self, self,
input_size: int, input_size: int,
...@@ -614,11 +737,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -614,11 +737,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
disable_tp: bool = False, disable_tp: bool = False,
): ):
self.eps = eps
self.output_sizes = output_sizes self.output_sizes = output_sizes
self.tp_size = (get_tensor_model_parallel_world_size() self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1) if not disable_tp else 1)
......
...@@ -16,11 +16,11 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -16,11 +16,11 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
from lmslim.layers.gemm.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_group_quant_int8,
per_token_quant_int8) per_token_quant_int8)
from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
import os import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs
W8A8_TRITONJSON=W8a8GetCacheJSON() W8A8_TRITONJSON=W8a8GetCacheJSON()
...@@ -153,8 +153,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -153,8 +153,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None
): ):
x_q, x_scale = per_token_quant_int8(x) if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
else:
x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1: if self.w8a8_strategy==1:
m=x_q.shape[0] m=x_q.shape[0]
......
...@@ -109,11 +109,21 @@ class DeepseekV2MLP(nn.Module): ...@@ -109,11 +109,21 @@ class DeepseekV2MLP(nn.Module):
"Only silu is supported for now.") "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x,
gate_up, _ = self.gate_up_proj(x) rms_weight: Optional[torch.Tensor] = None,
x = self.act_fn(gate_up) residual: Optional[torch.Tensor] = None,
x, _ = self.down_proj(x) update_hd: Optional[bool] = False
return x ):
if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x, new_resi
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
# Chunk x along the num_tokens axis for sequence parallelism # Chunk x along the num_tokens axis for sequence parallelism
...@@ -282,7 +292,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -282,7 +292,10 @@ class DeepseekV2MoE(nn.Module):
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -696,47 +709,88 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -696,47 +709,88 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention if envs.USE_FUSED_RMS_QUANT:
# Fix residual FP16 overflow
# Fix residual FP16 overflow residual_fix_overflow = False
residual_fix_overflow = False
if residual is None: assert self.input_layernorm.has_weight is True
residual = hidden_states if residual is None:
hidden_states = self.input_layernorm(hidden_states) residual = hidden_states
residual_fix_overflow = True hidden_states, _ = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = None
)
residual_fix_overflow = True
else:
hidden_states, new_residual = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = residual
)
residual = new_residual
if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi = self.mlp(hidden_states, self.post_attention_layernorm.weight.data, residual)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi
else: else:
hidden_states, residual = self.input_layernorm( # Self Attention
hidden_states, residual) # Fix residual FP16 overflow
hidden_states = self.self_attn( residual_fix_overflow = False
positions=positions, if residual is None:
hidden_states=hidden_states, residual = hidden_states
) hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick: if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow # Fix FP16 overflow
# We scale both hidden_states and residual before # We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale. # rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow: if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on # The residual is shared by all layers, we only scale it on
# first layer. # first layer.
residual *= 1. / self.routed_scaling_factor residual *= 1. / self.routed_scaling_factor
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
DeepseekV2MLP) and hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick: # Fix FP16 overflow
# Fix FP16 overflow # Scaling the DeepseekV2MLP output, it is the input of
# Scaling the DeepseekV2MLP output, it is the input of # input_layernorm of next decoder layer.
# input_layernorm of next decoder layer. # The scaling of DeepseekV2MOE output would be done in the forward
# The scaling of DeepseekV2MOE output would be done in the forward # of DeepseekV2MOE
# of DeepseekV2MOE hidden_states *= 1. / self.routed_scaling_factor
hidden_states *= 1. / self.routed_scaling_factor return hidden_states, residual
return hidden_states, residual
@support_torch_compile @support_torch_compile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment