Commit a5043e83 authored by wujl5's avatar wujl5
Browse files

DeepSeek-R1-Channel-INT8调用rmsquant融合。

parent 5eec6110
......@@ -666,7 +666,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None):
bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
......@@ -677,7 +678,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
return scheme.apply_weights(layer, x, bias=bias, input_quant_args=input_quant_args)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
......
......@@ -111,7 +111,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
bias: Optional[torch.Tensor],
input_quant_args: Optional[list[torch.Tensor]] = None) -> torch.Tensor:
# return self.kernel.apply_weights(layer, x, bias)
......@@ -122,5 +123,5 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj,
bias=bias,
w8a8_strategy=self.w8a8_strategy)
w8a8_strategy=self.w8a8_strategy,
input_quant_args=input_quant_args)
\ No newline at end of file
......@@ -392,15 +392,21 @@ def apply_int8_linear(
azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
w8a8_strategy:Optional[int]=0,
input_quant_args: Optional[list[torch.Tensor]] = None
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_zp =None
x_q, x_scale = input_quant_args
else: # not USE_FUSED_RMS_QUANT
symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment