Commit daa7f7d7 authored by gaoqiong's avatar gaoqiong
Browse files

增加w8a8的triton环境变量控制

parent 2c7f740a
......@@ -691,6 +691,24 @@ def cutlass_scaled_mm(a: torch.Tensor,
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def rocblas_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def triton_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias)
def cutlass_scaled_mm_azp(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
......
......@@ -20,6 +20,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
@classmethod
def get_min_capability(cls) -> int:
......@@ -97,4 +98,5 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias)
bias=bias,
w8a8_strategy=self.w8a8_strategy)
......@@ -192,19 +192,34 @@ def apply_int8_linear(
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
w8a8_strategy:Optional[int]=0,
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)
if w8a8_strategy==1:
return ops.triton_scaled_mm(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
elif w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
else:
return ops.rocblas_scaled_mm(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment