Commit 836dee3b authored by wanglong3's avatar wanglong3 Committed by zhuwenwen
Browse files

Support blaslt w8a8 GEMM op.

parent aefb81d8
......@@ -1140,6 +1140,18 @@ def rocblas_scaled_mm(a: torch.Tensor,
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def blaslt_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
m = a.shape[0]
n = b.shape[0]
k = a.shape[1]
_, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype)
return out
def triton_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
......
......@@ -635,6 +635,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
elif self.w8a8_strategy==3:
layer.weight.data = layer.weight.data.T
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
......
......@@ -461,31 +461,39 @@ def apply_int8_linear(
else:
best_config=None
# if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return ops.triton_scaled_mm(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias,best_config=best_config)
return ops.triton_scaled_mm(
x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias,best_config=best_config)
elif w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
return ops.cutlass_scaled_mm(
x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
elif w8a8_strategy==3:
# x_q: shape (m, k) stride (k, 1)
# weight: shape (n, k) stride (k, 1)
return ops.blaslt_scaled_mm(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
bias=None)
else:
return ops.rocblas_scaled_mm(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
return ops.rocblas_scaled_mm(
x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
def normalize_e4m3fn_to_e4m3fnuz(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment