Unverified Commit d126cdd6 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Add primary weighs fp8 support for mxfp8 (#2055)



* Add primary weighs fp8 support for mxfp8
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Fix unit test and add better error log to unit test
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Move post all-gather processing out of for loop
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Add descriptions and ASCII diagrams for partial cast and partial amax functions
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor fix based on greptile bot
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix compilation errors due to arch-specific PTX instructions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unused noop flag from C API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Expose test_partial_cast
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Skip mxfp8 partial cast test if mxfp8 is not available
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Fix pytest error
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* pylint ignore unused manual_post_all_gather_processing
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Fix error when using is_mxfp8_available
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent cc42a577
......@@ -49,6 +49,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
......@@ -18,6 +18,7 @@ from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
MXFP8BlockScaling,
Format,
Recipe,
)
......@@ -25,9 +26,11 @@ import transformer_engine.pytorch as te
from transformer_engine.pytorch import (
is_fp8_available,
is_fp8_block_scaling_available,
is_mxfp8_available,
QuantizedTensor,
Float8Tensor,
Float8BlockwiseQTensor,
MXFP8Tensor,
)
from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data
......@@ -42,17 +45,21 @@ def _get_quantization_recipe(quantization) -> Recipe:
return Float8CurrentScaling(fp8_format=fp8_format)
elif quantization == "fp8_block":
return Float8BlockScaling(fp8_format=fp8_format)
elif quantization == "mxfp8":
return MXFP8BlockScaling()
else:
raise ValueError(f"Unsupported quantization: {quantization}")
def _get_raw_data(quantized_tensor):
def _get_raw_data(quantized_tensor, colwise=False):
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
if isinstance(quantized_tensor, Float8Tensor):
assert not colwise, "Float8Tensor does not support get colwise data"
assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute"
assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8"
return quantized_tensor._data
elif isinstance(quantized_tensor, Float8BlockwiseQTensor):
assert not colwise, "Float8BlockwiseQTensor does not support get colwise data"
assert hasattr(
quantized_tensor, "_rowwise_data"
), "Float8BlockwiseQTensor does not have _rowwise_data attribute"
......@@ -60,6 +67,23 @@ def _get_raw_data(quantized_tensor):
quantized_tensor._rowwise_data.dtype == torch.uint8
), "Float8BlockwiseQTensor _rowwise_data must be uint8"
return quantized_tensor._rowwise_data
elif isinstance(quantized_tensor, MXFP8Tensor):
if colwise:
assert hasattr(
quantized_tensor, "_columnwise_data"
), "MXFP8Tensor does not have columnwise_data attribute"
assert (
quantized_tensor._columnwise_data.dtype == torch.uint8
), "MXFP8Tensor columnwise_data must be uint8"
return quantized_tensor._columnwise_data
else:
assert hasattr(
quantized_tensor, "_rowwise_data"
), "MXFP8Tensor does not have rowwise_data attribute"
assert (
quantized_tensor._rowwise_data.dtype == torch.uint8
), "MXFP8Tensor rowwise_data must be uint8"
return quantized_tensor._rowwise_data
else:
raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}")
......@@ -229,38 +253,43 @@ class MiniZero_1:
end = start_offset + master_weight.numel()
weight.data.view(-1)[start:end].copy_(master_weight)
# -----------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -----------------------------------------------------------------------------------------
for i in range(len(self.weights)):
master_weight = self.master_weights[i]
if master_weight is None:
continue
start_offset = self.start_offsets[i]
if isinstance(self.weights[i], QuantizedTensor):
weight = _get_raw_data(self.weights[i])
else:
weight = self.weights[i]
weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()]
overlapping_start, overlapping_end = self.overlapping_areas[i]
self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice)
colwise_list = [False]
if isinstance(self.weights[0], MXFP8Tensor):
colwise_list.append(True)
# -----------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -----------------------------------------------------------------------------------------
dist.all_gather_into_tensor(
self.weight_buffer, self.weight_buffer_slice, group=self.dp_group
)
for colwise in colwise_list:
# -------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -------------------------------------------------------------------------------------
for i in range(len(self.weights)):
master_weight = self.master_weights[i]
if master_weight is None:
continue
start_offset = self.start_offsets[i]
if isinstance(self.weights[i], QuantizedTensor):
weight = _get_raw_data(self.weights[i], colwise)
else:
weight = self.weights[i]
weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()]
overlapping_start, overlapping_end = self.overlapping_areas[i]
self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice)
# -------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -------------------------------------------------------------------------------------
dist.all_gather_into_tensor(
self.weight_buffer, self.weight_buffer_slice, group=self.dp_group
)
# -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -----------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
if isinstance(weight, QuantizedTensor):
weight = _get_raw_data(weight)
weight.view(-1).data.copy_(self.weight_buffer[start:end])
# -------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
if isinstance(weight, QuantizedTensor):
weight = _get_raw_data(weight, colwise)
weight.view(-1).data.copy_(self.weight_buffer[start:end])
if self.manual_post_all_gather_processing:
quantized_weights = [
......@@ -285,9 +314,15 @@ class MiniFSDP:
else:
raw_data_list = [w.view(-1) for w in weights]
self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list)
if isinstance(weights[0], MXFP8Tensor):
self.flatten_columnwise = self.flatten_weight.clone()
else:
self.flatten_columnwise = None
# Split flattened weights into shards
self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank]
if self.flatten_columnwise is not None:
self.local_columnwise_shard = torch.chunk(self.flatten_columnwise, world_size)[rank]
self.local_main_grad_shard = torch.zeros_like(
self.local_weight_shard, dtype=torch.float32, device="cuda"
)
......@@ -319,14 +354,25 @@ class MiniFSDP:
self.shard_indices.append((None, None))
if isinstance(weights[idx], QuantizedTensor):
replace_raw_data(
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
)
if self.flatten_columnwise is not None:
new_rowwise_data = self.flatten_weight[start:end].view(weights[idx].shape)
new_rowwise_data.copy_(weights[idx]._rowwise_data)
weights[idx]._rowwise_data = new_rowwise_data
new_columnwise_data = self.flatten_columnwise[start:end].view(
weights[idx].shape
)
new_columnwise_data.copy_(weights[idx]._columnwise_data)
weights[idx]._columnwise_data = new_columnwise_data
else:
replace_raw_data(
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
)
else:
weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape)
# Initialize local model weights and high-precision master weights
self.local_weights = []
self.local_columnwise = []
self.master_weights = []
for i, weight in enumerate(self.weights):
weight_start, weight_end = self.weight_indices[i]
......@@ -334,6 +380,11 @@ class MiniFSDP:
if shard_start is not None and shard_end is not None:
local_weight_shard = self.local_weight_shard[shard_start:shard_end]
self.local_weights.append(local_weight_shard)
if self.flatten_columnwise is not None:
local_columnwise_shard = self.local_columnwise_shard[shard_start:shard_end]
else:
local_columnwise_shard = None
self.local_columnwise.append(local_columnwise_shard)
if isinstance(weight, QuantizedTensor):
high_precision_init_val = weight.get_high_precision_init_val().view(-1)
......@@ -345,6 +396,7 @@ class MiniFSDP:
self.master_weights.append(master_weight_shard)
else:
self.local_weights.append(None)
self.local_columnwise.append(None)
self.master_weights.append(None)
setattr(
weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda")
......@@ -415,12 +467,12 @@ class MiniFSDP:
# Step 3: Cast master weights to FP8 or BF16 precision
if isinstance(self.weights[0], QuantizedTensor):
local_weights = []
for local_weight in self.local_weights:
if local_weight is None:
local_weights.append(None)
continue
local_weights.append(local_weight)
for i, local_weight in enumerate(self.local_weights):
if self.flatten_columnwise is not None:
local_columnwise = self.local_columnwise[i]
local_weights.append((local_weight, local_columnwise))
else:
local_weights.append(local_weight)
cast_master_weights_to_fp8(
self.weights,
......@@ -442,6 +494,10 @@ class MiniFSDP:
dist.all_gather_into_tensor(
self.flatten_weight, self.local_weight_shard, group=self.dp_group
)
if self.flatten_columnwise is not None:
dist.all_gather_into_tensor(
self.flatten_columnwise, self.local_columnwise_shard, group=self.dp_group
)
if self.manual_post_all_gather_processing:
quantized_weights = [
......@@ -513,15 +569,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 32, **linear_kwargs),
te.Linear(256 + 32, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 32, **linear_kwargs),
te.Linear(256 + 32, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
......@@ -546,7 +602,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
w.main_grad.zero_()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
torch.randn(32, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
......@@ -577,7 +633,9 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
assert torch.allclose(
loss_fp8, loss, atol=0, rtol=0
), f"Loss mismatch at rank {rank}, step {i} for {quantization}"
def _test_fsdp_cast_master_weights_to_fp8(
......@@ -609,15 +667,15 @@ def _test_fsdp_cast_master_weights_to_fp8(
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 32, **linear_kwargs),
te.Linear(256 + 32, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 32, **linear_kwargs),
te.Linear(256 + 32, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
......@@ -631,12 +689,12 @@ def _test_fsdp_cast_master_weights_to_fp8(
)
optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group)
for _ in range(100):
for i in range(100):
optimizer_fp8.zero_grad()
optimizer.zero_grad()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
torch.randn(32, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
......@@ -667,7 +725,9 @@ def _test_fsdp_cast_master_weights_to_fp8(
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
assert torch.allclose(
loss_fp8, loss, atol=0, rtol=0
), f"Loss mismatch at rank {rank}, step {i} for {quantization} (FSDP)"
def run_parallel_tests() -> None:
......@@ -698,6 +758,8 @@ def run_parallel_tests() -> None:
quantizations.extend(["fp8", "fp8_cs"])
if is_fp8_block_scaling_available():
quantizations.append("fp8_block")
if is_mxfp8_available():
quantizations.append("mxfp8")
manual_post_all_gather_processings = [False, True]
......
......@@ -7,6 +7,7 @@ import torch
import transformer_engine.pytorch
import transformer_engine_torch as tex
from transformer_engine.pytorch import is_mxfp8_available
from transformer_engine.pytorch.optimizers import MultiTensorApply
from references.quantize_scale_calc import scale_from_amax_tensor
......@@ -23,6 +24,7 @@ input_size_pairs = [
(555, 33333),
]
appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply(33333)]
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
@pytest.mark.parametrize("input_size_pair", input_size_pairs)
......@@ -259,3 +261,35 @@ def test_multi_tensor_compute_scale_and_scale_inv(
)
torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0)
torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0)
@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
def test_multi_tensor_compute_scale_inv_e8m0(input_size_pair, applier, repeat):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
a = torch.randn([sizea], dtype=torch.bfloat16, device=device).abs()
b = torch.randn([sizeb], dtype=torch.bfloat16, device=device).abs()
amax_list = []
for _ in range(repeat):
amax_list += [a.clone(), b.clone()]
scale_inv_list = [torch.empty_like(x).to(torch.uint8) for x in amax_list]
applier(
tex.multi_tensor_compute_scale_inv_e8m0,
None, # overflow_buf
[amax_list, scale_inv_list],
)
max_fp8 = torch.finfo(torch.float8_e4m3fn).max
for amax, scale_inv in zip(amax_list, scale_inv_list):
scale_inv_u32 = (amax.float() / max_fp8).view(torch.int)
exponent = scale_inv_u32 // 2**23
mantissa = scale_inv_u32 & 0x7FFFFF
exponent += (
((mantissa > 0) & (exponent != 0xFE)) & ~((exponent == 0) & (mantissa <= 0x400000))
).to(torch.int)
torch.testing.assert_close(exponent.to(torch.uint8), scale_inv)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0
from transformer_engine.pytorch import is_mxfp8_available
from transformer_engine.pytorch.optimizers.multi_tensor_apply import multi_tensor_applier
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
def compute_partial_amax_reference(inp, amax_rowwise, amax_colwise, h, w, start_offset):
n = inp.view(-1).size(0)
if n == h * w:
full = inp.view(-1)
else:
full = torch.zeros(h * w, dtype=inp.dtype, device=inp.device)
full[start_offset : start_offset + n].copy_(inp)
full = torch.abs(full)
_amax_rowwise, _ = torch.max(full.view(h, w // 32, 32), dim=2)
amax_rowwise[:h, : (w // 32)].copy_(_amax_rowwise)
_amax_colwise, _ = torch.max(full.view(h // 32, 32, w), dim=1)
amax_colwise[: (h // 32), :w].copy_(_amax_colwise)
def partial_cast_reference(
inp, rowwise_out, colwise_out, rowwise_inv_scale, colwise_inv_scale, h, w, start_offset
):
rowwise_scale = ((254 - rowwise_inv_scale.int()) * 2**23).view(torch.float32)
colwise_scale = ((254 - colwise_inv_scale.int()) * 2**23).view(torch.float32)
n = inp.view(-1).size(0)
if n == h * w:
full = inp
else:
full = torch.empty(h * w, dtype=inp.dtype, device=inp.device)
full[start_offset : start_offset + n].copy_(inp)
full = full.float()
rowwise_scale = rowwise_scale[:h, : (w // 32)].contiguous().float()
colwise_scale = colwise_scale[: (h // 32), :w].contiguous().float()
scaled = (full.view(-1, 32) * rowwise_scale.view(-1, 1)).view(-1)
rowwise_out.copy_(
scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(rowwise_out.dtype)
)
scaled = (full.view(h // 32, 32, w) * colwise_scale.view(h // 32, 1, w)).view(-1)
colwise_out.copy_(
scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(colwise_out.dtype)
)
def run_one_case(n, h, w, start_offset):
inp = torch.randn(n, dtype=torch.bfloat16, device="cuda")
rowwise_padding = [128, 4]
colwise_padding = [4, 128]
def _pad(x, padding):
return (x + padding - 1) // padding * padding
rowwise_shape = [_pad(h, rowwise_padding[0]), _pad(w // 32, rowwise_padding[1])]
colwise_shape = [_pad(h // 32, colwise_padding[0]), _pad(w, colwise_padding[1])]
# Partial amax cuda kernel
amax_rowwise = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device)
amax_colwise = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device)
tex.mxfp8_scaling_compute_partial_amax(inp, amax_rowwise, amax_colwise, h, w, start_offset)
# Partial amax pytorch reference
amax_rowwise_ref = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device)
amax_colwise_ref = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device)
compute_partial_amax_reference(inp, amax_rowwise_ref, amax_colwise_ref, h, w, start_offset)
# Check partial amax
torch.testing.assert_close(amax_rowwise, amax_rowwise_ref, atol=0, rtol=0)
torch.testing.assert_close(amax_colwise, amax_colwise_ref, atol=0, rtol=0)
# Calculate scales and scale_invs
scale_inv_rowwise = torch.empty_like(amax_rowwise).to(torch.uint8)
scale_inv_colwise = torch.empty_like(amax_colwise).to(torch.uint8)
multi_tensor_applier(
multi_tensor_compute_scale_inv_e8m0,
None,
[
[amax_rowwise, amax_colwise],
[scale_inv_rowwise, scale_inv_colwise],
],
)
# Partial cast cuda kernel
output_rowwise = torch.empty_like(inp).to(torch.uint8)
output_colwise = torch.empty_like(inp).to(torch.uint8)
tex.mxfp8_scaling_partial_cast(
inp,
output_rowwise,
output_colwise,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)
# Partial cast pytorch reference
output_rowwise_ref = torch.empty_like(inp).to(torch.uint8)
output_colwise_ref = torch.empty_like(inp).to(torch.uint8)
partial_cast_reference(
inp,
output_rowwise_ref,
output_colwise_ref,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)
# Check partial cast results
torch.testing.assert_close(output_rowwise, output_rowwise_ref, atol=0, rtol=0)
torch.testing.assert_close(output_colwise, output_colwise_ref, atol=0, rtol=0)
@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
def test_mxfp8_scaling_partial_cast():
torch.cuda.manual_seed(1234)
run_one_case(3, 32, 64, 31)
run_one_case(64 * 64 - 2, 64, 64, 1)
run_one_case(16384 * 6144, 16384, 6144, 0)
run_one_case(32768, 256, 128, 0)
run_one_case(131072, 768, 256, 0)
run_one_case(65536, 768, 256, 131072)
run_one_case(98304, 128, 768, 0)
......@@ -125,7 +125,6 @@ list(APPEND transformer_engine_cpp_sources
list(APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
......@@ -167,16 +166,18 @@ list(APPEND transformer_engine_cuda_sources
comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
cast/cast.cu
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
......
......@@ -265,6 +265,21 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream);
/*! \brief Compute E8M0 scale_inv for a list of tensors.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
cudaStream_t stream);
/*! \brief Split a tensor along dimension 0 and compute the amax for each split.
*
* This function is experimental and the API is not stable.
......
......@@ -111,17 +111,200 @@ void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output,
void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config,
cudaStream_t stream);
/*! \brief Compute partial amax for FP8 blockwise scaling.
*
* This function computes the maximum absolute values for each block of the original tensor.
* `inp` contains a continuous segment from the flattened original tensor. For each block,
* if it overlaps with the range [start_offset, start_offset+inp.length), the amax is
* computed from inp; otherwise, the amax is set to 0.
*
* Example: Original tensor (logically 512x512) divided into 16 blocks of size 128x128.
* `inp` contains continuous elements starting from position start_offset
* in the flattened original tensor.
*
* Logical view - Original Tensor (e.g., 512x512) divided into 16 blocks of size 128x128:
* ┌─────────┬─────────┬─────────┬─────────┐
* │ Block0 │ Block1 │ Block2 │ Block3 │ Each block: 128x128
* │ 128x128 │ 128x128 │ 128x128 │ 128x128 │
* ├─────────┼─────────┼─────────┼─────────┤
* │ Block4 │ Block5 │ Block6 │ Block7 │
* ├─────────┼─────────┼─────────┼─────────┤
* │ Block8 │ Block9 │ Block10 │ Block11 │
* ├─────────┼─────────┼─────────┼─────────┤
* │ Block12 │ Block13 │ Block14 │ Block15 │
* └─────────┴─────────┴─────────┴─────────┘
*
* Physical view - Flattened in row-major order:
* ┌────────────────────────────────────────────────────────────────┐
* │[0...128][128...256][256...384][384...512]...[261632...262143] │
* └────────────────────────────────────────────────────────────────┘
* ^ ^
* start_offset start_offset + inp.length
*
* For each 128x128 block, compute amax:
* - If the block overlaps with [start_offset, start_offset+inp.length), compute amax
* - If the block is completely outside this range, set amax = 0
*
* amax output (one value per 128x128 block), block 1 and block 2 are non-zero because they
* overlap with the [start_offset, start_offset+inp.length) range:
* ┌───────┬───────┬───────┬───────┐
* │ 0 │ amax │ amax │ 0 │ Block0-3
* ├───────┼───────┼───────┼───────┤
* │ 0 │ 0 │ 0 │ 0 │ Block4-7
* ├───────┼───────┼───────┼───────┤
* │ 0 │ 0 │ 0 │ 0 │ Block8-11
* ├───────┼───────┼───────┼───────┤
* │ 0 │ 0 │ 0 │ 0 │ Block12-15
* └───────┴───────┴───────┴───────┘
*
* \param[in] inp Input tensor (continuous slice of flattened original tensor).
* \param[in,out] amax Output tensor for maximum absolute values per block.
* \param[in] h Height dimension of the logical tensor.
* \param[in] w Width dimension of the logical tensor.
* \param[in] amax_stride_h Stride in height dimension for amax tensor.
* \param[in] amax_stride_w Stride in width dimension for amax tensor.
* \param[in] start_offset Starting offset in the flattened tensor.
* \param[in] block_len Length of a quantization block to process.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h,
size_t w, size_t amax_stride_h,
size_t amax_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream);
/*! \brief Perform partial FP8 casting with blockwise scaling.
*
* This function casts the input tensor to FP8 format using blockwise scaling factors.
* `inp` contains a continuous segment from the flattened original tensor.
*
* \param[in] inp Input tensor.
* \param[out] out Output tensor in FP8 format.
* \param[in] scale Scaling factors per block.
* \param[in] h Height dimension of the tensor.
* \param[in] w Width dimension of the tensor.
* \param[in] scale_stride_h Stride in height dimension for scale tensor.
* \param[in] scale_stride_w Stride in width dimension for scale tensor.
* \param[in] start_offset Starting offset for partial computation.
* \param[in] block_len Length of the block to process.
* \param[in] out_dtype Output FP8 datatype.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
const NVTETensor scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream);
/*! \brief Compute partial amax for MXFP8 scaling.
*
* This function computes the maximum absolute values along both row and column dimensions.
* input contains a continuous segment from the flattened original tensor. For each row/column
* block, if it overlaps with the range starting from start_offset, the amax is computed from
* `input`; otherwise, the amax is set to 0.
*
* Example: Original tensor (64 rows x 64 cols).
* Rowwise amax granularity: 1x32 (each row divided into 2 blocks)
* Columnwise amax granularity: 32x1 (each column divided into 2 blocks)
* input contains a continuous segment starting from start_offset.
*
* Logical view - Original Tensor (64x64) with 1x32 and 32x1 blocks:
*
* Rowwise blocks (1x32): Each row has 2 blocks
* ┌──────────────┬──────────────┐
* row0 │ Block_r0_0 │ Block_r0_1 │ (cols 0-31, 32-63)
* ├──────────────┼──────────────┤
* row1 │ Block_r1_0 │ Block_r1_1 │
* ├──────────────┼──────────────┤
* ... │ ... │ ... │
* ├──────────────┼──────────────┤
* row63│ Block_r63_0 │ Block_r63_1 │
* └──────────────┴──────────────┘
*
* Columnwise blocks (32x1): Each column has 2 blocks
* ┌───┬───┬─────┬───┬───┐
* │c0 │c1 │ ... │c62│c63│
* ┌────┼───┼───┼─────┼───┼───┤
* │Blk0│ │ │ │ │ │ rows 0-31
* ├────┼───┼───┼─────┼───┼───┤
* │Blk1│ │ │ │ │ │ rows 32-63
* └────┴───┴───┴─────┴───┴───┘
*
* Physical view - Flattened in row-major order:
* Total elements: 64*64 = 4096
* ┌──────────────────────────────────────────────────────┐
* │[0...63][64...127][128...191]...[4032...4095] │
* └──────────────────────────────────────────────────────┘
* ^ ^
* start_offset=60 start_offset + input.length=130
*
* Row-wise amax output (one value per 1x32 block):
* ┌────────┬────────┐
* │ amax │ amax │ row0 (block0 and block1 partially covered)
* ├────────┼────────┤
* │ 0 │ 0 │ row1 (not covered)
* ├────────┼────────┤
* │ ... │ ... │
* ├────────┼────────┤
* │ 0 │ 0 │ row63 (not covered)
* └────────┴────────┘
*
* Column-wise amax output (one value per 32x1 block):
* ┌────────┬────────┬────────┬────────┬────────┬────────┬────────┐
* │ amax │ amax │ amax │ amax │ amax │ amax │ amax │ ... row 0-31
* ├────────┼────────┼────────┼────────┼────────┼────────┼────────┤
* │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ ... row 32-62
* └────────┴────────┴────────┴────────┴────────┴────────┴────────┘
* col0 col1 col2 col3 col4 col5 col6
*
* For each 1x32 or 32x1 block, if it overlaps with [start_offset, start_offset+input.length),
* compute amax; otherwise set to 0.
*
* \param[in] input Input tensor (continuous segment of flattened original tensor).
* \param[in,out] amax_rowwise Output tensor for row-wise maximum absolute values.
* \param[in,out] amax_colwise Output tensor for column-wise maximum absolute values.
* \param[in] rows Number of rows in the logical tensor.
* \param[in] cols Number of columns in the logical tensor.
* \param[in] start_offset Starting offset in the flattened tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_mxfp8_scaling_compute_partial_amax(const NVTETensor input, NVTETensor amax_rowwise,
NVTETensor amax_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream);
/*! \brief Perform partial MXFP8 casting.
*
* This function casts the input tensor to MXFP8 format, producing both row-wise and
* column-wise scaled outputs. input contains a continuous segment from the flattened
* original tensor.
*
* \param[in] input Input (continuous segment of flattened original tensor).
* \param[out] output_rowwise Output tensor with row-wise scaling (MXFP8 format).
* \param[out] output_colwise Output tensor with column-wise scaling (MXFP8 format).
* \param[in] scale_inv_rowwise Inverse scaling factors for row-wise scaling.
* \param[in] scale_inv_colwise Inverse scaling factors for column-wise scaling.
* \param[in] rows Number of rows in the logical tensor.
* \param[in] cols Number of columns in the logical tensor.
* \param[in] start_offset Starting offset in the flattened tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_rowwise,
NVTETensor output_colwise, const NVTETensor scale_inv_rowwise,
const NVTETensor scale_inv_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream);
/*! \brief Compute per-tensor scaling factor for NVFP4 format.
*
* This function computes the scaling factor (alpha) for NVFP4 quantization based
* on the input tensors A and B, with options for using row-wise amax values.
*
* \param[in] inpA Input tensor A.
* \param[in] use_rowwise_amax_A Whether to use row-wise amax for tensor A.
* \param[in] inpB Input tensor B.
* \param[in] use_rowwise_amax_B Whether to use row-wise amax for tensor B.
* \param[in] alpha_in Input scaling factor.
* \param[out] alpha_out Output scaling factor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A,
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out, cudaStream_t stream);
......
......@@ -14,6 +14,7 @@
#include <sstream>
#include "../recipe/recipe_common.cuh"
#include "../util/ptx.cuh"
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
......@@ -55,6 +56,28 @@ struct ComputeScaleAndScaleInvFunctor {
}
};
struct ComputeScaleInvE8M0Functor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *unused,
TensorListMetadata<2> &tl) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
bf16 *amax = reinterpret_cast<bf16 *>(tl.addresses[0][tensor_loc]);
amax += chunk_idx * chunk_size;
e8m0_t *scale_inv = reinterpret_cast<e8m0_t *>(tl.addresses[1][tensor_loc]);
scale_inv += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) {
scale_inv[i_start] = ptx::float_to_e8m0(static_cast<float>(amax[i_start]) *
Quantized_Limits<fp8e4m3>::max_norm_rcp);
}
}
};
void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
float max_fp8, bool force_pow_2_scales,
......@@ -65,6 +88,19 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size,
std::vector<std::vector<Tensor *>> tensor_lists,
cudaStream_t stream) {
NVTE_CHECK(tensor_lists[0][0]->data.dtype == DType::kBFloat16, "amax should be bf16");
auto scale_inv_dtype = tensor_lists[1][0]->data.dtype;
NVTE_CHECK(scale_inv_dtype == DType::kByte || scale_inv_dtype == DType::kFloat8E8M0,
"scale_inv should be e8m0/uint8");
Tensor dummy;
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, dummy, tensor_lists, ComputeScaleInvE8M0Functor(),
stream);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_compute_scale
} // namespace transformer_engine
......@@ -82,3 +118,15 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, stream);
}
void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_compute_scale_inv_e8m0_cuda);
using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_inv_e8m0_cuda(
chunk_size, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include "../common.h"
#include "../util/ptx.cuh"
#include "../utils.cuh"
namespace transformer_engine {
namespace mxfp8_scaling_recipe {
constexpr int rowwise_row_padding = 128; // Row padding of rowwise_scale and rowwise_amax
constexpr int rowwise_col_padding = 4; // Column padding of rowwise_scale and rowwise_amax
constexpr int colwise_row_padding = 4; // Row padding of colwise_scale and colwise_amax
constexpr int colwise_col_padding = 128; // Column padding of colwise_scale and colwise_amax
constexpr int kRowsPerTile = 32; // Rows each block processes
constexpr int kColsPerTile = 128; // Columns each block processes
constexpr int kThreadsPerBlock = 128;
template <typename IType>
__global__ void __launch_bounds__(kThreadsPerBlock)
mxfp8_scaling_compute_partial_amax_kernel(const IType *input, IType *amax_rowwise,
IType *amax_colwise, int amax_rowwise_stride,
int amax_colwise_stride, int rows, int cols,
size_t start_offset, size_t len) {
__shared__ float smem_amax_rowwise[kRowsPerTile][kColsPerTile / 32];
size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
int warp_idx = threadIdx.x / 32;
int lane_idx = threadIdx.x % 32;
int c = blockIdx.x * kColsPerTile + threadIdx.x;
int r = blockIdx.y * kRowsPerTile;
float col_amax = 0.0f;
#pragma unroll
for (int i = 0; i < kRowsPerTile; i++) {
size_t idx = r * cols + c;
float row_amax = 0.0f;
if (r < rows && c < cols && idx >= start_offset && idx < end_offset) {
float abs_input = fabs(static_cast<float>(input_minus_offset[idx]));
row_amax = fmaxf(row_amax, abs_input);
col_amax = fmaxf(col_amax, abs_input);
}
#pragma unroll
for (int delta = 16; delta > 0; delta /= 2) {
float other_row_amax = __shfl_down_sync(0xFFFFFFFF, row_amax, delta);
row_amax = fmaxf(row_amax, other_row_amax);
}
if (lane_idx == 0) {
smem_amax_rowwise[i][warp_idx] = row_amax;
}
r++;
}
amax_colwise[blockIdx.y * amax_colwise_stride + c] = static_cast<IType>(col_amax);
__syncthreads();
int r_ = threadIdx.x / (kColsPerTile / 32); // rows in shared memory
int c_ = threadIdx.x % (kColsPerTile / 32); // cols in shared memory
r = blockIdx.y * kRowsPerTile + r_;
c = blockIdx.x * kColsPerTile / 32 + c_;
amax_rowwise[r * amax_rowwise_stride + c] = static_cast<IType>(smem_amax_rowwise[r_][c_]);
}
template <typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock)
mxfp8_scaling_partial_cast_kernel(const IType *input, OType *output_rowwise,
OType *output_colwise, const e8m0_t *scale_inv_rowwise,
const e8m0_t *scale_inv_colwise, int scale_inv_rowwise_stride,
int scale_inv_colwise_stride, int rows, int cols,
size_t start_offset, size_t len) {
__shared__ float smem_scales_rowwise[kRowsPerTile][kColsPerTile / 32];
__shared__ float smem_scales_colwise[kColsPerTile];
// Load scales_rowwise
{
int r_ = threadIdx.x / (kColsPerTile / 32); // rows in shared memory
int c_ = threadIdx.x % (kColsPerTile / 32); // cols in shared memory
int r = blockIdx.y * kRowsPerTile + r_;
int c = blockIdx.x * kColsPerTile / 32 + c_;
size_t idx = r * scale_inv_rowwise_stride + c;
smem_scales_rowwise[r_][c_] = ptx::exp2f_rcp(scale_inv_rowwise[idx]);
}
// Load scales_colwise
{
int c_ = threadIdx.x;
int r = blockIdx.y * kRowsPerTile / 32;
int c = blockIdx.x * kColsPerTile + c_;
size_t idx = r * scale_inv_colwise_stride + c;
smem_scales_colwise[c_] = ptx::exp2f_rcp(scale_inv_colwise[idx]);
}
__syncthreads();
size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
OType *output_rowwise_minus_offset = output_rowwise - start_offset;
OType *output_colwise_minus_offset = output_colwise - start_offset;
int warp_idx = threadIdx.x / 32;
int lane_idx = threadIdx.x % 32;
int c = blockIdx.x * kColsPerTile + threadIdx.x;
int r = blockIdx.y * kRowsPerTile;
#pragma unroll
for (int i = 0; i < kRowsPerTile; i++) {
size_t idx = r * cols + c;
if (r < rows && c < cols && idx >= start_offset && idx < end_offset) {
float inp = static_cast<float>(input_minus_offset[idx]);
OType out_rowwise = static_cast<OType>(inp * smem_scales_rowwise[i][warp_idx]);
OType out_colwise = static_cast<OType>(inp * smem_scales_colwise[threadIdx.x]);
output_rowwise_minus_offset[idx] = out_rowwise;
output_colwise_minus_offset[idx] = out_colwise;
}
r++;
}
}
void mxfp8_scaling_compute_partial_amax(const Tensor input, Tensor amax_rowwise,
Tensor amax_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream) {
NVTE_CHECK(rows % 32 == 0, "rows must be divisible by 32");
NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32");
NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor");
NVTE_CHECK(start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols,
"Invalid start_offset");
NVTE_CHECK(amax_rowwise.data.shape.size() == 2, "amax_rowwise must be a 2D tensor");
NVTE_CHECK(amax_rowwise.data.shape[0] % rowwise_row_padding == 0,
"Wrong padding of amax_rowwise's rows");
NVTE_CHECK(amax_rowwise.data.shape[0] >= rows, "Invalid rows");
NVTE_CHECK(amax_rowwise.data.shape[1] % rowwise_col_padding == 0,
"Wrong padding of amax_rowwise's cols");
NVTE_CHECK(amax_rowwise.data.shape[1] >= cols / 32, "Invalid cols");
NVTE_CHECK(amax_rowwise.dtype() == input.dtype(), "Wrong dtype of amax_rowwise");
NVTE_CHECK(amax_colwise.data.shape.size() == 2, "amax_colwise must be a 2D tensor");
NVTE_CHECK(amax_colwise.data.shape[0] % colwise_row_padding == 0,
"Wrong padding of amax_colwise's rows");
NVTE_CHECK(amax_colwise.data.shape[0] >= rows / 32, "Invalid rows");
NVTE_CHECK(amax_colwise.data.shape[1] % colwise_col_padding == 0,
"Wrong padding of amax_colwise's cols");
NVTE_CHECK(amax_colwise.data.shape[1] >= cols, "Invalid cols");
NVTE_CHECK(amax_colwise.dtype() == input.dtype(), "Wrong dtype of amax_colwise");
int blocks_x = (cols + kColsPerTile - 1) / kColsPerTile;
int blocks_y = (rows + kRowsPerTile - 1) / kRowsPerTile;
dim3 grid(blocks_x, blocks_y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input.dtype(), IType,
mxfp8_scaling_compute_partial_amax_kernel<IType><<<grid, kColsPerTile, 0, stream>>>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<IType *>(amax_rowwise.data.dptr),
reinterpret_cast<IType *>(amax_colwise.data.dptr), amax_rowwise.data.shape[1],
amax_colwise.data.shape[1], rows, cols, start_offset, input.data.shape[0]);)
}
void mxfp8_scaling_partial_cast(const Tensor input, Tensor output_rowwise, Tensor output_colwise,
const Tensor scale_inv_rowwise, const Tensor scale_inv_colwise,
int rows, int cols, size_t start_offset, cudaStream_t stream) {
NVTE_CHECK(rows % 32 == 0, "rows must be divisible by 32");
NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32");
NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor");
NVTE_CHECK(start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols,
"Invalid start_offset");
NVTE_CHECK(output_rowwise.data.shape.size() == 1, "output_rowwise must be a 1D tensor");
NVTE_CHECK(output_colwise.data.shape.size() == 1, "output_colwise must be a 1D tensor");
NVTE_CHECK(output_rowwise.data.shape[0] == input.data.shape[0],
"Size of input and output_rowwise mismatch");
NVTE_CHECK(output_colwise.data.shape[0] == input.data.shape[0],
"Size of input and output_colwise mismatch");
NVTE_CHECK(output_rowwise.dtype() == DType::kFloat8E4M3 || output_rowwise.dtype() == DType::kByte,
"output_rowwise should be e4m3 or uint8");
NVTE_CHECK(output_colwise.dtype() == DType::kFloat8E4M3 || output_colwise.dtype() == DType::kByte,
"output_colwise should be e4m3 or uint8");
NVTE_CHECK(scale_inv_rowwise.data.shape.size() == 2, "scale_inv_rowwise must be a 2D tensor");
NVTE_CHECK(scale_inv_rowwise.data.shape[0] % rowwise_row_padding == 0,
"Wrong padding of scale_inv_rowwise's rows");
NVTE_CHECK(scale_inv_rowwise.data.shape[0] >= rows, "Invalid rows");
NVTE_CHECK(scale_inv_rowwise.data.shape[1] % rowwise_col_padding == 0,
"Wrong padding of scale_inv_rowwise's cols");
NVTE_CHECK(scale_inv_rowwise.data.shape[1] >= cols / 32, "Invalid cols");
NVTE_CHECK(scale_inv_rowwise.dtype() == DType::kByte, "Wrong dtype of scale_inv_rowwise");
NVTE_CHECK(scale_inv_colwise.data.shape.size() == 2, "scale_inv_colwise must be a 2D tensor");
NVTE_CHECK(scale_inv_colwise.data.shape[0] % colwise_row_padding == 0,
"Wrong padding of scale_inv_colwise's rows");
NVTE_CHECK(scale_inv_colwise.data.shape[0] >= rows / 32, "Invalid rows");
NVTE_CHECK(scale_inv_colwise.data.shape[1] % colwise_col_padding == 0,
"Wrong padding of scale_inv_colwise's cols");
NVTE_CHECK(scale_inv_colwise.data.shape[1] >= cols, "Invalid cols");
NVTE_CHECK(scale_inv_colwise.dtype() == DType::kByte, "Wrong dtype of scale_inv_colwise");
int blocks_x = (cols + kColsPerTile - 1) / kColsPerTile;
int blocks_y = (rows + kRowsPerTile - 1) / kRowsPerTile;
dim3 grid(blocks_x, blocks_y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input.dtype(), IType,
mxfp8_scaling_partial_cast_kernel<IType, fp8e4m3><<<grid, kColsPerTile, 0, stream>>>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<fp8e4m3 *>(output_rowwise.data.dptr),
reinterpret_cast<fp8e4m3 *>(output_colwise.data.dptr),
reinterpret_cast<const e8m0_t *>(scale_inv_rowwise.data.dptr),
reinterpret_cast<const e8m0_t *>(scale_inv_colwise.data.dptr),
scale_inv_rowwise.data.shape[1], scale_inv_colwise.data.shape[1], rows, cols,
start_offset, input.data.shape[0]);)
}
} // namespace mxfp8_scaling_recipe
} // namespace transformer_engine
void nvte_mxfp8_scaling_compute_partial_amax(const NVTETensor input, NVTETensor amax_rowwise,
NVTETensor amax_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream) {
NVTE_API_CALL(nvte_mxfp8_scaling_compute_partial_amax);
using namespace transformer_engine;
mxfp8_scaling_recipe::mxfp8_scaling_compute_partial_amax(
*convertNVTETensorCheck(input), *convertNVTETensorCheck(amax_rowwise),
*convertNVTETensorCheck(amax_colwise), rows, cols, start_offset, stream);
}
void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_rowwise,
NVTETensor output_colwise, const NVTETensor scale_inv_rowwise,
const NVTETensor scale_inv_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream) {
NVTE_API_CALL(nvte_mxfp8_scaling_partial_cast);
using namespace transformer_engine;
mxfp8_scaling_recipe::mxfp8_scaling_partial_cast(
*convertNVTETensorCheck(input), *convertNVTETensorCheck(output_rowwise),
*convertNVTETensorCheck(output_colwise), *convertNVTETensorCheck(scale_inv_rowwise),
*convertNVTETensorCheck(scale_inv_colwise), rows, cols, start_offset, stream);
}
......@@ -335,6 +335,15 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const
size_t h, size_t w, size_t start_offset, size_t block_len,
const DType out_dtype);
void mxfp8_scaling_compute_partial_amax(const at::Tensor &input, at::Tensor amax_rowwise,
at::Tensor amax_colwise, int rows, int cols,
size_t start_offset);
void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwise,
at::Tensor output_colwise, const at::Tensor &scale_inv_rowwise,
const at::Tensor &scale_inv_colwise, int rows, int cols,
size_t start_offset);
/***************************************************************************************************
* Rotary positional embedding
**************************************************************************************************/
......@@ -451,6 +460,9 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon);
void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, const py::object &dummy,
std::vector<std::vector<at::Tensor>> tensor_lists);
/***************************************************************************************************
* padding
**************************************************************************************************/
......
......@@ -48,4 +48,42 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const
start_offset, block_len, static_cast<NVTEDType>(out_dtype), at::cuda::getCurrentCUDAStream());
}
void mxfp8_scaling_compute_partial_amax(const at::Tensor &input, at::Tensor amax_rowwise,
at::Tensor amax_colwise, int rows, int cols,
size_t start_offset) {
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(amax_rowwise.is_contiguous(), "amax_rowwise must be contiguous");
TORCH_CHECK(amax_colwise.is_contiguous(), "amax_colwise must be contiguous");
const TensorWrapper input_cu = makeTransformerEngineTensor(input);
TensorWrapper amax_rowwise_cu = makeTransformerEngineTensor(amax_rowwise);
TensorWrapper amax_colwise_cu = makeTransformerEngineTensor(amax_colwise);
nvte_mxfp8_scaling_compute_partial_amax(input_cu.data(), amax_rowwise_cu.data(),
amax_colwise_cu.data(), rows, cols, start_offset,
at::cuda::getCurrentCUDAStream());
}
void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwise,
at::Tensor output_colwise, const at::Tensor &scale_inv_rowwise,
const at::Tensor &scale_inv_colwise, int rows, int cols,
size_t start_offset) {
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(output_rowwise.is_contiguous(), "output_rowwise must be contiguous");
TORCH_CHECK(output_colwise.is_contiguous(), "output_colwise must be contiguous");
TORCH_CHECK(scale_inv_rowwise.is_contiguous(), "scale_inv_rowwise must be contiguous");
TORCH_CHECK(scale_inv_colwise.is_contiguous(), "scale_inv_colwise must be contiguous");
const TensorWrapper input_cu = makeTransformerEngineTensor(input);
TensorWrapper output_rowwise_cu = makeTransformerEngineTensor(output_rowwise);
TensorWrapper output_colwise_cu = makeTransformerEngineTensor(output_colwise);
const TensorWrapper scale_inv_rowwise_cu = makeTransformerEngineTensor(scale_inv_rowwise);
const TensorWrapper scale_inv_colwise_cu = makeTransformerEngineTensor(scale_inv_colwise);
nvte_mxfp8_scaling_partial_cast(input_cu.data(), output_rowwise_cu.data(),
output_colwise_cu.data(), scale_inv_rowwise_cu.data(),
scale_inv_colwise_cu.data(), rows, cols, start_offset,
at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -20,4 +20,14 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
force_pow_2_scales, epsilon, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, const py::object &dummy,
std::vector<std::vector<at::Tensor>> tensor_lists) {
NVTE_CHECK(dummy.is_none(), "No-op flag is not supported.");
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
nvte_multi_tensor_compute_scale_inv_e8m0_cuda(chunk_size, tensor_lists_ptr.data(), num_lists,
num_tensors, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -276,6 +276,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"),
py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>());
m.def("mxfp8_scaling_compute_partial_amax",
&transformer_engine::pytorch::mxfp8_scaling_compute_partial_amax,
"Compute partial amax from master weights for fp8 mxfp8 scaling", py::arg("input"),
py::arg("amax_rowwise"), py::arg("amax_colwise"), py::arg("rows"), py::arg("cols"),
py::arg("start_offset"), py::call_guard<py::gil_scoped_release>());
m.def("mxfp8_scaling_partial_cast", &transformer_engine::pytorch::mxfp8_scaling_partial_cast,
"Partial cast from master weights for fp8 mxfp8 scaling", py::arg("input"),
py::arg("output_rowwise"), py::arg("output_colwise"), py::arg("scale_inv_rowwise"),
py::arg("scale_inv_colwise"), py::arg("rows"), py::arg("cols"), py::arg("start_offset"),
py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding,
"Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding,
......@@ -427,6 +437,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_compute_scale_and_scale_inv",
&transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_compute_scale_inv_e8m0",
&transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda,
"Fused compute E8M0 scale_inv from amax", py::call_guard<py::gil_scoped_release>());
// Comm+GEMM Overlap
m.def("bulk_overlap_ag_with_external_gemm",
......
......@@ -8,7 +8,11 @@ from typing import Optional, Union, List
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
from transformer_engine_torch import (
multi_tensor_scale,
multi_tensor_compute_scale_and_scale_inv,
multi_tensor_compute_scale_inv_e8m0,
)
from ..quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
......@@ -85,6 +89,7 @@ def cast_master_weights_to_fp8(
delayed_scaling_params = []
current_scaling_params = []
blockwise_scaling_params = []
mxfp8_scaling_params = []
if fsdp_shard_model_weights is None:
use_fsdp_shard_model_weights = False
......@@ -131,8 +136,8 @@ def cast_master_weights_to_fp8(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, MXFP8Quantizer):
raise NotImplementedError(
"cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
mxfp8_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
else:
raise ValueError(
......@@ -146,6 +151,8 @@ def cast_master_weights_to_fp8(
_cast_master_weights_to_fp8_current_scaling(current_scaling_params, *extra_args)
if len(blockwise_scaling_params) > 0:
_cast_master_weights_to_fp8_blockwise_scaling(blockwise_scaling_params, *extra_args)
if len(mxfp8_scaling_params) > 0:
_cast_master_weights_to_fp8_mxfp8_scaling(mxfp8_scaling_params, *extra_args)
def _cast_master_weights_to_fp8_delayed_scaling(
......@@ -467,6 +474,131 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
)
def _cast_master_weights_to_fp8_mxfp8_scaling(
params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False
): # pylint: disable=unused-argument
r"""Helper function to cast master weights to FP8 primary weights for mxfp8 scaling.
Parameters
----------
params : List of tuple, each tuple contains a model weight, a master weight, and an offset
indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
"""
# Parameter attributes
device = params[0][0].device
for _, master_weight, _, _ in params:
if master_weight is not None:
master_weight_dtype = master_weight.dtype
break
# Get the total number of amax elements in all the model weights.
cu_rowwise_amax_sizes = [0]
cu_colwise_amax_sizes = [0]
for model_weight, _, _, _ in params:
rowwise_shape = model_weight._rowwise_scale_inv.shape
assert len(rowwise_shape) == 2
colwise_shape = model_weight._columnwise_scale_inv.shape
assert len(colwise_shape) == 2
cu_rowwise_amax_sizes.append(
cu_rowwise_amax_sizes[-1] + rowwise_shape[0] * rowwise_shape[1]
)
cu_colwise_amax_sizes.append(
cu_colwise_amax_sizes[-1] + colwise_shape[0] * colwise_shape[1]
)
# Create a contiguous buffer to store amaxes temporarily, so we can perform all all-reduce
# NCCL kernels at once.
packed_amaxes = torch.zeros(
cu_rowwise_amax_sizes[-1] + cu_colwise_amax_sizes[-1],
dtype=master_weight_dtype,
device=device,
)
# ---------------------------------------------------------------------------------------------
# Step 1: Iterate through all the none empty master weights and compute amax of them. Store the
# amaxes in a contiguous buffer. If a block of a master weight is empty, the
# corresponding amax will be set to 0.
# ---------------------------------------------------------------------------------------------
amaxes_rowwise, scale_invs_rowwise = [], []
amaxes_colwise, scale_invs_colwise = [], []
for i, (model_weight, master_weight, start_offset, _) in enumerate(params):
rowwise_shape = model_weight._rowwise_scale_inv.shape
colwise_shape = model_weight._columnwise_scale_inv.shape
rowwise_start = cu_rowwise_amax_sizes[i]
rowwise_end = cu_rowwise_amax_sizes[i + 1]
colwise_start = cu_rowwise_amax_sizes[-1] + cu_colwise_amax_sizes[i]
colwise_end = cu_rowwise_amax_sizes[-1] + cu_colwise_amax_sizes[i + 1]
amax_rowwise = packed_amaxes[rowwise_start:rowwise_end].reshape(rowwise_shape)
amax_colwise = packed_amaxes[colwise_start:colwise_end].reshape(colwise_shape)
amaxes_rowwise.append(amax_rowwise)
amaxes_colwise.append(amax_colwise)
scale_invs_rowwise.append(model_weight._rowwise_scale_inv)
scale_invs_colwise.append(model_weight._columnwise_scale_inv)
# Compute amax of the master weight and store it in packed_amaxes.
if master_weight is not None:
assert len(model_weight.shape) == 2
h, w = model_weight.shape
tex.mxfp8_scaling_compute_partial_amax(
master_weight, amax_rowwise, amax_colwise, h, w, start_offset
)
# ---------------------------------------------------------------------------------------------
# Step 2: Perform all-reduce on packed_amaxes to get the global amax.
# ---------------------------------------------------------------------------------------------
torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group)
# ---------------------------------------------------------------------------------------------
# Step 3: Update scales and scale_invs.
# ---------------------------------------------------------------------------------------------
multi_tensor_applier(
multi_tensor_compute_scale_inv_e8m0,
None, # dummy_overflow_buf
[
amaxes_rowwise + amaxes_colwise,
scale_invs_rowwise + scale_invs_colwise,
],
)
# ---------------------------------------------------------------------------------------------
# Step 4: Cast master weights to FP8.
# ---------------------------------------------------------------------------------------------
for (
(model_weight, master_weight, start_offset, model_weight_fragment),
scale_inv_rowwise,
scale_inv_colwise,
) in zip(params, scale_invs_rowwise, scale_invs_colwise):
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
if master_weight is None:
continue
# Cast master weight to FP8
end_offset = start_offset + master_weight.numel()
if use_fsdp_shard_model_weights:
rowwise_fragment = model_weight_fragment[0]
colwise_fragment = model_weight_fragment[1]
else:
rowwise_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset]
colwise_fragment = model_weight._columnwise_data.reshape(-1)[start_offset:end_offset]
assert len(model_weight.shape) == 2
h, w = model_weight.shape
tex.mxfp8_scaling_partial_cast(
master_weight,
rowwise_fragment,
colwise_fragment,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)
def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]):
"""
Post-processing after all-gather for weights in distributed optimizer.
......@@ -485,6 +617,9 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten
elif isinstance(model_weight, Float8BlockwiseQTensor):
# Blockwise scaling: create column-wise storage.
model_weight._create_columnwise()
elif isinstance(model_weight, MXFP8Tensor):
# MXFP8 scaling: no need to do anything.
pass
elif isinstance(model_weight, QuantizedTensor):
raise ValueError(f"post_processing for {type(model_weight)} is not supported")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment