Commit 638296df authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.4' into 'main'

[DCU] fix

See merge request dcutoolkit/deeplearing/TransformerEngine!24
parents ccb9a1b1 6d461a10
...@@ -7,6 +7,7 @@ from typing import Tuple ...@@ -7,6 +7,7 @@ from typing import Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from torch.utils.cpp_extension import IS_HIP_EXTENSION
@triton.jit @triton.jit
...@@ -195,14 +196,17 @@ class CuBLASRefBlockwiseGemm: ...@@ -195,14 +196,17 @@ class CuBLASRefBlockwiseGemm:
# Perform qgemm with scaling factors fused in the GEMM # Perform qgemm with scaling factors fused in the GEMM
# Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM # Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM
one = torch.tensor(1.0, dtype=torch.float32, device=qx.device) one = torch.tensor(1.0, dtype=torch.float32, device=qx.device)
y_partial = torch._scaled_mm( if IS_HIP_EXTENSION:
qx_block, y_partial = torch.mm(qx_block.to(torch.float), qw_block.t().to(torch.float))
qw_block.t(), else:
scale_a=one, y_partial = torch._scaled_mm(
scale_b=one, qx_block,
out_dtype=torch.float32, qw_block.t(),
use_fast_accum=not use_split_accumulator, scale_a=one,
) scale_b=one,
out_dtype=torch.float32,
use_fast_accum=not use_split_accumulator,
)
# Accumulate the partial result # Accumulate the partial result
if is_a_1d_scaled and is_b_1d_scaled: if is_a_1d_scaled and is_b_1d_scaled:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment