Unverified Commit db2aaa9e authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Subchannel Block quantized GEMM (#1545)



* Add GEMM logic for blockwise quantized tensors.

GEMM test cases included in pytorch integration.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update NVTE_BLOCK_SCALING for GEMM.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Gate feature on CUDA 12.9
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Gemm typo.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove unecessary type converter change.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reflect epilogue availability and test supported epilogues.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* GEMM simplifications from recipe branch.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Format py code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update GEMM DGelu tests to match support depending on output dtype.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Force pow2Scales in GEMM
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add GEMM test to pytorch test suite.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add copyright to GEMM test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update import for GEMM test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add license.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test gemm supported predicate.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use sgemm like interfaces and naming.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Rewrite GEMM comment.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR Feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Refactor GEMM param canonicalization

Configure A and B matrices separately. Have separate code path for each scaling mode.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Prune number of tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent b362a6e0
......@@ -32,6 +32,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import torch
import triton
import triton.language as tl
@triton.jit
def fused_fma_kernel(y_ptr, x_ptr, s_ptr, M, N, y_str0, y_str1, BLOCK: tl.constexpr = 128):
pid = tl.program_id(0)
idx = pid * BLOCK + tl.arange(0, BLOCK)
mask = idx < M * N
row = idx // N
col = idx % N
y_offset = row * y_str0 + col * y_str1
x_offset = row * N + col
s_offset = row * N + col
y = tl.load(y_ptr + y_offset, mask=mask)
x = tl.load(x_ptr + x_offset, mask=mask)
s = tl.load(s_ptr + s_offset, mask=mask)
tl.store(y_ptr + y_offset, tl.fma(x, s, y), mask=mask)
def fused_fma(y, x, s, BLOCK=128):
"""
Fused multiply-add operation (y = y + x * s).
PyTorch does not provide a direct FMA equivalent (torch.addcmul is not bitwise equivalent to this operation).
This function also supports cases where 'y' is non-contiguous in memory.
"""
assert (
y.shape == x.shape == s.shape and y.dim() == 2
), "All tensors must be 2D with the same shape"
assert x.is_contiguous() and s.is_contiguous(), "x and s must be contiguous"
M, N = y.shape
grid = ((M * N + BLOCK - 1) // BLOCK,)
fused_fma_kernel[grid](y, x, s, M, N, *y.stride(), BLOCK)
return y
class CuBLASRefBlockwiseGemm:
"""
A cuBLAS compatible reference implementation of subchannel GEMM.
"""
def qgemm(
self,
qx: torch.Tensor,
qw: torch.Tensor,
out_dtype: torch.dtype,
demunged_sx: torch.Tensor,
demunged_sw: torch.Tensor,
quant_tile_shape_x: Tuple[int, int],
quant_tile_shape_w: Tuple[int, int],
bias: torch.Tensor | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
use_split_accumulator: bool = False,
) -> torch.Tensor:
# demunge scale shapes for cuBLAS
is_a_1d_scaled = quant_tile_shape_x[0] == 1
is_b_1d_scaled = quant_tile_shape_w[0] == 1
M, K = qx.shape
N, K = qw.shape
# mm_tile_shape = (tile_m, tile_n, tile_k)
mm_tile_shape = (
quant_tile_shape_x[0],
quant_tile_shape_w[0],
quant_tile_shape_w[1],
)
if bias is not None and bias.numel():
# To match cuBLAS more closely when bias is applied,
# the reference accumulates into float32, and cast to
# bfloat16 is deferred until after the GEMM.
out_dtype_for_ref = torch.float32
else:
out_dtype_for_ref = out_dtype
y = self.qgemm_blockwise_2d(
qx,
qw,
out_dtype_for_ref,
demunged_sx,
demunged_sw,
mm_tile_shape,
use_split_accumulator,
is_a_1d_scaled,
is_b_1d_scaled,
)
if bias is not None and bias.numel():
y += bias
y = y.to(dtype=out_dtype)
# cublas accumulation first convert to output dtype, then accumulate.
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y
@classmethod
def qgemm_blockwise_2d(
cls,
qx: torch.Tensor,
qw: torch.Tensor,
out_dtype: torch.dtype,
sx: torch.Tensor,
sw: torch.Tensor,
mm_tile_shape: Tuple[int, int, int],
use_split_accumulator: bool,
is_a_1d_scaled: bool,
is_b_1d_scaled: bool,
) -> torch.Tensor:
"""
Difference between cuBLAS and CUTLASS GEMM implementations:
- cuBLAS accumulation equation: use different equation for each scaling mode.
- For accumulation C in epiloge, it first convert C to output dtype, then accumulate.
"""
M, K = qx.shape
N, K_w = qw.shape
assert K == K_w, "K dimension mismatch between qx and qw"
tile_len = 128
# Calculate grid sizes without padding
grid_m = (M + tile_len - 1) // tile_len
grid_n = (N + tile_len - 1) // tile_len
grid_k = (K + tile_len - 1) // tile_len
block_m, block_n, block_k = mm_tile_shape
scale_m_per_tile = tile_len // block_m
scale_n_per_tile = tile_len // block_n
assert block_k == tile_len, "block_k must be equal to tile_len"
# Notes on making the reference implementation numerically equivalent to Cast Blockwise FP8 GEMM:
# 1) When using split_accumulate in FP8 GEMM, every 4 QMMA partial accumulation results are accumulated into float32 registers.
# 2) Partial accumulation results are accumulated using FMA (Fused Multiply-Add) instructions to apply scaling factors, as in: y += partial_y * scale
y = torch.zeros(M, N, dtype=torch.float32, device=qx.device)
# Validate shapes of sx and sw
scale_m_per_tensor = (M + block_m - 1) // block_m
scale_n_per_tensor = (N + block_n - 1) // block_n
assert sx.shape == (
scale_m_per_tensor,
grid_k,
), f"sx shape mismatch: expected ({scale_m_per_tensor}, {grid_k}), got {sx.shape}"
assert sw.shape == (
scale_n_per_tensor,
grid_k,
), f"sw shape mismatch: expected ({scale_n_per_tensor}, {grid_k}), got {sw.shape}"
for i in range(grid_m):
m_start = i * tile_len
m_end = min(m_start + tile_len, M)
m_size = m_end - m_start
for j in range(grid_n):
n_start = j * tile_len
n_end = min(n_start + tile_len, N)
n_size = n_end - n_start
y_block = y[m_start:m_end, n_start:n_end]
for k in range(grid_k):
k_start = k * tile_len
k_end = min(k_start + tile_len, K)
k_size = k_end - k_start
qx_block = (
qx[m_start:m_end, k_start:k_end].clone().contiguous()
) # Shape: [m_size, k_size]
qw_block = (
qw[n_start:n_end, k_start:k_end].clone().contiguous()
) # Shape: [n_size, k_size]
# Extract scaling factors for the current blocks
sx_block = sx[i * scale_m_per_tile : (i + 1) * scale_m_per_tile, k].unsqueeze(
-1
)
sw_block = sw[j * scale_n_per_tile : (j + 1) * scale_n_per_tile, k].unsqueeze(0)
# Perform qgemm with scaling factors fused in the GEMM
# Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM
one = torch.tensor(1.0, dtype=torch.float32, device=qx.device)
y_partial = torch._scaled_mm(
qx_block,
qw_block.t(),
scale_a=one,
scale_b=one,
out_dtype=torch.float32,
use_fast_accum=not use_split_accumulator,
)
# Accumulate the partial result
if is_a_1d_scaled and is_b_1d_scaled:
# 1Dx1D
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
y_partial = y_partial * sx_block
# Fuse multiplication and addition to align with the split_accumulate in FP8 GEMM
# y_block.add_(y_partial, alpha=scale.item())
fused_fma(
y_block,
y_partial,
sw_block.expand_as(y_partial).contiguous(),
)
elif not is_a_1d_scaled and is_b_1d_scaled:
# 2Dx1D
# CuBLAS accumulation equation: y += (y * scale_b) * scale_a
y_partial = y_partial * sw_block
fused_fma(
y_block,
y_partial,
sx_block.expand_as(y_partial).contiguous(),
)
elif is_a_1d_scaled and not is_b_1d_scaled:
# 1Dx2D
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
y_partial = y_partial * sx_block
fused_fma(
y_block,
y_partial,
sw_block.expand_as(y_partial).contiguous(),
)
else:
scale = sx_block * sw_block
fused_fma(y_block, y_partial, scale.expand_as(y_partial).contiguous())
y = y.to(out_dtype)
return y
......@@ -49,6 +49,7 @@ class CuBLASScaleMunger:
s_t = _pad_inner_to_align(unmunged.scale_t, transpose=tile_shape[0] == 1)
return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t)
@classmethod
def demunge_scale_shape_from_backend(
cls,
qtensor_shape: Tuple[int, int],
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.utils import get_device_compute_capability
from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
def fp8_blockwise_gemm_supported() -> bool:
return (
get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0)
and float(torch.version.cuda) >= 12.9
)
def cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
*,
x_columnwise: bool = False,
w_columnwise: bool = False,
use_bias: bool = False,
use_gelu: bool = False,
use_grad: bool = False,
atol: float = 0.0,
rtol: float = 0.0
):
if x_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2:
pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2")
if not (is_x_1d_scaled or is_w_1d_scaled):
pytest.skip("FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile")
if not fp8_blockwise_gemm_supported():
pytest.skip("CUDA version does not support blockwise FP8 gemm.")
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x_shape = (K, M) if x_columnwise else (M, K)
w_shape = (K, N) if w_columnwise else (N, K)
# generate random input and weight
if noise_type == "uniform":
x = torch.rand(x_shape, dtype=torch.float32, device=device) * x_magnitude * 2 - x_magnitude
w = torch.rand(w_shape, dtype=torch.float32, device=device) * w_magnitude * 2 - w_magnitude
elif noise_type == "normal":
x = torch.randn(x_shape, dtype=torch.float32, device=device) * x_magnitude
w = torch.randn(w_shape, dtype=torch.float32, device=device) * w_magnitude
else:
assert False
# Setup out tensor if accumulate is True
if accumulate:
out = torch.randn((M, N), dtype=out_dtype, device=device) * x_magnitude
else:
out = None
assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128)
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128)
x_block_scaling_dim = 1 if is_x_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2
x_te_dtype = TE_DType[x_dtype]
w_te_dtype = TE_DType[w_dtype]
x_quantizer = Float8BlockQuantizer(
fp8_dtype=x_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=x_block_scaling_dim,
)
w_quantizer = Float8BlockQuantizer(
fp8_dtype=w_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=w_block_scaling_dim,
)
# Quantize x and w
qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False)
qx = x_quantizer.update_quantized(x, qx)
qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False)
qw = w_quantizer.update_quantized(w, qw)
if not use_bias:
bias = None
else:
bias = torch.randn((1, N), dtype=torch.bfloat16, device=device)
# Reference GEMM
ref_gemm = CuBLASRefBlockwiseGemm()
scale_decoder = CuBLASScaleMunger()
qx_data = (
qx._columnwise_data.view(dtype=x_dtype)
if x_columnwise
else qx._rowwise_data.view(dtype=x_dtype)
)
qw_data = (
qw._columnwise_data.view(dtype=w_dtype)
if w_columnwise
else qw._rowwise_data.view(dtype=w_dtype)
)
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
y_ref = ref_gemm.qgemm(
qx=qx_data,
qw=qw_data,
out_dtype=out_dtype,
demunged_sx=CuBLASScaleMunger.demunge_scale_shape_from_backend(
qtensor_shape=(M, K), scales=ref_scales_x, tile_shape=x_quant_tile_shape
),
demunged_sw=CuBLASScaleMunger.demunge_scale_shape_from_backend(
qtensor_shape=(N, K), scales=ref_scales_w, tile_shape=w_quant_tile_shape
),
quant_tile_shape_x=x_quant_tile_shape,
quant_tile_shape_w=w_quant_tile_shape,
bias=bias,
out=out.clone() if accumulate else None,
accumulate=accumulate,
use_split_accumulator=use_split_accumulator,
)
# Allocate cuBLAS workspace
workspace_size = 0
workspace = torch.empty(0, dtype=torch.uint8, device=device)
transa = True if not w_columnwise else False
transb = False if not x_columnwise else True
out_quantizer = None
assert not (use_gelu and use_bias), "Bias and GELU not supported by GEMM"
aux_tensor = torch.randn((M, N), dtype=out_dtype, device=device) if use_gelu else None
aux_tensor_ref = aux_tensor.clone() if use_gelu else None
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y = tex.generic_gemm(
qw,
transa,
qx,
transb,
out.clone() if accumulate else None,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# just in case of accumulation, make sure y_ref and y are not the same tensor
assert y_ref is not y, "y_ref and y should not be the same tensor"
# Reset nans to zeros because torch.assert_close does not assume nans to be equal
assert not torch.isnan(y_ref.float()).all(), "All elements are nan"
y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref)
y = torch.where(y.isnan(), torch.zeros_like(y), y)
if use_gelu:
# Check
if use_grad:
# With use_grad, GEMM should use aux tensor to calculate
# gradient
gelu_ref = tex.dgelu(y_ref, aux_tensor_ref, None)
# TODO: How do we decide whether this is acceptably close?
# Could also try to put the activation inside the reference
# before the output cast to see different tolerances.
torch.testing.assert_close(y, gelu_ref, atol=1e-3, rtol=1e-2)
else:
# aux tensor is pre-gelu aux output. Verify against y_ref.
torch.testing.assert_close(aux_tensor, y_ref, atol=atol, rtol=rtol)
act = torch.nn.GELU()
gelu_ref = act(y_ref)
# gelu_ref = tex.gelu(y_ref, None)
torch.testing.assert_close(y, gelu_ref, atol=atol, rtol=rtol)
else:
torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol)
def cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
*,
x_columnwise: bool = False,
w_columnwise: bool = False,
use_bias: bool = False,
use_gelu: bool = False,
use_grad: bool = False,
expected_err_msg="CUBLAS_STATUS_NOT_SUPPORTED",
expected_err_cls=RuntimeError
):
if not fp8_blockwise_gemm_supported():
pytest.skip("CUDA version does not support blockwise FP8 gemm.")
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x_shape = (K, M) if x_columnwise else (M, K)
w_shape = (K, N) if w_columnwise else (N, K)
# generate random input and weight
x = torch.rand(x_shape, dtype=torch.float32, device=device) * 2.0 - 1.0
w = torch.rand(w_shape, dtype=torch.float32, device=device) * 2.0 - 1.0
# Setup out tensor if accumulate is True
if accumulate:
out = torch.randn((M, N), dtype=out_dtype, device=device)
else:
out = None
# Set quantize_op and quantization parameters
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128)
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128)
x_block_scaling_dim = 1 if is_x_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2
x_te_dtype = TE_DType[x_dtype]
w_te_dtype = TE_DType[w_dtype]
x_quantizer = Float8BlockQuantizer(
fp8_dtype=x_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=x_block_scaling_dim,
)
w_quantizer = Float8BlockQuantizer(
fp8_dtype=w_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=w_block_scaling_dim,
)
# Quantize x and w
qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False)
qx = x_quantizer.update_quantized(x, qx)
qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False)
qw = w_quantizer.update_quantized(w, qw)
if not use_bias:
bias = None
else:
bias = torch.randn((1, N), dtype=torch.bfloat16, device=device)
# Allocate cuBLAS workspace
workspace_size = 0
workspace = torch.empty(0, dtype=torch.uint8, device=device)
transa = True if not w_columnwise else False
transb = False if not x_columnwise else True
out_quantizer = None
grad = use_grad
gelu_in = None if not use_gelu else torch.randn((M, N), dtype=out_dtype, device=device)
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
with pytest.raises(expected_err_cls, match=expected_err_msg):
y = tex.generic_gemm(
qw,
transa,
qx,
transb,
out.clone() if accumulate else None,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
gelu_in,
grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(128, 128, 128),
(256, 128, 256),
# non 128x128 divisible input shapes
(320, 128, 336),
(320, 64, 336),
# k > 128
(256, 256, 256),
(320, 256, 336),
(1024, 4096, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1], ids=str)
@pytest.mark.parametrize("w_magnitude", [1], ids=str)
@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_cublas_gemm_fp8_blockwise_shape_varying(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
):
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
)
@pytest.mark.parametrize(
"M, K, N",
[
(256, 128, 256),
(320, 256, 336),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal", "uniform"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str)
@pytest.mark.parametrize("w_magnitude", [1], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
):
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
# non 128x128 divisible input shapes
(320, 64, 336),
# k > 128
(256, 256, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1e-3], ids=str)
@pytest.mark.parametrize("w_magnitude", [1], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_cublas_gemm_fp8_blockwise_bias(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
):
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
use_bias=True,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
# non 128x128 divisible input shapes
(16, 128, 128),
(320, 64, 336),
# k > 128
(4096, 128, 4096),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1], ids=str)
@pytest.mark.parametrize("w_magnitude", [1], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
@pytest.mark.parametrize(
"is_x_columnwise, is_w_columnwise",
[
(True, False),
(True, True),
(False, True),
],
ids=["colxrow", "colxcol", "rowxcol"],
)
def test_cublas_gemm_fp8_blockwise_columnwise(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
is_x_columnwise,
is_w_columnwise,
):
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
x_columnwise=is_x_columnwise,
w_columnwise=is_w_columnwise,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
# non 128x128 divisible input shapes
(320, 64, 336),
# k > 128
(256, 256, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("noise_type", ["normal"], ids=str)
@pytest.mark.parametrize("x_magnitude", [1], ids=str)
@pytest.mark.parametrize("w_magnitude", [1], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
@pytest.mark.parametrize(
"use_grad",
[
True,
],
ids=["grad"],
)
def test_cublas_gemm_fp8_gelu(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
use_grad,
):
# NOTE: cuBLAS doesn't complain with not use_grad, but the tests don't succeed
# so the epilogue is disabled on the transformer engine side.
if not use_grad and not (is_x_1d_scaled and not is_w_1d_scaled):
pytest.skip(
"CUBLASLT_EPILOGUE_GELU_AUX epilogue is only supported for 1Dx2D (cuBLAS 2Dx1D)."
)
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
use_gelu=True,
use_grad=use_grad,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [False], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_split_accumulator_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_bgrad_not_supported(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
# NOTE: BGRAD epilogue is not supported for fp8.
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
use_grad=True,
use_bias=True,
expected_err_msg="Epilogue requested outside of the available",
)
@pytest.mark.parametrize(
"M, K, N",
[
# k = 128
(256, 128, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no_bias"])
@pytest.mark.parametrize("use_grad", [True, False], ids=["grad", "no_grad"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_gelu_unsupported_cases_error(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_bias,
use_grad,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
if use_grad and not use_bias and out_dtype == torch.bfloat16:
pytest.skip("DGELU epilogue is supported for bfloat16.")
elif use_grad and not use_bias:
expected_err = "an unsupported value or parameter was passed"
else:
expected_err = "Epilogue requested outside of the available"
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
use_grad=use_grad,
use_bias=use_bias,
use_gelu=True,
expected_err_msg=expected_err,
)
@pytest.mark.parametrize(
"M, K, N",
[
(256, 128, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(True, True),
(False, True),
],
ids=["1Dx2D", "1Dx1D", "2Dx1D"],
)
def test_illegal_dtype_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
# e5m2 by e5m2 not supported.
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
)
@pytest.mark.parametrize(
"M, K, N",
[
(256, 128, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(False, False),
],
ids=["2Dx2D"],
)
def test_illegal_2D_by_2D_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
# 2D block quantization by 2D block quantization is not supported.
expected_err_msg = "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported"
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
expected_err_msg=expected_err_msg,
)
@pytest.mark.parametrize(
"M, K, N, legalX1d, legalX2d",
[
# M dim unconstrained when X is 2D.
(255, 128, 256, False, True),
# K must be multiple of 16
(256, 120, 256, False, False),
# N must be a multiple of 8
(256, 128, 252, False, False),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"])
@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"])
@pytest.mark.parametrize(
"is_x_1d_scaled, is_w_1d_scaled",
[
(True, False),
(False, True),
(True, True),
],
ids=["1Dx2D", "2Dx1D", "1Dx1D"],
)
def test_unaligned_shapes(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
legalX1d,
legalX2d,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
legal = legalX1d if is_x_1d_scaled else legalX2d
if not legal:
cublas_gemm_test_constraint_enforced(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
expected_err_msg="dimension requirement",
)
else:
cublas_gemm_fp8_blockwise_case(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
"uniform", # noise type
1.0, # x_magnitude
1.0, # w_magnitude
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
)
......@@ -52,97 +52,173 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) {
NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}
/* Parameters for cuBLAS GEMM
*
* cuBLAS follows the BLAS convention of column-major ordering. This
* is different than the row-major that is typically used in
* Transformer Engine.
*
*/
struct GemmParam {
void *A;
void *B;
cublasOperation_t transA;
cublasOperation_t transB;
transformer_engine::DType Atype;
transformer_engine::DType Btype;
void *A_scale_inv;
void *B_scale_inv;
int lda;
int ldb;
GemmParam(cublasOperation_t transA, cublasOperation_t transB)
: A(nullptr),
B(nullptr),
transA(transA),
transB(transB),
Atype(transformer_engine::DType::kNumTypes),
Btype(transformer_engine::DType::kNumTypes),
A_scale_inv(nullptr),
B_scale_inv(nullptr),
lda(0),
ldb(0) {}
void *A = nullptr;
void *B = nullptr;
cublasOperation_t transA = CUBLAS_OP_N;
cublasOperation_t transB = CUBLAS_OP_N;
transformer_engine::DType Atype = transformer_engine::DType::kNumTypes;
transformer_engine::DType Btype = transformer_engine::DType::kNumTypes;
void *A_scale_inv = nullptr;
void *B_scale_inv = nullptr;
int lda = 0; // A column strides
int ldb = 0; // B column strides
};
/* Populate parameters for cuBLAS GEMM
*
* cuBLAS follows the BLAS convention of column-major ordering. This
* is different than the row-major that is typically used in
* Transformer Engine.
*
*/
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
const transformer_engine::Tensor &B, const cublasOperation_t transB,
const int k, const int lda, const int ldb) {
int m, int n, int k) {
using namespace transformer_engine;
// FIXME(kwyss): 1x128 by 128x128 GEMM is part of the subchannel design.
// Must either force them both into a common block scaling mode or loosen this
// restriction.
NVTE_CHECK(A.scaling_mode == B.scaling_mode,
"Inputs A and B to GEMM need to have the same scaling mode!");
NVTE_CHECK(
A.scaling_mode == B.scaling_mode ||
(A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
(A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
"Inputs A and B to GEMM need to have compatible scaling modes!");
NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
GemmParam ret(transA, transB);
GemmParam ret;
// Device compute capability
const int arch = cuda::sm_arch();
ret.lda = lda;
ret.ldb = ldb;
// Transpose mode with column-major ordering
bool transa_bool = transA == CUBLAS_OP_T;
bool transb_bool = transB == CUBLAS_OP_T;
// FIXME(kwyss): 128x128 by 128x128 GEMMs and 1x128 by 128x128 GEMMs need cases
// or need to be treated as `is_tensor_scaling`.
// Configure A matrix
if (is_tensor_scaling(A.scaling_mode)) {
// Unscaled or FP8 tensor scaling
ret.A = A.data.dptr;
ret.A_scale_inv = A.scale_inv.dptr;
if (transA == CUBLAS_OP_T) {
ret.transA = transA;
ret.Atype = A.data.dtype;
} else {
ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype;
if (is_fp8_dtype(ret.Atype)) {
int arch = cuda::sm_arch(cuda::current_device());
if (arch < 100) {
// Hopper and Ada - we need to use columnwise_data and change transA
NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!");
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = transa_bool ? k : m;
if (arch < 100 && !transa_bool) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
ret.transA = CUBLAS_OP_T;
ret.Atype = A.columnwise_data.dtype;
ret.A_scale_inv = A.columnwise_scale_inv.dptr;
ret.lda = k;
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
}
} else if (is_mxfp_scaling(A.scaling_mode)) {
// MXFP8
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if (transa_bool) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else {
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage");
}
ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr;
ret.transA = transA;
ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.lda = m;
} else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) {
// FP8 block scaling
// Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (transa_bool) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else {
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage");
}
ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr;
ret.transA = CUBLAS_OP_T;
ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.lda = k;
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK((ret.lda % 16) == 0,
"Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
// Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
// Smallest supported CType is 2 bytes in this scaling mode.
NVTE_CHECK((m % 8) == 0,
"Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
} else {
NVTE_ERROR("A has unsupported scaling mode");
}
// Configure B matrix
if (is_tensor_scaling(B.scaling_mode)) {
// Unscaled or FP8 tensor scaling
ret.B = B.data.dptr;
ret.transB = transB;
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
if (transB == CUBLAS_OP_T) {
ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype;
if (is_fp8_dtype(ret.Btype)) {
int arch = cuda::sm_arch(cuda::current_device());
if (arch < 100) {
// Hopper and Ada - we need to use columnwise_data and change transA
NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!");
ret.ldb = transb_bool ? n : k;
if (arch < 100 && transb_bool) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
ret.transB = CUBLAS_OP_N;
ret.Btype = B.columnwise_data.dtype;
ret.B_scale_inv = B.columnwise_scale_inv.dptr;
ret.ldb = k;
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
}
} else if (is_mxfp_scaling(B.scaling_mode)) {
// MXFP8
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if (transb_bool) {
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else {
ret.Btype = B.data.dtype;
NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
}
ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr;
ret.transB = transB;
ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.ldb = k;
} else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) {
// FP8 block scaling
// Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (transb_bool) {
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else {
// If not tensor scaling (which includes also high precision types), we need to
// use the proper version of data
// We leave the transA/B values as is, since Blackwell supports transposes
ret.A = transA ? A.data.dptr : A.columnwise_data.dptr;
ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.B = transB ? B.columnwise_data.dptr : B.data.dptr;
ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
}
ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr;
ret.transB = CUBLAS_OP_N;
ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.ldb = k;
// Requirements from
// https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK((ret.ldb % 16) == 0,
"B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) {
// Observed this requirement only present for B tensor is 1D quantized.
NVTE_CHECK((n % 8) == 0,
"Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
}
} else {
NVTE_ERROR("B has unsupported scaling mode");
}
return ret;
}
......@@ -153,18 +229,33 @@ namespace transformer_engine {
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad,
void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, cudaStream_t stream) {
const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) {
// Tensor dims in row-major order
const int A0 = inputA->flat_first_dim();
const int A1 = inputA->flat_last_dim();
const int B0 = inputB->flat_first_dim();
const int B1 = inputB->flat_last_dim();
// GEMM dims in column-major order
const int m = transa == CUBLAS_OP_T ? A0 : A1;
const int n = transb == CUBLAS_OP_T ? B1 : B0;
const int k = transa == CUBLAS_OP_T ? A1 : A0;
NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k,
"GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
")");
const int ldd = m;
// Return immediately if GEMM is trivial
if (m <= 0 || n <= 0) {
return;
}
NVTE_CHECK(k > 0);
const GemmParam &param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb);
const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);
void *C = outputD->data.dptr;
void *D = outputD->data.dptr;
void *D_scale = outputD->scale.dptr;
......@@ -226,6 +317,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
param.transA == CUBLAS_OP_N ? k : m, param.lda));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
param.transB == CUBLAS_OP_N ? n : k, param.ldb));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
......@@ -249,12 +341,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
&fastAccuMode, sizeof(fastAccuMode)));
// FIXME(kwyss): Add binding code for 128x128 block quantized 1x128 block quantized
// GEMM types.
// Scaling factors.
#if CUDA_VERSION >= 12080
cublasLtMatmulMatrixScale_t scaling_mode;
cublasLtMatmulMatrixScale_t scaling_mode_a;
cublasLtMatmulMatrixScale_t scaling_mode_b;
#endif
if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) {
void *A_scale_inverse = param.A_scale_inv;
......@@ -266,8 +356,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
#if CUDA_VERSION >= 12080
scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
} else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) {
scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
} else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) {
fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
......@@ -276,7 +367,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cublasLtGetVersion() <= 120803) {
......@@ -285,7 +377,32 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
sizeof(dummy_a_vec_stride)));
}
#endif
} else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
(inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
#if CUDA_VERSION >= 12090
float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
"Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D");
scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
: CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D
? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
: CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
NVTE_ERROR("FP8 block scaling requires CUDA 12.9+");
#endif // CUDA_VERSION >= 12090
#endif // CUDA_VERSION >= 12080
} else {
NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " +
to_string(inputB->scaling_mode) + ".");
......@@ -293,9 +410,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b)));
#endif
if (is_fp8_dtype(outputD->data.dtype)) {
// Accumulation mode not supported for FP8 output
......@@ -305,8 +422,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
#if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
// NOTE: In all current cases where FP8 output is supported, the input is
// scaled identically to the output.
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
&scaling_mode_a, sizeof(scaling_mode_a)));
#endif
// For FP8 output, cuBLAS requires C_type to match bias_type and
// be FP16/BF16
......@@ -364,6 +484,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
}
if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) ||
(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS ||
epilogue == CUBLASLT_EPILOGUE_DGELU),
"Epilogue requested outside of the available and tested cuBLAS functionality for "
"float8 block scaled GEMM");
}
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
......@@ -411,7 +539,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
"Unable to find suitable cuBLAS GEMM algorithm");
NVTE_CHECK_CUBLAS(status);
if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
// D = alpha * (A * B) + beta * C
......@@ -469,35 +596,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
const size_t A0 = inputA->flat_first_dim();
const size_t A1 = inputA->flat_last_dim();
const size_t B0 = inputB->flat_first_dim();
const size_t B1 = inputB->flat_last_dim();
const int m = transa ? A0 : A1;
const int k = transa ? A1 : A0;
const int n = transb ? B1 : B0;
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, 0, 0, false, nullptr, stream);
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
}
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
......@@ -525,31 +626,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode),
"Atomic GEMM only supports delayed scaling.");
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer,
inputCounter, stream);
}
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
......
......@@ -27,7 +27,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_block_scaling(z->scaling_mode)) {
!is_mxfp_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
......@@ -57,7 +57,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
if (cudnn_backend) {
// TODO: add check for GPU ARCH
......
......@@ -23,7 +23,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
Tensor *rsigma, Tensor *workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_block_scaling(z->scaling_mode)) {
!is_mxfp_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
......@@ -47,7 +47,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment