Commit 6d461a10 authored by yuguo's avatar yuguo
Browse files

[DCU] fix

parent 0a8072fa
......@@ -7,6 +7,7 @@ from typing import Tuple
import torch
import triton
import triton.language as tl
from torch.utils.cpp_extension import IS_HIP_EXTENSION
@triton.jit
......@@ -195,6 +196,9 @@ class CuBLASRefBlockwiseGemm:
# Perform qgemm with scaling factors fused in the GEMM
# Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM
one = torch.tensor(1.0, dtype=torch.float32, device=qx.device)
if IS_HIP_EXTENSION:
y_partial = torch.mm(qx_block.to(torch.float), qw_block.t().to(torch.float))
else:
y_partial = torch._scaled_mm(
qx_block,
qw_block.t(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment