Unverified Commit b4a1d4d6 authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch][MOE] Support NVFP4 Grouped Linear (#2215)



* pipeclean, fix nvfp4 padding of 32 alignment
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* numerical test passed
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fix CI failure with test_cast_master_weights_to_fp8 (in a hacky way)
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* found CUDA mis-aligned address error in training in multi-swizzle, hack the vec_load_size to 1 to unblock
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* leave comments about alignment issue
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fused bulk alloc nvfp4
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fix RHT sign mask CPU overhead
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fix
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* resolve comments
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Remove incorrect logic that treats 0-D tensor as uninitialized

Tensor shape logic still requires treating 0-D tensor as uninitialized.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix invalid conversion from tensor to int
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent bd55e7ba
...@@ -8,53 +8,67 @@ import torch.utils.benchmark as benchmark ...@@ -8,53 +8,67 @@ import torch.utils.benchmark as benchmark
import pandas as pd import pandas as pd
from transformer_engine.pytorch.module import GroupedLinear from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling from transformer_engine.common.recipe import (
Float8BlockScaling,
MXFP8BlockScaling,
NVFP4BlockScaling,
)
from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager
from contextlib import nullcontext from contextlib import nullcontext
""" """
# Profile BF16 recipe with Nsight Systems # Profile BF16 recipe with Nsight Systems
nsys profile \ nsys profile \
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_bf16 \ --output=./benchmarks/linear/b200_numgemm_8_bf16 \
--force-overwrite true \ --force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \ --trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16 python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16
# Profile FP8 sub-channel recipe with Nsight Systems # Profile FP8 sub-channel recipe with Nsight Systems
nsys profile \ nsys profile \
--output=./benchmarks/linear/h100hbm_mkn_4096_4096_4096_numgemm_8_fp8_sub_channel \ --output=./benchmarks/linear/h100hbm_numgemm_8_fp8_sub_channel \
--force-overwrite true \ --force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \ --trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel
# Profile MXFP8 recipe with Nsight Systems # Profile MXFP8 recipe with Nsight Systems
nsys profile \ nsys profile \
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_mxfp8 \ --output=./benchmarks/linear/b200_numgemm_8_mxfp8 \
--force-overwrite true \ --force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \ --trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8 python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8
# Profile NVFP4 recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/b200_numgemm_8_nvfp4 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4
""" """
RECIPES = { RECIPES = {
"bf16": None, "bf16": None,
"fp8_sub_channel": Float8BlockScaling(), "fp8_sub_channel": Float8BlockScaling(),
"mxfp8": MXFP8BlockScaling(), "mxfp8": MXFP8BlockScaling(),
"nvfp4": NVFP4BlockScaling(),
} }
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available() FP8GlobalStateManager.is_fp8_block_scaling_available()
) )
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None): def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
assert mode in ["fwd_only", "fwd_bwd"] assert mode in ["fwd_only", "fwd_bwd"]
fp8_context = autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext() quantization_context = (
# print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}") autocast(enabled=True, recipe=recipe) if recipe is not None else nullcontext()
)
if mode == "fwd_only": if mode == "fwd_only":
with torch.no_grad(), fp8_context: with torch.no_grad(), quantization_context:
for i in range(run_num_steps): for i in range(run_num_steps):
y_q = layer.forward( y_q = layer.forward(
x, x,
...@@ -67,7 +81,7 @@ def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps= ...@@ -67,7 +81,7 @@ def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=
layer.zero_grad() layer.zero_grad()
x.grad = None x.grad = None
with fp8_context: with quantization_context:
for i in range(run_num_steps): for i in range(run_num_steps):
label = f"step_{i}" label = f"step_{i}"
torch.cuda.nvtx.range_push(label) torch.cuda.nvtx.range_push(label)
...@@ -142,7 +156,7 @@ def benchmark_linear( ...@@ -142,7 +156,7 @@ def benchmark_linear(
"recipe": recipe, "recipe": recipe,
}, },
num_threads=1, num_threads=1,
).blocked_autorange(min_run_time=5) ).blocked_autorange(min_run_time=10)
print(f"{recipe_name}: {timing} \n") print(f"{recipe_name}: {timing} \n")
timing_ms = timing.median * 1000 / num_microbatches timing_ms = timing.median * 1000 / num_microbatches
...@@ -225,30 +239,44 @@ if __name__ == "__main__": ...@@ -225,30 +239,44 @@ if __name__ == "__main__":
use_bias = False use_bias = False
# Set the MKN values to benchmark # Set the MKN values to benchmark
# Deepseek V3 EP64, SEQ_LEN=8192, topK8
# 256 expert => 4 local experts
# Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 16384
# M = AvgM * localExperts = 65536
# K = 7168
# N = 2048
# Deepseek V3 EP32, SEQ_LEN=8192, topK8
# 256 expert => 8 local experts
# Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 8192
# M = AvgM * localExperts = 65536
# K = 7168
# N = 2048
# 4 or 8local experts per rank
num_gemms_list = [4, 8]
# MKN for group linear
mkns = [] mkns = []
for m in [8192]: for m in [65536]:
# for m in [4096, 8192, 16384]: for k in [7168]:
# for n in [1024, 2048, 4096, 8192, 16384]: for n in [2048]:
for n in [8192]:
for k in [4096]:
mkns.append((m, k, n)) mkns.append((m, k, n))
# default recipes to run if not specified # default recipes to run if not specified
recipe_list = ["bf16"] recipe_list = ["bf16"]
if args.recipe == "all": if args.recipe == "all":
recipe_list = ["bf16", "fp8_sub_channel", "mxfp8"] recipe_list = ["bf16", "fp8_sub_channel", "mxfp8", "nvfp4"]
else: else:
recipe_list = [args.recipe] recipe_list = [args.recipe]
num_gemms_list = [8]
if args.profile: if args.profile:
mkns = [(4096 * 8, 4096, 4096)] mkns = [(8192 * 8, 7168, 2048)]
# in profile mode, only run one recipe specified in args.recipe # in profile mode, only run one recipe specified in args.recipe
assert args.recipe != "all", ( assert args.recipe != "all", (
"In profile mode, only one recipe can be specified, please specify the recipe as" "In profile mode, only one recipe can be specified, please specify the recipe as"
" fp8_sub_channel, mxfp8, or bf16" " fp8_sub_channel, mxfp8, nvfp4, or bf16"
) )
recipe_list = [args.recipe] recipe_list = [args.recipe]
num_gemms_list = [8] num_gemms_list = [8]
...@@ -265,13 +293,17 @@ if __name__ == "__main__": ...@@ -265,13 +293,17 @@ if __name__ == "__main__":
"bf16", "bf16",
"fp8_sub_channel", "fp8_sub_channel",
"mxfp8", "mxfp8",
], "Recipe must be one of bf16, fp8_sub_channel, or mxfp8" "nvfp4",
], "Recipe must be one of bf16, fp8_sub_channel, mxfp8, or nvfp4"
if recipe_name == "mxfp8" and not mxfp8_available: if recipe_name == "mxfp8" and not mxfp8_available:
print(f"MXFP8 is not available, skipping {recipe_name}") print(f"MXFP8 is not available, skipping {recipe_name}")
continue continue
if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available: if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available:
print(f"FP8 block scaling is not available, skipping {recipe_name}") print(f"FP8 block scaling is not available, skipping {recipe_name}")
continue continue
if recipe_name == "nvfp4" and not nvfp4_available:
print(f"NVFP4 is not available, skipping {recipe_name}")
continue
df = run_benchmark_linear( df = run_benchmark_linear(
mkns, mkns,
......
...@@ -40,6 +40,7 @@ from transformer_engine.pytorch import ( ...@@ -40,6 +40,7 @@ from transformer_engine.pytorch import (
is_mxfp8_available, is_mxfp8_available,
is_fp8_block_scaling_available, is_fp8_block_scaling_available,
is_bf16_available, is_bf16_available,
is_nvfp4_available,
) )
from transformer_engine.pytorch import checkpoint as te_checkpoint from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
...@@ -53,6 +54,7 @@ from utils import ModelConfig, reset_rng_states ...@@ -53,6 +54,7 @@ from utils import ModelConfig, reset_rng_states
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available = is_fp8_block_scaling_available() fp8_block_scaling_available = is_fp8_block_scaling_available()
nvfp4_available = is_nvfp4_available()
sm_80plus = get_device_compute_capability() >= (8, 0) sm_80plus = get_device_compute_capability() >= (8, 0)
...@@ -114,6 +116,43 @@ if NVTE_TEST_NVINSPECT_ENABLED: ...@@ -114,6 +116,43 @@ if NVTE_TEST_NVINSPECT_ENABLED:
) )
def nvfp4_rht_and_2d_quantization():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
random_hadamard_transform=False, fp4_2d_quantization=True
)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
return nvfp4_recipe
def check_rht_usage(recipe: recipe.Recipe) -> bool:
# if using RHT, we can only support bf16
# check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
if recipe.nvfp4():
if (
recipe.fp4_quant_fwd_inp.random_hadamard_transform
or recipe.fp4_quant_fwd_weight.random_hadamard_transform
or recipe.fp4_quant_bwd_grad.random_hadamard_transform
):
return True
return False
def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
supported_input_dtypes = []
if recipe.nvfp4():
supported_input_dtypes.append(torch.bfloat16)
# if not using RHT, we can add fp32 as well
if not check_rht_usage(recipe):
supported_input_dtypes.append(torch.float32)
return supported_input_dtypes
fp8_recipes = [] fp8_recipes = []
if mxfp8_available: if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling()) fp8_recipes.append(recipe.MXFP8BlockScaling())
...@@ -122,6 +161,8 @@ if fp8_block_scaling_available: ...@@ -122,6 +161,8 @@ if fp8_block_scaling_available:
if fp8_available: if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(recipe.DelayedScaling())
if nvfp4_available:
fp8_recipes.append(nvfp4_rht_and_2d_quantization())
use_cutlass_grouped_gemm = [False] use_cutlass_grouped_gemm = [False]
# Only enable cutlass grouped gemm on Hopper # Only enable cutlass grouped gemm on Hopper
...@@ -582,6 +623,11 @@ def _test_e2e_selective_recompute( ...@@ -582,6 +623,11 @@ def _test_e2e_selective_recompute(
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
config = model_configs[model] config = model_configs[model]
...@@ -692,6 +738,11 @@ def test_gpt_full_activation_recompute( ...@@ -692,6 +738,11 @@ def test_gpt_full_activation_recompute(
): ):
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
config = model_configs[model] config = model_configs[model]
...@@ -1275,6 +1326,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ...@@ -1275,6 +1326,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
if recipe is not None and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
te_linear_ref = Linear( te_linear_ref = Linear(
config.hidden_size, config.hidden_size,
...@@ -1718,8 +1775,8 @@ def _test_grouped_linear_accuracy( ...@@ -1718,8 +1775,8 @@ def _test_grouped_linear_accuracy(
split_size = 1 split_size = 1
if fp8: if fp8:
split_size = 16 split_size = 16
if recipe.mxfp8(): if recipe.mxfp8() or recipe.nvfp4():
split_size = 128 split_size = 32
m = config.max_seqlen_q // split_size m = config.max_seqlen_q // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero dist.append(dist[-1]) # Manually add a zero
...@@ -1791,6 +1848,12 @@ def test_grouped_linear_accuracy( ...@@ -1791,6 +1848,12 @@ def test_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
if recipe is not None and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear( grouped_linear = GroupedLinear(
num_gemms, num_gemms,
...@@ -1927,6 +1990,12 @@ def test_grouped_linear_accuracy_save_original_input( ...@@ -1927,6 +1990,12 @@ def test_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
if recipe is not None and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear( grouped_linear = GroupedLinear(
num_gemms, num_gemms,
...@@ -2014,7 +2083,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r ...@@ -2014,7 +2083,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
align_size = 16 align_size = 16
if recipe.mxfp8(): if recipe.mxfp8() or recipe.nvfp4():
align_size = 32 align_size = 32
padded_tokens_per_expert = [ padded_tokens_per_expert = [
(num_tokens + align_size - 1) // align_size * align_size (num_tokens + align_size - 1) // align_size * align_size
...@@ -2129,6 +2198,12 @@ def test_padding_grouped_linear_accuracy( ...@@ -2129,6 +2198,12 @@ def test_padding_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
if recipe is not None and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding( grouped_linear = TorchGroupedLinearWithPadding(
num_gemms, num_gemms,
...@@ -2200,6 +2275,12 @@ def test_padding_grouped_linear_accuracy_save_original_input( ...@@ -2200,6 +2275,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
if recipe is not None and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding( grouped_linear = TorchGroupedLinearWithPadding(
num_gemms, num_gemms,
...@@ -2409,6 +2490,12 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): ...@@ -2409,6 +2490,12 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if NVTE_TEST_NVINSPECT_ENABLED: if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
config = model_configs[model] config = model_configs[model]
outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe) outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe)
......
...@@ -183,21 +183,38 @@ struct Tensor { ...@@ -183,21 +183,38 @@ struct Tensor {
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
*/ */
switch (scaling_mode) { switch (scaling_mode) {
case NVTE_NVFP4_1D_SCALING:
case NVTE_DELAYED_TENSOR_SCALING: case NVTE_DELAYED_TENSOR_SCALING:
if (!has_data() && has_columnwise_data()) { case NVTE_NVFP4_1D_SCALING: {
// Choose data buffer based on whether it is initialized
// Note: Uninitialized buffers currently have shape=[].
// However, this is logically incorrect. 0-D tensors have 1
// entry, and uninitialized tensors should have shape=[0].
bool use_columnwise_shape = false;
if (data.dptr != nullptr) {
use_columnwise_shape = false;
} else if (columnwise_data.dptr != nullptr) {
use_columnwise_shape = true;
} else if (data.shape.size() != 0) {
use_columnwise_shape = false;
} else if (columnwise_data.shape.size() != 0) {
use_columnwise_shape = true;
}
// Infer shape based on data
if (use_columnwise_shape) {
// Column-wise data is transposed
std::vector<size_t> ret; std::vector<size_t> ret;
if (!columnwise_data.shape.empty()) { if (!columnwise_data.shape.empty()) {
ret.reserve(columnwise_data.shape.size());
for (size_t i = 1; i < columnwise_data.shape.size(); i++) { for (size_t i = 1; i < columnwise_data.shape.size(); i++) {
ret.push_back(columnwise_data.shape[i]); ret.push_back(columnwise_data.shape[i]);
} }
ret.push_back(columnwise_data.shape.front()); ret.push_back(columnwise_data.shape.front());
} }
return ret; return ret;
} else {
return data.shape;
} }
break; return data.shape;
}
case NVTE_MXFP8_1D_SCALING: case NVTE_MXFP8_1D_SCALING:
if (!has_data() && has_columnwise_data()) { if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape; return columnwise_data.shape;
......
...@@ -332,11 +332,9 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ ...@@ -332,11 +332,9 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
} // namespace } // namespace
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING || NVTE_CHECK(
input->scaling_mode == NVTE_BLOCK_SCALING_1D || input->scaling_mode == NVTE_MXFP8_1D_SCALING || input->scaling_mode == NVTE_NVFP4_1D_SCALING,
input->scaling_mode == NVTE_BLOCK_SCALING_2D || "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
input->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()), NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()),
"Input tensor has invalid dtype (", to_string(input->dtype()), ")."); "Input tensor has invalid dtype (", to_string(input->dtype()), ").");
...@@ -583,16 +581,19 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, ...@@ -583,16 +581,19 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
// TODO(nvfp4): Add NVFP4 support.
void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input, void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
std::vector<Tensor*>& output, cudaStream_t stream) { std::vector<Tensor*>& output, cudaStream_t stream) {
auto num_tensors = input.size(); auto num_tensors = input.size();
bool all_has_data = true; bool all_has_data = true;
bool all_has_columnwise_data = true; bool all_has_columnwise_data = true;
bool all_nvfp4 = true;
for (size_t i = 0; i < num_tensors; i++) { for (size_t i = 0; i < num_tensors; i++) {
if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) { auto scaling_mode = input[i]->scaling_mode;
NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + "."); auto is_fp8 = is_fp8_dtype(input[i]->dtype());
} auto is_fp4 = is_fp4_dtype(input[i]->dtype());
NVTE_CHECK(
(is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)),
"Not implemented scaling mode " + to_string(scaling_mode) + ".");
// We don't allow empty tensors. They should be filtered out before calling this function. // We don't allow empty tensors. They should be filtered out before calling this function.
if (input[i]->data.numel() == 0) { if (input[i]->data.numel() == 0) {
NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty."); NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty.");
...@@ -601,13 +602,17 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input, ...@@ -601,13 +602,17 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]");
all_has_data &= input[i]->has_data(); all_has_data &= input[i]->has_data();
all_has_columnwise_data &= input[i]->has_columnwise_data(); all_has_columnwise_data &= input[i]->has_columnwise_data();
all_nvfp4 &= is_nvfp4_scaling(scaling_mode);
} }
NVTE_CHECK(all_has_data || all_has_columnwise_data, NVTE_CHECK(all_has_data || all_has_columnwise_data,
"All tensors should have data or columnwise data."); "All tensors should have data or columnwise data.");
const bool rowwise_swizzle = all_has_data || all_nvfp4;
const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4;
constexpr int SF_TILE_DIM_M = 128; constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4; constexpr int SF_TILE_DIM_K = 4;
if (all_has_data) { if (rowwise_swizzle) {
MultiSwizzleArgs kernel_args; MultiSwizzleArgs kernel_args;
kernel_args.num_tensors = 0; kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0; kernel_args.block_range[0] = 0;
...@@ -623,29 +628,60 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input, ...@@ -623,29 +628,60 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
kernel_args.num_tensors = 0; kernel_args.num_tensors = 0;
vec_load_size = 4; vec_load_size = 4;
} }
const int m = input[i]->scale_inv.shape[0];
const int k = input[i]->scale_inv.shape[1]; int m, k;
if (all_has_data) {
m = input[i]->scale_inv.shape[0];
k = input[i]->scale_inv.shape[1];
} else {
NVTE_CHECK(all_nvfp4, "When doing rowwise swizzle with rowwise data, it has to be NVFP4");
m = input[i]->columnwise_scale_inv.shape[0];
k = input[i]->columnwise_scale_inv.shape[1];
}
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
NVTE_CHECK(
m * k == std::accumulate(output[i]->scale_inv.shape.begin(), if (output[i]->has_data()) {
output[i]->scale_inv.shape.end(), 1, std::multiplies<int>()), NVTE_CHECK(
"Input.scale_inv size is not equal to Output.scale_inv size!"); m * k == std::accumulate(output[i]->scale_inv.shape.begin(),
output[i]->scale_inv.shape.end(), 1, std::multiplies<int>()),
"Input.scale_inv size is not equal to Output.scale_inv size!");
}
if (output[i]->has_columnwise_data()) {
NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(),
output[i]->columnwise_scale_inv.shape.end(), 1,
std::multiplies<int>()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!");
}
int num_tiles_k = k / SF_TILE_DIM_K; int num_tiles_k = k / SF_TILE_DIM_K;
int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; int vec_load_size_i = (num_tiles_k - 1) % 4 + 1;
// We use the minimum vec_load_size across all tensors. // We use the minimum vec_load_size across all tensors.
vec_load_size = std::min(vec_load_size, vec_load_size_i); // TODO(zhongbo): fix vec_load_size for NVFP4
// Current unit test won't capture this issue, but in E2E
// using vec_load_size = 1 other than 1 will lead to mis-aligned
// address error in MOE training
vec_load_size = all_nvfp4 ? 1 : std::min(vec_load_size, vec_load_size_i);
const int pos = kernel_args.num_tensors; const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input[i]->scale_inv.dptr);
kernel_args.output_list[pos] = output[i]->scale_inv.dptr;
kernel_args.m_list[pos] = m; kernel_args.m_list[pos] = m;
kernel_args.k_list[pos] = k; kernel_args.k_list[pos] = k;
kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); if (!all_nvfp4 || all_has_data) {
kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE; int block_scale_size = all_nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE;
kernel_args.input_list[pos] = const_cast<void*>(input[i]->scale_inv.dptr);
kernel_args.output_list[pos] = output[i]->scale_inv.dptr;
kernel_args.original_m_list[pos] = input[i]->flat_first_dim();
kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / block_scale_size;
} else {
kernel_args.input_list[pos] = const_cast<void*>(input[i]->columnwise_scale_inv.dptr);
kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr;
kernel_args.original_m_list[pos] = input[i]->flat_last_dim();
kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / NVFP4_BLOCK_SIZE;
}
kernel_args.num_tensors++; kernel_args.num_tensors++;
} }
// Launch the remaining tensors // Launch the remaining tensors
...@@ -655,7 +691,10 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input, ...@@ -655,7 +691,10 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
kernel_args, vec_load_size, true, stream); kernel_args, vec_load_size, true, stream);
} }
if (all_has_columnwise_data) { if (columnwise_swizzle) {
// NVFP4 shouldn't end up here because it only needs rowwise swizzle
NVTE_CHECK(!all_nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle");
MultiSwizzleArgs kernel_args; MultiSwizzleArgs kernel_args;
kernel_args.num_tensors = 0; kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0; kernel_args.block_range[0] = 0;
......
...@@ -190,8 +190,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( ...@@ -190,8 +190,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
const std::vector<size_t> meta_shape{1}; const std::vector<size_t> meta_shape{1};
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
auto scale_inv_dtype = auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0
(scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3
: DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype,
columnwise_scale_inv_shape); columnwise_scale_inv_shape);
......
...@@ -491,6 +491,207 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx ...@@ -491,6 +491,207 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
return retval; return retval;
} }
// allocate fp4 data, fp8 scalings, and amax values
// layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN]
// amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_nvfp4_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<NVFP4Quantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
// Number of tensors
const size_t num_tensors = shape_list.size();
if (num_tensors == 0) {
return retval;
}
// Quantization parameters
const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage;
const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage;
const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode();
const auto fp4_dtype = quantizer_cpp_list[0]->dtype;
constexpr size_t scale_elem_size = 1;
// Helper function to construct tensor view
// Note: Deleter holds a shared_ptr for the buffer, so the buffer
// will survive until all views are deleted.
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
size_t offset, at::ScalarType dtype) -> at::Tensor {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
bool is_empty_shape = product(shape) == 0;
if (buffer->data_ptr<uint8_t>() == nullptr || is_empty_shape) {
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
}
return at::from_blob(
buffer->data_ptr<uint8_t>() + offset, shape_int64,
[buffer](void *) {}, // deleter holds shared_ptr
at::device(at::kCUDA).dtype(dtype));
};
// Lambda function for converting std::vector<size_t> shape to NVFP4 shape (last dim divided by 2)
auto to_fp4_shape = [](const std::vector<size_t> &shape) {
std::vector<size_t> fp4_shape(shape.begin(), shape.end());
if (!fp4_shape.empty()) {
fp4_shape.back() /= 2;
}
return fp4_shape;
};
// Allocate row-wise data
std::vector<at::Tensor> rowwise_data_list, rowwise_scale_list, amax_rowwise_list;
std::vector<std::vector<size_t>> rowwise_data_shapes, rowwise_scale_shapes;
if (rowwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_shapes.emplace_back(shape_list[i]);
rowwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets, amax_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
// Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes).
// Integer arithmetic: ceil(product / 2) == (product + 1) / 2.
buffer_size += (product(rowwise_data_shapes[i]) + 1) / 2;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
amax_offsets.push_back(buffer_size);
// amax is scalar in fp32, 4 bytes each
buffer_size += 4;
}
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]),
data_offsets[i], torch::kUInt8));
rowwise_scale_list.emplace_back(
make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8));
amax_rowwise_list.emplace_back(
make_torch_view(buffer, std::vector<size_t>{1}, amax_offsets[i], torch::kUInt8));
}
}
// Allocate column-wise data
std::vector<at::Tensor> columnwise_data_list, columnwise_scale_list, amax_columnwise_list;
std::vector<std::vector<size_t>> columnwise_data_shapes, columnwise_scale_shapes;
if (columnwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
// push the transposed shape into NVFP4 columnwise shape
// NVFP4 on SM100 is TN only
columnwise_data_shapes.emplace_back();
auto &shape = columnwise_data_shapes.back();
shape.push_back(shape_list[i].back());
for (size_t j = 0; j < shape_list[i].size() - 1; ++j) {
shape.push_back(shape_list[i][j]);
}
columnwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets, amax_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
// Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes).
// Integer arithmetic: ceil(product / 2) == (product + 1) / 2.
buffer_size += (product(columnwise_data_shapes[i]) + 1) / 2;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
amax_offsets.push_back(buffer_size);
// amax is scalar in fp32, 4 bytes each
buffer_size += 4;
}
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
columnwise_data_list.emplace_back(make_torch_view(
buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8));
columnwise_scale_list.emplace_back(
make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8));
amax_columnwise_list.emplace_back(
make_torch_view(buffer, std::vector<size_t>{1}, amax_offsets[i], torch::kUInt8));
}
}
// Construct nvfp4 tensors
py::handle NVFP4TensorClass(reinterpret_cast<PyObject *>(NVFP4TensorStoragePythonClass));
for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none();
py::object columnwise_data =
(columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none());
py::object columnwise_scale =
(columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none());
py::object amax_rowwise = rowwise_usage ? py::cast(amax_rowwise_list[i]) : py::none();
py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none();
// Construct Python tensor
tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data,
columnwise_scale, amax_rowwise, amax_columnwise,
fp4_dtype, quantizer_py_list[i]));
// Construct C++ tensor
// Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor,
// then set the amax and amax_columnwise values.
{
auto tensor_wrapper = makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp4_dtype,
/*amax_ptr=*/nullptr,
/*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode);
// Set the amax rowwise and amax columnwise if available
if (rowwise_usage) {
tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
if (columnwise_usage) {
tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
tensor_cpp_list.emplace_back(std::move(tensor_wrapper));
}
}
return retval;
}
} // namespace } // namespace
std::vector<py::object> split_quantize(const at::Tensor &tensor, std::vector<py::object> split_quantize(const at::Tensor &tensor,
...@@ -549,7 +750,8 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor, ...@@ -549,7 +750,8 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
bool use_fused_bulk_alloc = true; bool use_fused_bulk_alloc = true;
for (size_t i = 0; i < quantizer_list.size(); i++) { for (size_t i = 0; i < quantizer_list.size(); i++) {
if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr()) && if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr()) &&
!detail::IsMXFP8Quantizers(quantizer_list[i].ptr())) { !detail::IsMXFP8Quantizers(quantizer_list[i].ptr()) &&
!detail::IsNVFP4Quantizers(quantizer_list[i].ptr())) {
use_fused_bulk_alloc = false; use_fused_bulk_alloc = false;
break; break;
} }
...@@ -570,6 +772,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor, ...@@ -570,6 +772,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
// TODO(zhongbo): make a better api to make this part less hacky // TODO(zhongbo): make a better api to make this part less hacky
bool is_fp8_blockwise = detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr()); bool is_fp8_blockwise = detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr());
bool is_mxfp8 = detail::IsMXFP8Quantizers(quantizer_list[0].ptr()); bool is_mxfp8 = detail::IsMXFP8Quantizers(quantizer_list[0].ptr());
bool is_nvfp4 = detail::IsNVFP4Quantizers(quantizer_list[0].ptr());
if (is_fp8_blockwise) { if (is_fp8_blockwise) {
// FP8 block-scaling: construct output tensors with bulk allocations // FP8 block-scaling: construct output tensors with bulk allocations
std::vector<Float8BlockQuantizer *> blockwise_quantizers; std::vector<Float8BlockQuantizer *> blockwise_quantizers;
...@@ -586,6 +789,14 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor, ...@@ -586,6 +789,14 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
} }
std::tie(output_py_list, output_cpp_list) = std::tie(output_py_list, output_cpp_list) =
bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers);
} else if (is_nvfp4) {
// NVFP4: construct output tensors with bulk allocations
std::vector<NVFP4Quantizer *> nvfp4_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
nvfp4_quantizers.push_back(static_cast<NVFP4Quantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers);
} else { } else {
NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer"); NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer");
} }
......
...@@ -20,10 +20,11 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { ...@@ -20,10 +20,11 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor"); TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor");
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
auto* amax_ptr = amax.data_ptr<float>();
TensorWrapper fake_te_output( TensorWrapper fake_te_output(
nullptr, te_input.shape(), nullptr, te_input.shape(),
DType::kFloat8E4M3, // It doesn't matter because we only compute amax. DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
amax.data_ptr<float>()); amax_ptr);
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
} }
......
...@@ -1200,6 +1200,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1200,6 +1200,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
rowwise_scale_inv_shape.end()); rowwise_scale_inv_shape.end());
rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts);
rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts);
// hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_rowwise = at::empty({1}, bit32_tensor_opts); amax_rowwise = at::empty({1}, bit32_tensor_opts);
} }
if (columnwise_usage) { if (columnwise_usage) {
...@@ -1213,6 +1215,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1213,6 +1215,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
columnwise_data_tensor = columnwise_data_tensor =
at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts); at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts);
columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts);
// hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_columnwise = at::empty({1}, bit32_tensor_opts); amax_columnwise = at::empty({1}, bit32_tensor_opts);
} }
...@@ -1352,6 +1356,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor( ...@@ -1352,6 +1356,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
} }
if (!amax_rowwise) { if (!amax_rowwise) {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
// hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_rowwise = at::empty({1}, opts); amax_rowwise = at::empty({1}, opts);
tensor.attr("_amax_rowwise") = *amax_rowwise; tensor.attr("_amax_rowwise") = *amax_rowwise;
} }
...@@ -1392,7 +1398,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor( ...@@ -1392,7 +1398,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
} }
if (!amax_columnwise) { if (!amax_columnwise) {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
amax_columnwise = at::zeros({1}, opts); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_columnwise = at::empty({1}, opts);
tensor.attr("_amax_columnwise") = *amax_columnwise; tensor.attr("_amax_columnwise") = *amax_columnwise;
} }
} else { // columnwise_usage == false } else { // columnwise_usage == false
......
...@@ -50,8 +50,6 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap ...@@ -50,8 +50,6 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
void* scale_inv_dptr = scale_inv.data_ptr; void* scale_inv_dptr = scale_inv.data_ptr;
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
// Reconstruct input only to avoid swizzling both directions if not needed.
// The specific dtype used is irrelevant, just needs to be correct bits.
transformer_engine::TensorWrapper input_cu(input.scaling_mode()); transformer_engine::TensorWrapper input_cu(input.scaling_mode());
transformer_engine::TensorWrapper output_cu(input.scaling_mode()); transformer_engine::TensorWrapper output_cu(input.scaling_mode());
...@@ -100,10 +98,14 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors( ...@@ -100,10 +98,14 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) { if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle."); NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) { } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING &&
tensors.front().scaling_mode() != NVTE_NVFP4_1D_SCALING) {
return std::nullopt; return std::nullopt;
} }
const auto scaling_mode = tensors.front().scaling_mode();
const auto nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;
std::vector<transformer_engine::TensorWrapper> wrappers; std::vector<transformer_engine::TensorWrapper> wrappers;
std::vector<NVTETensor> input_tensors, output_tensors; std::vector<NVTETensor> input_tensors, output_tensors;
...@@ -131,39 +133,44 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors( ...@@ -131,39 +133,44 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
// Allocate full buffer // Allocate full buffer
auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8));
const auto input_dtype =
(nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
const auto scale_inv_dtype =
(nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
auto& tensor = tensors[i]; auto& tensor = tensors[i];
void* scale_inv_dptr = scale_inv_dptrs[i]; void* scale_inv_dptr = scale_inv_dptrs[i];
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
auto input_shape = nvte_shape_to_vector(tensor.shape()); // auto input_shape = nvte_shape_to_vector(tensor.shape());
NVTEShape nvte_input_shape;
if (rowwise) {
nvte_input_shape = tensor.shape();
} else {
nvte_input_shape = tensor.get_columnwise_data().shape;
}
auto input_shape = nvte_shape_to_vector(nvte_input_shape);
// Reconstruct input only to avoid swizzling both directions if not needed. // Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant. // Use any 8 bit type, it's irrelevant.
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper input_cu(scaling_mode);
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper output_cu(scaling_mode);
if (rowwise) { if (rowwise) {
input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); input_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
scale_inv_shapes[i]); output_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
input_shape); scale_inv_shapes[i]);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor. // Set the swizzled scaling factor to the original tensor.
tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
scale_inv_shapes[i]);
} else { } else {
input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, input_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
input_shape); input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, output_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
scale_inv_shapes[i]); output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
output_cu.set_columnwise_data(tensor.columnwise_dptr(), scale_inv_shapes[i]);
transformer_engine::DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_scale_inv(
swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor. // Set the swizzled scaling factor to the original tensor.
tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); scale_inv_shapes[i]);
} }
input_tensors.emplace_back(input_cu.data()); input_tensors.emplace_back(input_cu.data());
......
...@@ -78,7 +78,7 @@ class Fp8Padding(torch.nn.Module): ...@@ -78,7 +78,7 @@ class Fp8Padding(torch.nn.Module):
number of GEMMs to be performed simultaneously. number of GEMMs to be performed simultaneously.
align_size : int, optional align_size : int, optional
the alignment size for the input tensor. If not provided, the alignment size will the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first be determined by the FP8/FP4 recipe (32 for MXFP8/NVFP4 and 16 for others) in the first
forward pass. forward pass.
""" """
...@@ -111,7 +111,14 @@ class Fp8Padding(torch.nn.Module): ...@@ -111,7 +111,14 @@ class Fp8Padding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None: if self.align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 self.align_size = (
32
if (
FP8GlobalStateManager.get_fp8_recipe().mxfp8()
or FP8GlobalStateManager.get_fp8_recipe().nvfp4()
)
else 16
)
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
......
...@@ -75,9 +75,9 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -75,9 +75,9 @@ class Fp8Unpadding(torch.nn.Module):
num_gemms : int num_gemms : int
number of GEMMs to be performed simultaneously. number of GEMMs to be performed simultaneously.
align_size : int, optional align_size : int, optional
the alignment size for the input tensor. If not provided, the alignment size will The alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first be automatically determined based on the FP8/FP4 recipe in the first forward pass:
forward pass. 32 for MXFP8 or NVFP4, otherwise 16.
""" """
def __init__( def __init__(
...@@ -109,7 +109,14 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -109,7 +109,14 @@ class Fp8Unpadding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None: if self.align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 self.align_size = (
32
if (
FP8GlobalStateManager.get_fp8_recipe().mxfp8()
or FP8GlobalStateManager.get_fp8_recipe().nvfp4()
)
else 16
)
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment