Commit 9f087f8b authored by zhuwenwen's avatar zhuwenwen
Browse files

DeepSeek-R1-Channel-INT8调用rmsquant融合

parent 1e911dbd
......@@ -767,7 +767,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None):
bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
......@@ -777,7 +778,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
return scheme.apply_weights(layer, x, bias=bias, input_quant_args=input_quant_args)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
......
......@@ -111,7 +111,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
bias: Optional[torch.Tensor],
input_quant_args: Optional[list[torch.Tensor]] = None) -> torch.Tensor:
# return self.kernel.apply_weights(layer, x, bias)
......@@ -122,5 +123,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj,
bias=bias,
w8a8_strategy=self.w8a8_strategy)
w8a8_strategy=self.w8a8_strategy,
input_quant_args=input_quant_args)
......@@ -447,20 +447,25 @@ def apply_int8_linear(
azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
w8a8_strategy:Optional[int]=0,
input_quant_args: Optional[list[torch.Tensor]] = None
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_zp =None
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
x_q, x_scale = input_quant_args
else: # not USE_FUSED_RMS_QUANT
symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment