Commit 09c2856a authored by zhuwenwen's avatar zhuwenwen
Browse files

Fix blaslt miss bias

parent 9be76efd
...@@ -1091,6 +1091,8 @@ def blaslt_scaled_mm(a: torch.Tensor, ...@@ -1091,6 +1091,8 @@ def blaslt_scaled_mm(a: torch.Tensor,
n = b.shape[0] n = b.shape[0]
k = a.shape[1] k = a.shape[1]
_, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype) _, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype)
if bias is not None:
out += bias
return out return out
def triton_scaled_mm(a: torch.Tensor, def triton_scaled_mm(a: torch.Tensor,
......
...@@ -555,7 +555,7 @@ def apply_int8_linear( ...@@ -555,7 +555,7 @@ def apply_int8_linear(
scale_a=x_scale, scale_a=x_scale,
scale_b=weight_scale, scale_b=weight_scale,
out_dtype=input.dtype, out_dtype=input.dtype,
bias=None) bias=bias)
else: else:
return ops.rocblas_scaled_mm( return ops.rocblas_scaled_mm(
x_q, x_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment