Unverified Commit 99df8810 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add logic for block-scaled tensors with GEMM swizzled scales (#2486)



* Add general C API for setting tensor params
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Implement general accessors for NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor tex swizzling to skip if scales are already swizzled
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add checks for non-swizzled scales in MXFP8 and NVFP4 kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support pre-swizzled scales in MXFP8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tex function to swizzle MXFP8 scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in inplace swizzle function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak comments to use "compact/swizzled format"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 quantize kernel with pre-swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expose pre-swizzled scales in modules
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in multi-swizzle
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 gated activations with swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add PyTorch infrastructure for pre-swizzled NVFP4 tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Deprecate DSv3-specific quantization logic in C API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove support for DSv3 compact data from quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove DSv3 compact data format from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in FP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update JAX to use new swizzled scale API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update C++ swizzle test with swizzled scales API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Return default tensor params when querying params for invalid NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug DSv3 FP8 test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Userbuffers test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure gated activations populate FP8 transpose if needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable pre-swizzling with debug quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflicts and review suggestions

Update copyright years. Tweak comments. Fix various complaints from @greptile-apps.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use explicitly sized types in config accessors

Miscellaneous review suggestions from @ptrendx.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make util header for function that compute swizzled scale index
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from @greptile-apps
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Update expected error message in FP8 block-scaling test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @yaox12
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a652730f
...@@ -85,6 +85,7 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row ...@@ -85,6 +85,7 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row
std::vector<int> scaling_mode = {SF_MODE_X, SF_MODE_Y, 0}; std::vector<int> scaling_mode = {SF_MODE_X, SF_MODE_Y, 0};
Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
output.set_with_gemm_swizzled_scales(true);
fillUniform(&input); fillUniform(&input);
......
...@@ -286,6 +286,10 @@ class Tensor { ...@@ -286,6 +286,10 @@ class Tensor {
tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape);
} }
void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales){
tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
}
void to_cpu() const; void to_cpu() const;
void from_cpu() const; void from_cpu() const;
void set_scale(float scale); void set_scale(float scale);
......
...@@ -884,7 +884,7 @@ def test_illegal_2D_by_2D_enforced( ...@@ -884,7 +884,7 @@ def test_illegal_2D_by_2D_enforced(
is_w_1d_scaled, is_w_1d_scaled,
) -> None: ) -> None:
# 2D block quantization by 2D block quantization is not supported. # 2D block quantization by 2D block quantization is not supported.
expected_err_msg = "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported" expected_err_msg = "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling GEMM is supported"
cublas_gemm_test_constraint_enforced( cublas_gemm_test_constraint_enforced(
x_dtype, x_dtype,
w_dtype, w_dtype,
......
...@@ -87,126 +87,6 @@ def initialize_for_many_scales( ...@@ -87,126 +87,6 @@ def initialize_for_many_scales(
return result return result
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(303, 300),
(305, 256),
# Some larger tiles.
(2000, 2000),
(2048, 2000),
(2000, 1024),
(2048, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
def test_quantization_1D_block_tiling_with_compact_data_and_scales(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
pow_2_scales: bool,
) -> None:
te_dtype = TE_DType[quant_dtype]
tile_size = (1, 128)
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
ref_quantizer = BlockwiseQuantizerReference()
sut_quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=eps,
force_pow_2_scales=pow_2_scales,
block_scaling_dim=1,
all_gather_usage=True,
)
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device)
x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False)
x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut)
x_fp8_sut_cpp_alloc = sut_quantizer(x)
assert x_fp8_sut._rowwise_data is not None
qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype)
assert x_fp8_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv
qx_t = x_fp8_sut._columnwise_data
sx_t = x_fp8_sut._columnwise_scale_inv
qresult_ref = ref_quantizer.quantize(
x,
quant_dtype=quant_dtype,
return_transpose=True,
eps=eps,
pow_2_scales=pow_2_scales,
quant_tile_shape=tile_size,
munge_scale_shapes=False,
)
qx_ref, sx_ref, qx_t_ref, sx_t_ref = (
qresult_ref.data,
qresult_ref.scale,
qresult_ref.data_t,
qresult_ref.scale_t,
)
# match the reference quantize transpose output with the columnwise non-transpose method
qx_t_ref = qx_t_ref.transpose(-1, -2).contiguous()
sx_t_ref = sx_t_ref.transpose(-1, -2).contiguous()
# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0)
assert qx_t is not None
qx_t = qx_t.view(dtype=quant_dtype)
assert qx_t_ref is not None
assert sx_t is not None
assert sx_t_ref is not None
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0)
# check that the C++ and Python allocators are equivalent
torch.testing.assert_close(
x_fp8_sut._rowwise_data, x_fp8_sut_cpp_alloc._rowwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._rowwise_scale_inv, x_fp8_sut_cpp_alloc._rowwise_scale_inv, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_data, x_fp8_sut_cpp_alloc._columnwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_scale_inv,
x_fp8_sut_cpp_alloc._columnwise_scale_inv,
atol=0.0,
rtol=0.0,
)
# check if the fp8 output between C++ and Python are the same
assert x_fp8_sut._data_format == x_fp8_sut_cpp_alloc._data_format
def check_quantization_block_tiling_versus_reference( def check_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype, x_dtype: torch.dtype,
M: int, M: int,
......
...@@ -175,16 +175,12 @@ class TestFloat8BlockwiseTensor: ...@@ -175,16 +175,12 @@ class TestFloat8BlockwiseTensor:
) )
@pytest.mark.parametrize("block_scaling_dim", [1, 2]) @pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("dq_columnwise", [True, False]) @pytest.mark.parametrize("dq_columnwise", [True, False])
@pytest.mark.parametrize("all_gather_usage", [True, False])
def test_quantize_dequantize_dims( def test_quantize_dequantize_dims(
self, self,
dims: DimsType, dims: DimsType,
block_scaling_dim: int, block_scaling_dim: int,
dq_columnwise: bool, dq_columnwise: bool,
all_gather_usage: bool,
) -> None: ) -> None:
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
atol = _tols[tex.DType.kFloat8E4M3]["atol"] atol = _tols[tex.DType.kFloat8E4M3]["atol"]
rtol = _tols[tex.DType.kFloat8E4M3]["rtol"] rtol = _tols[tex.DType.kFloat8E4M3]["rtol"]
quantizer = Float8BlockQuantizer( quantizer = Float8BlockQuantizer(
...@@ -192,7 +188,6 @@ class TestFloat8BlockwiseTensor: ...@@ -192,7 +188,6 @@ class TestFloat8BlockwiseTensor:
rowwise=True, rowwise=True,
columnwise=dq_columnwise, columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim, block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
) )
self._test_quantize_dequantize( self._test_quantize_dequantize(
quantizer=quantizer, quantizer=quantizer,
...@@ -218,7 +213,6 @@ class TestFloat8BlockwiseTensor: ...@@ -218,7 +213,6 @@ class TestFloat8BlockwiseTensor:
rowwise=True, rowwise=True,
columnwise=dq_columnwise, columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim, block_scaling_dim=block_scaling_dim,
all_gather_usage=(block_scaling_dim == 1),
) )
self._test_quantize_dequantize( self._test_quantize_dequantize(
quantizer=quantizer, quantizer=quantizer,
...@@ -283,13 +277,8 @@ class TestFloat8BlockwiseTensor: ...@@ -283,13 +277,8 @@ class TestFloat8BlockwiseTensor:
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) @pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2]) @pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("all_gather_usage", [True, False]) def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None:
def test_serialization(
self, dims: DimsType, block_scaling_dim: int, all_gather_usage: bool
) -> None:
"""Test serialization of Float8BlockwiseQTensor""" """Test serialization of Float8BlockwiseQTensor"""
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
device = "cuda" device = "cuda"
dtype = torch.bfloat16 dtype = torch.bfloat16
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
...@@ -298,7 +287,6 @@ class TestFloat8BlockwiseTensor: ...@@ -298,7 +287,6 @@ class TestFloat8BlockwiseTensor:
rowwise=True, rowwise=True,
columnwise=True, columnwise=True,
block_scaling_dim=block_scaling_dim, block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
) )
# Create FP8 tensor # Create FP8 tensor
...@@ -322,7 +310,6 @@ class TestFloat8BlockwiseTensor: ...@@ -322,7 +310,6 @@ class TestFloat8BlockwiseTensor:
assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled
assert x_fp8_loaded.dtype == x_fp8.dtype assert x_fp8_loaded.dtype == x_fp8.dtype
assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype
assert x_fp8_loaded._data_format == x_fp8._data_format
# Test that dequantized values match # Test that dequantized values match
x_fp8_dequant = x_fp8.dequantize() x_fp8_dequant = x_fp8.dequantize()
......
...@@ -2737,7 +2737,11 @@ class TestCheckpointing: ...@@ -2737,7 +2737,11 @@ class TestCheckpointing:
# Check that original and loaded model match exactly # Check that original and loaded model match exactly
tols = {"rtol": 0, "atol": 0} tols = {"rtol": 0, "atol": 0}
for param_load, param_save in zip(model_load.parameters(), model_save.parameters()): for param_load, param_save in zip(model_load.parameters(), model_save.parameters()):
torch.testing.assert_close(param_load, param_save, **tols) torch.testing.assert_close( # Force dequantization by casting to FP64
param_load.to(dtype=torch.float64, device="cpu"),
param_save.to(dtype=torch.float64, device="cpu"),
**tols,
)
torch.testing.assert_close(param_load.grad, param_save.grad, **tols) torch.testing.assert_close(param_load.grad, param_save.grad, **tols)
for y_load, y_save in zip(ys_load, ys_save): for y_load, y_save in zip(ys_load, ys_save):
torch.testing.assert_close(y_load, y_save, **tols) torch.testing.assert_close(y_load, y_save, **tols)
...@@ -2754,7 +2758,6 @@ class TestSequentialModules: ...@@ -2754,7 +2758,6 @@ class TestSequentialModules:
@pytest.mark.parametrize("requires_grad", (False, True)) @pytest.mark.parametrize("requires_grad", (False, True))
@pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm"))
@pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
...@@ -2764,25 +2767,18 @@ class TestSequentialModules: ...@@ -2764,25 +2767,18 @@ class TestSequentialModules:
*, *,
requires_grad: bool, requires_grad: bool,
bias: bool, bias: bool,
normalization: str,
quantized_compute: bool, quantized_compute: bool,
quantized_weight: bool, quantized_weight: bool,
dtype: torch.dtype, dtype: torch.dtype,
quantization: Optional[str], quantization: Optional[str],
device: torch.device = "cuda", device: torch.device = "cuda",
hidden_size: int = 32, hidden_size: int = 256,
sequence_length: int = 512, sequence_length: int = 48,
batch_size: int = 4, batch_size: int = 4,
ffn_hidden_size: int = 64, ffn_hidden_size: int = 384,
layernorm_epsilon: float = 1e-5, layernorm_epsilon: float = 1e-5,
) -> None: ) -> None:
""" """LayerNorm/RMSNorm + Linear + SwiGLU + Linear"""
LayerNorm/RMSNorm + Linear + GELU + Linear
Note that this test checks only if the module runs
as when chaining multiple modules it is hard to validate
numerical accuracy.
"""
# Make input shape # Make input shape
in_shape = (sequence_length, batch_size, hidden_size) in_shape = (sequence_length, batch_size, hidden_size)
...@@ -2798,38 +2794,90 @@ class TestSequentialModules: ...@@ -2798,38 +2794,90 @@ class TestSequentialModules:
pytest.skip("Quantization scheme is not used") pytest.skip("Quantization scheme is not used")
# Random data # Random data
_, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization, quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=requires_grad, requires_grad=requires_grad,
) )
_, dy_test = make_reference_and_test_tensors( norm_w_ref, norm_w_test = make_reference_and_test_tensors(
hidden_size,
test_dtype=dtype,
test_device=device,
)
norm_b_ref, norm_b_test = make_reference_and_test_tensors(
hidden_size,
test_dtype=dtype,
test_device=device,
)
w1_ref, w1_test = make_reference_and_test_tensors(
(ffn_hidden_size, hidden_size),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
w2_ref, w2_test = make_reference_and_test_tensors(
(hidden_size, ffn_hidden_size // 2),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
b1_ref, b1_test, b2_ref, b2_test = None, None, None, None
if bias:
b1_ref, b1_test = make_reference_and_test_tensors(
ffn_hidden_size,
test_dtype=dtype,
test_device=device,
)
b2_ref, b2_test = make_reference_and_test_tensors(
hidden_size,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization, quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=False, requires_grad=False,
) )
with torch.no_grad():
for t in (norm_w_ref, norm_w_test, norm_b_ref, norm_b_test):
t -= 0.5
for t in (w1_ref, w1_test, w2_ref, w2_test):
t *= 1 / 64
if bias:
for t in (b1_ref, b1_test, b2_ref, b2_test):
t -= 0.5
for t in (dy_ref, dy_test):
t -= 0.5
# Reference implementation
x = x_ref
x = torch.nn.functional.layer_norm(
x,
(hidden_size,),
weight=norm_w_ref,
bias=norm_b_ref,
eps=layernorm_epsilon,
)
x = torch.nn.functional.linear(x, w1_ref, bias=b1_ref)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
x = torch.nn.functional.linear(x, w2_ref, bias=b2_ref)
y_ref = x
y_ref.backward(dy_ref)
# Implementation with fusible operations # Construct operations
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
if normalization == "LayerNorm": norm = te_ops.LayerNorm(
norm = te_ops.LayerNorm( hidden_size,
hidden_size, eps=layernorm_epsilon,
eps=layernorm_epsilon, device=device,
device=device, dtype=dtype,
dtype=dtype, )
)
else:
norm = te_ops.RMSNorm(
hidden_size,
eps=layernorm_epsilon,
device=device,
dtype=dtype,
)
ffn1 = te_ops.Linear( ffn1 = te_ops.Linear(
hidden_size, hidden_size,
ffn_hidden_size, ffn_hidden_size,
...@@ -2837,15 +2885,48 @@ class TestSequentialModules: ...@@ -2837,15 +2885,48 @@ class TestSequentialModules:
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
act = te_ops.GELU() act = te_ops.SwiGLU()
ffn2 = te_ops.Linear( ffn2 = te_ops.Linear(
ffn_hidden_size, ffn_hidden_size // 2,
hidden_size, hidden_size,
bias=bias, bias=bias,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
# Copy weights
with torch.no_grad():
norm.weight.copy_(norm_w_test)
norm.bias.copy_(norm_b_test)
ffn1.weight.copy_(w1_test)
ffn2.weight.copy_(w2_test)
if bias:
ffn1.bias.copy_(b1_test)
ffn2.bias.copy_(b2_test)
del norm_w_test, norm_b_test, w1_test, b1_test, w2_test, b2_test
# Fuse ops and perform forward and backward pass
forward = te_ops.Sequential(norm, ffn1, act, ffn2) forward = te_ops.Sequential(norm, ffn1, act, ffn2)
with te.autocast(enabled=quantized_compute, recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
"""Convert to FP64 CPU tensor"""
if tensor is None:
return None
out = tensor.detach().to(dtype=torch.float64, device="cpu")
out = out.requires_grad_(requires_grad=tensor.requires_grad)
return out
# Check values
tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking
torch.testing.assert_close(to_cpu(y_test), y_ref, **tols)
torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols)
torch.testing.assert_close(to_cpu(norm.weight.grad), norm_w_ref.grad, **tols)
torch.testing.assert_close(to_cpu(norm.bias.grad), norm_b_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.weight.grad), w2_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn1.weight.grad), w1_ref.grad, **tols)
if bias:
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "../../common.h" #include "../../common.h"
#include "../../transpose/transpose.h"
#include "../../utils.cuh" #include "../../utils.cuh"
#include "../fp8/gated_fp8.cuh" #include "../fp8/gated_fp8.cuh"
#include "../mxfp8/gated_mxfp8.cuh" #include "../mxfp8/gated_mxfp8.cuh"
...@@ -53,6 +54,20 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp ...@@ -53,6 +54,20 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp
} else { } else {
fp8::cast_gated_fwd<ParamOP, ActOP>(input, output, p, stream); fp8::cast_gated_fwd<ParamOP, ActOP>(input, output, p, stream);
} }
if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) {
// FP8 kernel only populates row-wise data, so perform
// transpose separately if needed
Tensor transpose_in, transpose_out, dummy;
transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_in.data.dptr = output->data.dptr;
transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()};
transpose_in.data.dtype = output->data.dtype;
transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_out.data.dptr = output->columnwise_data.dptr;
transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()};
transpose_out.data.dtype = output->data.dtype;
detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream);
}
break; break;
} }
case NVTE_MXFP8_1D_SCALING: { case NVTE_MXFP8_1D_SCALING: {
...@@ -129,6 +144,20 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte ...@@ -129,6 +144,20 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
} else { } else {
fp8::cast_gated_bwd<ParamOP, ActOP, DActOP>(gated_input, grad, output, p, stream); fp8::cast_gated_bwd<ParamOP, ActOP, DActOP>(gated_input, grad, output, p, stream);
} }
if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) {
// FP8 kernel only populates row-wise data, so perform
// transpose separately if needed
Tensor transpose_in, transpose_out, dummy;
transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_in.data.dptr = output->data.dptr;
transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()};
transpose_in.data.dtype = output->data.dtype;
transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_out.data.dptr = output->columnwise_data.dptr;
transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()};
transpose_out.data.dtype = output->data.dtype;
detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream);
}
break; break;
} }
case NVTE_MXFP8_1D_SCALING: { case NVTE_MXFP8_1D_SCALING: {
......
...@@ -150,17 +150,10 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, ...@@ -150,17 +150,10 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) { if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
} }
if (output_tensor->has_columnwise_data()) { if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
} }
quantize_transpose_vector_blockwise( quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
...@@ -298,17 +291,10 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens ...@@ -298,17 +291,10 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) { if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
} }
if (output_tensor->has_columnwise_data()) { if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
} }
quantize_transpose_vector_blockwise( quantize_transpose_vector_blockwise(
grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
......
...@@ -239,6 +239,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) ...@@ -239,6 +239,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type."); NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type.");
} }
NVTE_CHECK(!input.with_gemm_swizzled_scales, "Input must have scales in compact format.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match.");
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "../../util/math.h" #include "../../util/math.h"
#include "../../util/ptx.cuh" #include "../../util/ptx.cuh"
#include "../../utils.cuh" #include "../../utils.cuh"
#include "swizzle.cuh"
namespace transformer_engine { namespace transformer_engine {
namespace dispatch { namespace dispatch {
...@@ -51,7 +52,8 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 ...@@ -51,7 +52,8 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType, float (*DActOP)(float, const ParamOP &), typename IType, typename OType,
bool ROWWISE_SCALING, bool COLWISE_SCALING, size_t THREADS_PER_CHUNK> bool ROWWISE_SCALING, bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES,
size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) __global__ void __launch_bounds__(THREADS_PER_CHUNK)
quantize_gated_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, quantize_gated_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act, const __grid_constant__ CUtensorMap tensor_map_input_act,
...@@ -68,6 +70,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -68,6 +70,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
using IType2 = typename ptx::FPx2<IType>; using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>; using OType2 = typename ptx::FPx2<OType>;
using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx;
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1); static_assert(STAGES >= 1);
...@@ -355,14 +359,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -355,14 +359,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor // 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act = const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise; const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx = size_t scale_idx;
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows;
const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise; const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx] = biased_exponent_act; scales_colwise[scale_idx] = biased_exponent_act;
} }
...@@ -374,8 +381,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -374,8 +381,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const e8m0_t biased_exponent_gate = const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2; size_t scale_idx_gate;
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx_gate = gemm_swizzled_scale_idx(
global_scales_offset_X + gate_scale_idx_offset_colwise, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
}
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx_gate] = biased_exponent_gate; scales_colwise[scale_idx_gate] = biased_exponent_gate;
} }
...@@ -557,7 +570,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -557,7 +570,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const size_t stage_scales_offset_X = scales_offset_X_rowwise; const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; size_t scale_idx;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X,
DIVUP(output_cols, static_cast<size_t>(128)));
} else {
scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
}
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows;
const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise;
if (!out_of_bounds_rowwise) { if (!out_of_bounds_rowwise) {
...@@ -573,7 +593,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -573,7 +593,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_BWD) { if constexpr (IS_BWD) {
const e8m0_t biased_exponent_gate = const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
size_t scale_idx_gate;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
scale_idx_gate = gemm_swizzled_scale_idx(
stage_scales_offset_Y, stage_scales_offset_X + gate_scale_idx_offset_rowwise,
DIVUP(output_cols, static_cast<size_t>(128)));
} else {
scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
}
if (!out_of_bounds_rowwise) { if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx_gate] = biased_exponent_gate; scales_rowwise[scale_idx_gate] = biased_exponent_gate;
} }
...@@ -667,7 +696,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -667,7 +696,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
parity ^= 1; parity ^= 1;
destroy_barriers<STAGES>(mbar, is_master_thread); destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} } // NOLINT(readability/fn_size)
} // namespace gated_kernel } // namespace gated_kernel
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
...@@ -679,6 +709,7 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu ...@@ -679,6 +709,7 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
const bool USE_ROWWISE_SCALING = output->has_data(); const bool USE_ROWWISE_SCALING = output->has_data();
const bool USE_COLWISE_SCALING = output->has_columnwise_data(); const bool USE_COLWISE_SCALING = output->has_columnwise_data();
const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales;
if (USE_ROWWISE_SCALING) { if (USE_ROWWISE_SCALING) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
...@@ -722,113 +753,140 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu ...@@ -722,113 +753,140 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
gated_input.dtype(), IType, gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType, output->dtype(), OType,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES,
alignas(64) CUtensorMap tensor_map_grad{}; alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{}; alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{}; alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; alignas(64) CUtensorMap tensor_map_output_act_rowwise{};
alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; alignas(64) CUtensorMap tensor_map_output_gate_rowwise{};
alignas(64) CUtensorMap tensor_map_output_act_colwise{}; alignas(64) CUtensorMap tensor_map_output_act_colwise{};
alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; alignas(64) CUtensorMap tensor_map_output_gate_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size; constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size; constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
if constexpr (IS_BWD) { if constexpr (IS_BWD) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
cols, 0, input_type_bit_size); cols, 0, input_type_bit_size);
} }
const uint32_t tensor_stride_elems = output_cols; const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y, create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols * 2, 0, input_type_bit_size); BUFF_DIM_X, cols * 2, 0, input_type_bit_size);
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y, create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols * 2, cols, input_type_bit_size); BUFF_DIM_X, cols * 2, cols, input_type_bit_size);
if (USE_ROWWISE_SCALING) { if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
output_type_bit_size); output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
output_type_bit_size); output_type_bit_size);
} }
if (USE_COLWISE_SCALING) { if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols, create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
output_type_bit_size); output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows, create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows,
cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
output_type_bit_size); output_type_bit_size);
} }
const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X; const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X;
const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
const size_t buff_size_aligned_in = const size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out = const size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in; const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in;
const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; const size_t in_mem = grad_mem + in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out; const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0);
size_t out_mem = out_act_mem + out_gate_mem; size_t out_mem = out_act_mem + out_gate_mem;
if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; }
const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
switch (scaling_type) { // Zero out swizzled scales if padding is needed
case ScalingType::ROWWISE: { /// TODO (tmoon) Handle this within the cast kernel
auto kernel = if (with_gemm_swizzled_scales) {
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, true, constexpr size_t TILE_DIM_X = 128; // Tile dim in data buffer
false, THREADS_PER_CHUNK_NON_COLWISE>; constexpr size_t TILE_DIM_Y = 128;
NVTE_CHECK_CUDA(cudaFuncSetAttribute( if (cols % TILE_DIM_X != 0 || rows % TILE_DIM_Y != 0) {
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); if (USE_ROWWISE_SCALING) {
NVTE_CHECK_CUDA(cudaMemsetAsync(output->scale_inv.dptr, 0,
kernel<<<grid, block_size, shmem_size, stream>>>( output->scale_inv.buffer_size_bytes(), stream));
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, }
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, if (USE_COLWISE_SCALING) {
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, NVTE_CHECK_CUDA(
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); cudaMemsetAsync(output->columnwise_scale_inv.dptr, 0,
break; output->columnwise_scale_inv.buffer_size_bytes(), stream));
} }
case ScalingType::COLWISE: { }
auto kernel = }
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, false,
true, THREADS_PER_CHUNK_COLWISE>; switch (scaling_type) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute( case ScalingType::ROWWISE: {
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType,
kernel<<<grid, block_size, shmem_size, stream>>>( true, false, WITH_GEMM_SWIZZLED_SCALES,
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, THREADS_PER_CHUNK_NON_COLWISE>;
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, NVTE_CHECK_CUDA(cudaFuncSetAttribute(
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p);
break; kernel<<<grid, block_size, shmem_size, stream>>>(
} tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
case ScalingType::BIDIMENSIONAL: { tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
auto kernel = tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, true, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
true, THREADS_PER_CHUNK_NON_COLWISE>; scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaFuncSetAttribute( break;
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); }
case ScalingType::COLWISE: {
kernel<<<grid, block_size, shmem_size, stream>>>( auto kernel =
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, false, true, WITH_GEMM_SWIZZLED_SCALES,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, THREADS_PER_CHUNK_COLWISE>;
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaFuncSetAttribute(
break; kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
}
} NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) kernel<<<grid, block_size, shmem_size, stream>>>(
); // NOLINT(*) tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
break;
}
case ScalingType::BIDIMENSIONAL: {
auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType,
true, true, WITH_GEMM_SWIZZLED_SCALES,
THREADS_PER_CHUNK_NON_COLWISE>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
kernel<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
break;
}
} NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
} }
} // namespace mxfp8 } // namespace mxfp8
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file swizzle.cuh
* \brief Helper function for GEMM-swizzled scales
*/
#ifndef TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_SWIZZLE_CUH_
#define TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_SWIZZLE_CUH_
namespace transformer_engine {
namespace dispatch {
namespace mxfp8 {
namespace swizzle {
/*! \brief Convert compact scale indices into GEMM swizzled scale index
*
* MXFP8 GEMM expects scaling factors to be in a "swizzled" order
* (https://docs.nvidia.com/cuda/cublas/#d-block-scaling-factors-layout).
* This function converts indices from "compact" order (i.e. matching
* the FP8 data) to swizzled order.
*
*/
__device__ __forceinline__ size_t gemm_swizzled_scale_idx(size_t i, size_t j, size_t num_tiles_X) {
constexpr size_t TILE_DIM_X = 4; // Tile dim in scale buffer
constexpr size_t TILE_DIM_Y = 128;
constexpr size_t TILE_SIZE = TILE_DIM_X * TILE_DIM_Y;
const size_t tile_idx_X = j / TILE_DIM_X;
const size_t tile_idx_Y = i / TILE_DIM_Y;
const size_t idx_in_tile_X = j % TILE_DIM_X;
const size_t idx_in_tile_Y = i % TILE_DIM_Y;
size_t idx = (tile_idx_Y * num_tiles_X + tile_idx_X) * TILE_SIZE;
idx += (idx_in_tile_Y % 32) * 16 + (idx_in_tile_Y / 32) * 4 + idx_in_tile_X;
return idx;
}
} // namespace swizzle
} // namespace mxfp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_SWIZZLE_CUH_
...@@ -80,6 +80,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) ...@@ -80,6 +80,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
CheckInputTensor(input, "input"); CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output"); CheckOutputTensor(*output, "output");
NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type.");
NVTE_CHECK(!input.with_gemm_swizzled_scales, "Input must have scales in compact format.");
NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
......
...@@ -142,17 +142,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -142,17 +142,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr size_t buff_size_aligned_out_mxfp8 = constexpr size_t buff_size_aligned_out_mxfp8 =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_nvfp4_scales =
CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3);
constexpr size_t buff_size_mxfp8_scales =
(CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0);
constexpr size_t in_mem = buff_size_aligned_in; constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0);
constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0);
constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0);
constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0);
extern __shared__ char dynamic_shmem[]; extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem); uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
...@@ -167,8 +160,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -167,8 +160,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise_data); OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise_data);
fp8e4m3 *out_rowwise_scales_sh = fp8e4m3 *out_rowwise_scales_sh =
reinterpret_cast<fp8e4m3 *>(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); reinterpret_cast<fp8e4m3 *>(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
e8m0_t *out_colwise_scales_sh = reinterpret_cast<e8m0_t *>( (void)out_rowwise_scales_sh; // Suppress unused variable warning
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
...@@ -557,6 +549,7 @@ inline void quantize(const Tensor &input, const Tensor *noop, Tensor *output, cu ...@@ -557,6 +549,7 @@ inline void quantize(const Tensor &input, const Tensor *noop, Tensor *output, cu
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format.");
bool use_colwise_scaling = output->has_columnwise_data(); bool use_colwise_scaling = output->has_columnwise_data();
if (use_colwise_scaling) { if (use_colwise_scaling) {
......
...@@ -1179,6 +1179,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, ...@@ -1179,6 +1179,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format.");
if (return_transpose) { if (return_transpose) {
NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype),
......
...@@ -172,7 +172,17 @@ CommOverlapCore::~CommOverlapCore() { ...@@ -172,7 +172,17 @@ CommOverlapCore::~CommOverlapCore() {
TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset,
const std::vector<size_t> &chunk_shape) { const std::vector<size_t> &chunk_shape) {
// Check tensor format
const auto scaling_mode = source.scaling_mode(); const auto scaling_mode = source.scaling_mode();
NVTE_CHECK(scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_MXFP8_1D_SCALING,
"Unsupported tensor format (", to_string(scaling_mode), ").");
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
uint8_t has_swizzled_scales = false;
nvte_get_tensor_param_v2(source.data(), NVTETensorParam::kNVTEWithGEMMSwizzledScales,
&has_swizzled_scales, sizeof(has_swizzled_scales), nullptr);
NVTE_CHECK(has_swizzled_scales,
"Expected MXFP8 tensor to have scales in GEMM swizzled format.");
}
// Tensor dimensions // Tensor dimensions
std::vector<size_t> shape = shape_to_vector(source.shape()); std::vector<size_t> shape = shape_to_vector(source.shape());
......
...@@ -133,6 +133,23 @@ struct Tensor { ...@@ -133,6 +133,23 @@ struct Tensor {
NVTEScalingMode scaling_mode; NVTEScalingMode scaling_mode;
NVTETensor nvte_tensor; NVTETensor nvte_tensor;
/*! \brief Whether scaling factors are in format expected by GEMM
*
* Only meaningful for MXFP8 and NVFP4.
*/
bool with_gemm_swizzled_scales = false;
/*! Map from NVTETensorParam to parameter sizes */
static constexpr size_t attr_sizes[] = {
sizeof(NVTEBasicTensor), // kNVTERowwiseData
sizeof(NVTEBasicTensor), // kNVTEColumnwiseData
sizeof(NVTEBasicTensor), // kNVTEScale
sizeof(NVTEBasicTensor), // kNVTEAmax
sizeof(NVTEBasicTensor), // kNVTERowwiseScaleInv
sizeof(NVTEBasicTensor), // kNVTEColumnwiseScaleInv
sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax
sizeof(uint8_t) // kNVTEWithGEMMSwizzledScales
};
Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {} Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {}
...@@ -146,6 +163,7 @@ struct Tensor { ...@@ -146,6 +163,7 @@ struct Tensor {
scale_inv.clear(); scale_inv.clear();
columnwise_scale_inv.clear(); columnwise_scale_inv.clear();
scaling_mode = NVTE_DELAYED_TENSOR_SCALING; scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
with_gemm_swizzled_scales = false;
} }
explicit operator NVTETensor() const noexcept { return nvte_tensor; } explicit operator NVTETensor() const noexcept { return nvte_tensor; }
...@@ -389,22 +407,20 @@ struct QuantizationConfig { ...@@ -389,22 +407,20 @@ struct QuantizationConfig {
bool force_pow_2_scales = false; bool force_pow_2_scales = false;
float amax_epsilon = 0.0f; float amax_epsilon = 0.0f;
NVTETensor noop_tensor = nullptr; NVTETensor noop_tensor = nullptr;
Float8BlockScaleTensorFormat float8_block_scale_tensor_format =
Float8BlockScaleTensorFormat::GEMM_READY;
NVTETensor rng_state = nullptr; NVTETensor rng_state = nullptr;
bool nvfp4_2d_quantization = false; bool nvfp4_2d_quantization = false;
bool stochastic_rounding = false; bool stochastic_rounding = false;
bool use_fast_math = false; bool use_fast_math = false;
static constexpr size_t attr_sizes[] = { static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales sizeof(uint8_t), // force_pow_2_scales
sizeof(float), // amax_epsilon sizeof(float), // amax_epsilon
sizeof(NVTETensor), // noop_tensor sizeof(NVTETensor), // noop_tensor
sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format sizeof(Float8BlockScaleTensorFormat), // (deprecated)
sizeof(NVTETensor), // rng_seed and offset sizeof(NVTETensor), // rng_seed and offset
sizeof(bool), // nvfp4_2d_quantization sizeof(uint8_t), // nvfp4_2d_quantization
sizeof(bool), // stochastic_rounding sizeof(uint8_t), // stochastic_rounding
sizeof(bool) // use_fast_math sizeof(uint8_t) // use_fast_math
}; };
}; };
......
...@@ -36,6 +36,12 @@ void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA ...@@ -36,6 +36,12 @@ void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)"); " bytes)");
// bool size is implementation-dependent, so we explicitly specify
// uint8_t in the user-facing API.
auto bool_to_uint8 = [](bool in, void *out) {
*reinterpret_cast<uint8_t *>(out) = static_cast<uint8_t>(in);
};
// Write to buffer // Write to buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)");
const auto &config_ = *reinterpret_cast<const transformer_engine::MatmulConfig *>(config); const auto &config_ = *reinterpret_cast<const transformer_engine::MatmulConfig *>(config);
...@@ -47,19 +53,19 @@ void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA ...@@ -47,19 +53,19 @@ void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
std::memcpy(buf, &config_.dbias_tensor, attr_size); std::memcpy(buf, &config_.dbias_tensor, attr_size);
break; break;
case kNVTEMatmulConfigWithGELUEpilogue: case kNVTEMatmulConfigWithGELUEpilogue:
std::memcpy(buf, &config_.with_gelu_epilogue, attr_size); bool_to_uint8(config_.with_gelu_epilogue, buf);
break; break;
case kNVTEMatmulConfigWithDGELUEpilogue: case kNVTEMatmulConfigWithDGELUEpilogue:
std::memcpy(buf, &config_.with_dgelu_epilogue, attr_size); bool_to_uint8(config_.with_dgelu_epilogue, buf);
break; break;
case kNVTEMatmulConfigEpilogueAuxTensor: case kNVTEMatmulConfigEpilogueAuxTensor:
std::memcpy(buf, &config_.epilogue_aux_tensor, attr_size); std::memcpy(buf, &config_.epilogue_aux_tensor, attr_size);
break; break;
case kNVTEMatmulConfigUseSplitAccumulator: case kNVTEMatmulConfigUseSplitAccumulator:
std::memcpy(buf, &config_.use_split_accumulator, attr_size); bool_to_uint8(config_.use_split_accumulator, buf);
break; break;
case kNVTEMatmulConfigSMCount: case kNVTEMatmulConfigSMCount:
std::memcpy(buf, &config_.sm_count, attr_size); *reinterpret_cast<int32_t *>(buf) = static_cast<int32_t>(config_.sm_count);
break; break;
default: default:
NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast<int>(attr), ")"); NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
...@@ -79,6 +85,12 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA ...@@ -79,6 +85,12 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
" bytes)"); " bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// bool size is implementation-dependent, so we explicitly specify
// uint8_t in the user-facing API.
auto uint8_to_bool = [](const void *in, bool &out) {
out = static_cast<bool>(*reinterpret_cast<const uint8_t *>(in));
};
// Read from buffer // Read from buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)");
auto &config_ = *reinterpret_cast<transformer_engine::MatmulConfig *>(config); auto &config_ = *reinterpret_cast<transformer_engine::MatmulConfig *>(config);
...@@ -90,19 +102,19 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA ...@@ -90,19 +102,19 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
std::memcpy(&config_.dbias_tensor, buf, attr_size); std::memcpy(&config_.dbias_tensor, buf, attr_size);
break; break;
case kNVTEMatmulConfigWithGELUEpilogue: case kNVTEMatmulConfigWithGELUEpilogue:
std::memcpy(&config_.with_gelu_epilogue, buf, attr_size); uint8_to_bool(buf, config_.with_gelu_epilogue);
break; break;
case kNVTEMatmulConfigWithDGELUEpilogue: case kNVTEMatmulConfigWithDGELUEpilogue:
std::memcpy(&config_.with_dgelu_epilogue, buf, attr_size); uint8_to_bool(buf, config_.with_dgelu_epilogue);
break; break;
case kNVTEMatmulConfigEpilogueAuxTensor: case kNVTEMatmulConfigEpilogueAuxTensor:
std::memcpy(&config_.epilogue_aux_tensor, buf, attr_size); std::memcpy(&config_.epilogue_aux_tensor, buf, attr_size);
break; break;
case kNVTEMatmulConfigUseSplitAccumulator: case kNVTEMatmulConfigUseSplitAccumulator:
std::memcpy(&config_.use_split_accumulator, buf, attr_size); uint8_to_bool(buf, config_.use_split_accumulator);
break; break;
case kNVTEMatmulConfigSMCount: case kNVTEMatmulConfigSMCount:
std::memcpy(&config_.sm_count, buf, attr_size); config_.sm_count = static_cast<int>(*reinterpret_cast<const int32_t *>(buf));
break; break;
default: default:
NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast<int>(attr), ")"); NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
......
...@@ -23,11 +23,11 @@ struct MatmulConfig { ...@@ -23,11 +23,11 @@ struct MatmulConfig {
static constexpr size_t attr_sizes[] = { static constexpr size_t attr_sizes[] = {
sizeof(NVTETensor), // bias_tensor sizeof(NVTETensor), // bias_tensor
sizeof(NVTETensor), // dbias_tensor sizeof(NVTETensor), // dbias_tensor
sizeof(bool), // with_gelu_epilogue sizeof(uint8_t), // with_gelu_epilogue
sizeof(bool), // with_dgelu_epilogue sizeof(uint8_t), // with_dgelu_epilogue
sizeof(NVTETensor), // epilogue_aux_tensor sizeof(NVTETensor), // epilogue_aux_tensor
sizeof(bool), // use_split_accumulator sizeof(uint8_t), // use_split_accumulator
sizeof(int) // sm_count sizeof(int32_t) // sm_count
}; };
}; };
......
...@@ -503,6 +503,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -503,6 +503,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#if CUBLAS_VERSION >= 120800 #if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800, NVTE_CHECK(cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
// Check that scales are in expected format
NVTE_CHECK(inputA->with_gemm_swizzled_scales,
"MXFP8 scales are not in format expected by GEMM");
NVTE_CHECK(inputB->with_gemm_swizzled_scales,
"MXFP8 scales are not in format expected by GEMM");
// Configure cuBLAS scales
fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv); fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
...@@ -513,6 +521,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -513,6 +521,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&B_scale_inverse, sizeof(B_scale_inverse))); &B_scale_inverse, sizeof(B_scale_inverse)));
scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cublas_version() <= 120803) { if (cublas_version() <= 120803) {
...@@ -529,17 +538,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -529,17 +538,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#if CUBLAS_VERSION >= 120800 #if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800, NVTE_CHECK(cublas_version() >= 120800,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
// make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE
cublasDataType_t scale_type = CUDA_R_32F; // Check that scales are in expected format
NVTE_CHECK(inputA->with_gemm_swizzled_scales,
"NVFP4 block scales are not in format expected by GEMM");
NVTE_CHECK(inputB->with_gemm_swizzled_scales,
"NVFP4 block scales are not in format expected by GEMM");
// alpha and beta are device pointers to FP32
const cublasDataType_t scale_type = CUDA_R_32F;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
const cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
// Set pointer mode: alpha and beta are both device pointers
// https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t
cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));
// Configure cuBLAS scales
fp8e4m3 *A_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.A_scale_inv); fp8e4m3 *A_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.A_scale_inv);
fp8e4m3 *B_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.B_scale_inv); fp8e4m3 *B_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
...@@ -561,6 +575,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -561,6 +575,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK(cublas_version() >= 120900, NVTE_CHECK(cublas_version() >= 120900,
"FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ", "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
cublas_version()); cublas_version());
// Check that matrix formats are valid
NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
"Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling GEMM is supported, "
"but got 2D by 2D");
// Configure cuBLAS scales
float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv); float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv); float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
...@@ -569,9 +591,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -569,9 +591,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse))); &B_scale_inverse, sizeof(B_scale_inverse)));
NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
"Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D");
scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
: CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment