Commit 2009d4a1 authored by zhuwenwen's avatar zhuwenwen
Browse files

增加w8a8的triton环境变量控制

parent f7512877
...@@ -715,6 +715,25 @@ def cutlass_scaled_mm(a: torch.Tensor, ...@@ -715,6 +715,25 @@ def cutlass_scaled_mm(a: torch.Tensor,
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias) return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def rocblas_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def triton_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias)
def cutlass_scaled_mm_azp(a: torch.Tensor, def cutlass_scaled_mm_azp(a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
......
...@@ -3,6 +3,7 @@ from typing import Callable, List, Optional ...@@ -3,6 +3,7 @@ from typing import Callable, List, Optional
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter from torch.nn import Parameter
import os
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...@@ -24,6 +25,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -24,6 +25,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric self.input_symmetric = input_symmetric
self.w8a8_strategy = int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
...@@ -145,4 +147,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -145,4 +147,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_scale=layer.input_scale, input_scale=layer.input_scale,
input_zero_point=layer.input_zero_point, input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj, azp_adj=layer.azp_adj,
bias=bias) bias=bias,
w8a8_strategy=self.w8a8_strategy)
...@@ -195,6 +195,7 @@ def apply_int8_linear( ...@@ -195,6 +195,7 @@ def apply_int8_linear(
input_zero_point: Optional[torch.Tensor] = None, input_zero_point: Optional[torch.Tensor] = None,
azp_adj: Optional[torch.Tensor] = None, azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
w8a8_strategy: Optional[int] = 0,
): ):
# ops.scaled_int8_quant supports both dynamic and static quant. # ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x. # * dynamic, layer.input_scale is None and x_scale computed from x.
...@@ -214,12 +215,27 @@ def apply_int8_linear( ...@@ -214,12 +215,27 @@ def apply_int8_linear(
azp_adj=azp_adj, azp_adj=azp_adj,
azp=x_zp, azp=x_zp,
bias=bias) bias=bias)
return ops.cutlass_scaled_mm(x_q, if w8a8_strategy == 1:
weight, return ops.triton_scaled_mm(x_q,
scale_a=x_scale, weight,
scale_b=weight_scale, scale_a=x_scale,
out_dtype=input.dtype, scale_b=weight_scale,
bias=bias) out_dtype=input.dtype,
bias=bias)
elif w8a8_strategy == 2:
return ops.cutlass_scaled_mm(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
else:
return ops.rocblas_scaled_mm(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
def normalize_e4m3fn_to_e4m3fnuz( def normalize_e4m3fn_to_e4m3fnuz(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment