Commit be22412f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch '092dev_DS_add_silu_mul_quant_w8a8' into 'v0.9.2-dev'

deepseek_v2_w8a8 增加 silu_mul_quant融合

See merge request dcutoolkit/deeplearing/vllm!265
parents 2b47bce9 216e414b
...@@ -172,7 +172,7 @@ if TYPE_CHECKING: ...@@ -172,7 +172,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_SUM: bool = False VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = True
USE_FUSED_SILU_MUL_QUANT: bool = True USE_FUSED_SILU_MUL_QUANT: bool = True
VLLM_P2P_ASYNC: bool = False VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000 VLLM_P2P_BUF_TOKENS: int = 30000
...@@ -1142,8 +1142,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1142,8 +1142,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")), ("true", "1")),
# vllm will use rmsquant fused op # vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT": "USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "1"))),
("true", "1")),
# vllm will use silu_mul_quant fused op, # vllm will use silu_mul_quant fused op,
# This variable has a default value of true, # This variable has a default value of true,
# but it is still controlled by CRQ and RQ. # but it is still controlled by CRQ and RQ.
......
...@@ -669,7 +669,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -669,7 +669,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None): input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None):
""" """
Use the output of create_weights and the CompressedTensorsScheme Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the associated with the layer to apply the forward pass with the
...@@ -680,7 +681,10 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -680,7 +681,10 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
scheme = layer.scheme scheme = layer.scheme
if scheme is None: if scheme is None:
raise ValueError("A scheme must be defined for each layer") raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias, input_quant_args=input_quant_args) return scheme.apply_weights(layer, x,
bias=bias,
input_quant_args=input_quant_args,
silu_quant_args=silu_quant_args)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
......
...@@ -112,7 +112,9 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -112,7 +112,9 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
input_quant_args: Optional[list[torch.Tensor]] = None) -> torch.Tensor: input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
) -> torch.Tensor:
# return self.kernel.apply_weights(layer, x, bias) # return self.kernel.apply_weights(layer, x, bias)
...@@ -124,4 +126,5 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -124,4 +126,5 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
azp_adj=layer.azp_adj, azp_adj=layer.azp_adj,
bias=bias, bias=bias,
w8a8_strategy=self.w8a8_strategy, w8a8_strategy=self.w8a8_strategy,
input_quant_args=input_quant_args) input_quant_args=input_quant_args,
\ No newline at end of file silu_quant_args=silu_quant_args)
\ No newline at end of file
...@@ -168,6 +168,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -168,6 +168,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
assert len(input_quant_args) == 2 assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args x_q, x_scale = input_quant_args
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and silu_quant_args is not None: elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
x_q, x_scale = silu_quant_args x_q, x_scale = silu_quant_args
else: else:
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
......
...@@ -406,7 +406,8 @@ def apply_int8_linear( ...@@ -406,7 +406,8 @@ def apply_int8_linear(
azp_adj: Optional[torch.Tensor] = None, azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
w8a8_strategy:Optional[int]=0, w8a8_strategy:Optional[int]=0,
input_quant_args: Optional[list[torch.Tensor]] = None input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
): ):
# ops.scaled_int8_quant supports both dynamic and static quant. # ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x. # * dynamic, layer.input_scale is None and x_scale computed from x.
...@@ -416,7 +417,11 @@ def apply_int8_linear( ...@@ -416,7 +417,11 @@ def apply_int8_linear(
assert len(input_quant_args) == 2 assert len(input_quant_args) == 2
x_zp =None x_zp =None
x_q, x_scale = input_quant_args x_q, x_scale = input_quant_args
else: # not USE_FUSED_RMS_QUANT elif envs.USE_FUSED_RMS_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
x_zp =None
x_q, x_scale = silu_quant_args
else: # default
symmetric = azp_adj is None symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True: if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input) x_q, x_scale=per_token_quant_int8(input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment