Commit d261a1e6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-hotfix' into 'v0.9.2-dev'

Fix blaslt miss bias.

See merge request dcutoolkit/deeplearing/vllm!270
parents ade7db0c 2164aab4
...@@ -1150,6 +1150,8 @@ def blaslt_scaled_mm(a: torch.Tensor, ...@@ -1150,6 +1150,8 @@ def blaslt_scaled_mm(a: torch.Tensor,
n = b.shape[0] n = b.shape[0]
k = a.shape[1] k = a.shape[1]
_, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype) _, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype)
if bias is not None:
out += bias
return out return out
def triton_scaled_mm(a: torch.Tensor, def triton_scaled_mm(a: torch.Tensor,
...@@ -2486,4 +2488,4 @@ direct_register_custom_op( ...@@ -2486,4 +2488,4 @@ direct_register_custom_op(
op_func=awq_gemm, op_func=awq_gemm,
mutates_args=[], mutates_args=[],
fake_impl=awq_gemm_fake, fake_impl=awq_gemm_fake,
) )
\ No newline at end of file
...@@ -504,7 +504,7 @@ def apply_int8_linear( ...@@ -504,7 +504,7 @@ def apply_int8_linear(
scale_a=x_scale, scale_a=x_scale,
scale_b=weight_scale, scale_b=weight_scale,
out_dtype=input.dtype, out_dtype=input.dtype,
bias=None) bias=bias)
else: else:
return ops.rocblas_scaled_mm( return ops.rocblas_scaled_mm(
x_q, x_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment