Commit 33a5ce88 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-blaslt-w8a8-GEMM' into 'v0.9.2-dev'

Support blaslt w8a8 GEMM op.

See merge request dcutoolkit/deeplearing/vllm!238
parents 8d6b0b0a 836dee3b
...@@ -1140,6 +1140,18 @@ def rocblas_scaled_mm(a: torch.Tensor, ...@@ -1140,6 +1140,18 @@ def rocblas_scaled_mm(a: torch.Tensor,
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias) return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def blaslt_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
m = a.shape[0]
n = b.shape[0]
k = a.shape[1]
_, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype)
return out
def triton_scaled_mm(a: torch.Tensor, def triton_scaled_mm(a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
......
...@@ -635,6 +635,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -635,6 +635,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
for key, value in configs_dict.items(): for key, value in configs_dict.items():
m=int(key.split('_')[0]) m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value) ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
elif self.w8a8_strategy==3:
layer.weight.data = layer.weight.data.T
else: else:
weight_data=layer.weight.data weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1) _weight=weight_data.T.contiguous().reshape(n,-1)
......
...@@ -475,31 +475,39 @@ def apply_int8_linear( ...@@ -475,31 +475,39 @@ def apply_int8_linear(
else: else:
best_config=None best_config=None
# if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return ops.triton_scaled_mm(
return ops.triton_scaled_mm(x_q, x_q,
weight, weight,
scale_a=x_scale, scale_a=x_scale,
scale_b=weight_scale, scale_b=weight_scale,
out_dtype=input.dtype, out_dtype=input.dtype,
bias=bias,best_config=best_config) bias=bias,best_config=best_config)
elif w8a8_strategy==2: elif w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q, return ops.cutlass_scaled_mm(
x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
elif w8a8_strategy==3:
# x_q: shape (m, k) stride (k, 1)
# weight: shape (n, k) stride (k, 1)
return ops.blaslt_scaled_mm(x_q,
weight, weight,
scale_a=x_scale, scale_a=x_scale,
scale_b=weight_scale, scale_b=weight_scale,
out_dtype=input.dtype, out_dtype=input.dtype,
bias=bias) bias=None)
else: else:
return ops.rocblas_scaled_mm(x_q, return ops.rocblas_scaled_mm(
weight, x_q,
scale_a=x_scale, weight,
scale_b=weight_scale, scale_a=x_scale,
out_dtype=input.dtype, scale_b=weight_scale,
bias=bias) out_dtype=input.dtype,
bias=bias)
def normalize_e4m3fn_to_e4m3fnuz( def normalize_e4m3fn_to_e4m3fnuz(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment