Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2217 additions
and
110 deletions
+2217
-110
tests/pytorch/debug/test_config.py
tests/pytorch/debug/test_config.py
+1
-1
tests/pytorch/debug/test_log.py
tests/pytorch/debug/test_log.py
+12
-8
tests/pytorch/debug/test_numerics.py
tests/pytorch/debug/test_numerics.py
+6
-6
tests/pytorch/debug/test_sanity.py
tests/pytorch/debug/test_sanity.py
+2
-3
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
+18
-18
tests/pytorch/distributed/run_fsdp2_model.py
tests/pytorch/distributed/run_fsdp2_model.py
+2
-2
tests/pytorch/distributed/run_gemm_with_overlap.py
tests/pytorch/distributed/run_gemm_with_overlap.py
+7
-4
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+6
-4
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+238
-15
tests/pytorch/distributed/run_numerics_exact.py
tests/pytorch/distributed/run_numerics_exact.py
+756
-0
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
+4
-4
tests/pytorch/distributed/test_comm_gemm_overlap.py
tests/pytorch/distributed/test_comm_gemm_overlap.py
+2
-3
tests/pytorch/distributed/test_fusible_ops.py
tests/pytorch/distributed/test_fusible_ops.py
+241
-17
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
.../pytorch/distributed/test_fusible_ops_with_userbuffers.py
+10
-10
tests/pytorch/distributed/test_numerics.py
tests/pytorch/distributed/test_numerics.py
+13
-9
tests/pytorch/distributed/test_numerics_exact.py
tests/pytorch/distributed/test_numerics_exact.py
+70
-0
tests/pytorch/distributed/test_sanity.py
tests/pytorch/distributed/test_sanity.py
+1
-2
tests/pytorch/distributed/test_torch_fsdp2.py
tests/pytorch/distributed/test_torch_fsdp2.py
+3
-4
tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
+242
-0
tests/pytorch/nvfp4/test_nvfp4_module_exact.py
tests/pytorch/nvfp4/test_nvfp4_module_exact.py
+583
-0
No files found.
tests/pytorch/debug/test_config.py
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pathlib
,
os
import
pathlib
from
nvdlfw_inspect.config_manager
import
ConfigManager
...
...
tests/pytorch/debug/test_log.py
View file @
063ef88d
...
...
@@ -8,18 +8,22 @@ import transformer_engine.pytorch as te
import
torch
import
tempfile
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.fp8
import
RecipeState
import
pytest
import
contextlib
import
os
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch
import
(
is_fp8_available
,
is_mxfp8_available
,
is_fp8_block_scaling_available
,
)
from
transformer_engine.pytorch.quantization
import
RecipeState
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
is_mxfp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
is_fp8_block_scaling_available
(
return_reason
=
True
)
LOG_QUANTIZED_CONFIG_BASE
=
"""
...
...
@@ -128,7 +132,7 @@ def test_sanity(feature_dirs):
inp
=
torch
.
zeros
(
128
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
for
_
in
range
(
10
):
with
te
.
fp8_
autocast
(
fp8_
recipe
=
recipe
.
DelayedScaling
()):
with
te
.
autocast
(
recipe
=
recipe
.
DelayedScaling
()):
output
=
model
(
inp
)
loss
=
output
.
sum
()
loss
.
backward
()
...
...
@@ -232,7 +236,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
for
i
in
range
(
20
):
x
=
torch
.
randn
(
4
,
128
,
128
).
cuda
()
with
te
.
fp8_
autocast
(
enabled
=
True
):
with
te
.
autocast
(
enabled
=
True
):
y
=
model
(
x
)
y
.
sum
().
backward
()
debug_api
.
step
()
...
...
tests/pytorch/debug/test_numerics.py
View file @
063ef88d
...
...
@@ -17,19 +17,19 @@ import transformer_engine.debug
import
transformer_engine.pytorch
as
tepytorch
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
DelayedScaling
,
Format
from
transformer_engine.pytorch.
fp8
import
_default_sf_compute
from
transformer_engine.pytorch
.tensor.float8_tensor
import
(
from
transformer_engine.pytorch.
quantization
import
_default_sf_compute
from
transformer_engine.pytorch
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
is_fp8_available
,
)
from
transformer_engine.pytorch.module.base
import
(
_2X_ACC_DGRAD
,
_2X_ACC_FPROP
,
_2X_ACC_WGRAD
,
)
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
all_boolean
=
[
True
,
False
]
FP8_FORMAT
=
Format
.
HYBRID
...
...
@@ -250,7 +250,7 @@ def _init_model(weight):
def
_run_forward_backward
(
x
,
model
,
loss_scale
=
1.0
,
is_first_microbatch
=
None
,
fp8
=
True
):
with
tepytorch
.
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
FP8_RECIPE
):
with
tepytorch
.
autocast
(
enabled
=
fp8
,
recipe
=
FP8_RECIPE
):
y
=
model
(
x
,
is_first_microbatch
=
is_first_microbatch
)
(
y
.
sum
()
*
loss_scale
).
backward
()
debug_api
.
step
()
...
...
@@ -547,7 +547,7 @@ def run_per_tensor_scaling(
LOSS_MULTIPLIER
=
100
with
tepytorch
.
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
FP8_RECIPE
):
with
tepytorch
.
autocast
(
enabled
=
True
,
recipe
=
FP8_RECIPE
):
y
=
model
(
x
,
is_first_microbatch
=
True
)
model
.
zero_grad
()
y
.
retain_grad
()
...
...
tests/pytorch/debug/test_sanity.py
View file @
063ef88d
...
...
@@ -7,11 +7,10 @@ import torch
import
nvdlfw_inspect.api
as
debug_api
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
test_numerics
import
create_config_file
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
B
,
S
,
H
,
D
=
64
,
64
,
64
,
64
...
...
@@ -68,7 +67,7 @@ def _get_model(model_key):
def
_run_forward_backward
(
model
,
fp8
):
for
_
in
range
(
3
):
inp
=
torch
.
randn
((
S
,
B
,
H
)).
cuda
()
with
te
.
fp8_
autocast
(
enabled
=
fp8
):
with
te
.
autocast
(
enabled
=
fp8
):
out
=
model
(
inp
)
out
.
sum
().
backward
()
debug_api
.
step
()
...
...
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
View file @
063ef88d
...
...
@@ -21,13 +21,13 @@ from transformer_engine.common.recipe import (
Recipe
,
)
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch
.tensor
import
QuantizedTensor
,
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
from
transformer_engine.pytorch
import
(
QuantizedTensor
,
Float8Tensor
,
Float8
CurrentScalingQuantize
r
,
Float8
BlockwiseQTenso
r
,
)
from
transformer_engine.pytorch.tensor
import
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockwiseQTensor
def
_get_raw_data
(
quantized_tensor
):
...
...
@@ -439,7 +439,7 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
}
# Create model with FP8 weights
with
te
.
fp8
.
fp8
_model_init
(
with
te
.
quantized
_model_init
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
preserve_high_precision_init_val
=
True
,
...
...
@@ -475,17 +475,17 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
with
te
.
fp8
.
fp8_
autocast
(
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
fp8_
recipe
=
quantization_recipe
(
quantization
),
fp8
_group
=
mock_group
,
recipe
=
quantization_recipe
(
quantization
),
amax_reduction
_group
=
mock_group
,
):
y_fp8
=
model_fp8
(
x
)
with
te
.
fp8_
autocast
(
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
fp8_
recipe
=
quantization_recipe
(
quantization
),
fp8
_group
=
mock_group
,
recipe
=
quantization_recipe
(
quantization
),
amax_reduction
_group
=
mock_group
,
):
y
=
model
(
x
)
...
...
@@ -573,7 +573,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
linear_kwargs
=
{
"params_dtype"
:
torch
.
bfloat16
,
"bias"
:
False
,
"fuse_wgrad_accumulation"
:
False
}
# Create model with FP8 weights
with
te
.
fp8
.
fp8
_model_init
(
with
te
.
quantized
_model_init
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
preserve_high_precision_init_val
=
True
,
...
...
@@ -615,17 +615,17 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
with
te
.
fp8
.
fp8_
autocast
(
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
fp8_
recipe
=
quantization_recipe
(
quantization
),
fp8
_group
=
mock_group
,
recipe
=
quantization_recipe
(
quantization
),
amax_reduction
_group
=
mock_group
,
):
y_fp8
=
model_fp8
(
x
)
with
te
.
fp8_
autocast
(
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
fp8_
recipe
=
quantization_recipe
(
quantization
),
fp8
_group
=
mock_group
,
recipe
=
quantization_recipe
(
quantization
),
amax_reduction
_group
=
mock_group
,
):
y
=
model
(
x
)
...
...
tests/pytorch/distributed/run_fsdp2_model.py
View file @
063ef88d
...
...
@@ -110,9 +110,9 @@ def _train(args):
build_model_context
=
nullcontext
build_model_context_args
=
{}
from
transformer_engine.pytorch
import
fp8
_model_init
from
transformer_engine.pytorch
import
quantized
_model_init
build_model_context
=
fp8
_model_init
build_model_context
=
quantized
_model_init
build_model_context_args
[
"enabled"
]
=
True
# Build the model with the specified context
...
...
tests/pytorch/distributed/run_gemm_with_overlap.py
View file @
063ef88d
...
...
@@ -19,9 +19,12 @@ from torch.distributed.elastic.multiprocessing.errors import record
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch
import
(
Float8Tensor
,
Float8Quantizer
,
MXFP8Quantizer
,
)
import
transformer_engine.pytorch.cpp_extensions
as
tex
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_cublas_workspace_size_bytes
,
...
...
@@ -172,12 +175,12 @@ def _parse_args(argv=None, namespace=None):
opts
.
p2p
=
True
if
opts
.
atomic
:
if
not
te
.
fp8
.
check_fp8_support
():
if
not
te
.
is_fp8_available
():
assert
opts
.
quantization
==
"none"
,
"Atomic GEMM is only supported in FP8."
opts
.
quantization
=
"fp8"
if
opts
.
fp8_output
:
assert
ops
.
quantization
==
"fp8"
,
"FP8 output is only supported with FP8 compute."
assert
op
t
s
.
quantization
==
"fp8"
,
"FP8 output is only supported with FP8 compute."
return
opts
...
...
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
063ef88d
...
...
@@ -165,7 +165,7 @@ def _parse_args(argv=None, namespace=None):
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"RNG seed."
)
parser
.
add_argument
(
"--fp8"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enables the te.
fp8_
autocast() context."
"--fp8"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enables the te.autocast() context."
)
parser
.
add_argument
(
"--quantization"
,
...
...
@@ -438,7 +438,7 @@ def _train(opts):
ub_cfgs
=
ub_cfgs
if
opts
.
ub_cfg
is
None
else
opts
.
ub_cfg
,
)
with
te
.
fp8
_model_init
(
enabled
=
opts
.
fp8_init
):
with
te
.
quantized
_model_init
(
enabled
=
opts
.
fp8_init
):
test_model
=
multi_module_model
(
opts
.
layer_type
,
opts
.
num_layers
,
*
args
,
**
kwargs
)
dist_print
(
"Initialized test model..."
,
debug
=
True
)
if
WORLD_RANK
==
0
:
...
...
@@ -450,7 +450,7 @@ def _train(opts):
ref_args
,
ref_kwargs
,
_
=
_get_layer_args
(
opts
,
nccl_world
,
opts
.
tp
,
num_layers
=
opts
.
num_layers
,
reference
=
True
)
with
te
.
fp8
_model_init
(
enabled
=
opts
.
fp8_init
):
with
te
.
quantized
_model_init
(
enabled
=
opts
.
fp8_init
):
ref_model
=
multi_module_model
(
opts
.
layer_type
,
opts
.
num_layers
,
*
ref_args
,
**
ref_kwargs
)
dist_print
(
"Initialized reference model..."
,
debug
=
True
)
for
test_param
,
ref_param
in
zip
(
test_model
.
parameters
(),
ref_model
.
parameters
()):
...
...
@@ -473,7 +473,9 @@ def _train(opts):
layer_contexts
=
[
(
partial
(
te
.
fp8_autocast
,
enabled
=
opts
.
fp8
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
nccl_world
)
partial
(
te
.
autocast
,
enabled
=
opts
.
fp8
,
recipe
=
fp8_recipe
,
amax_reduction_group
=
nccl_world
)
if
opts
.
num_layers_at_start_in_bf16
<=
i
and
i
<
(
opts
.
num_layers
-
opts
.
num_layers_at_end_in_bf16
)
else
nullcontext
...
...
tests/pytorch/distributed/run_numerics.py
View file @
063ef88d
...
...
@@ -9,6 +9,7 @@ import datetime
import
os
import
sys
from
functools
import
wraps
import
math
import
torch
from
torch
import
nn
...
...
@@ -20,10 +21,14 @@ from transformer_engine.common.recipe import (
DelayedScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
NVFP4BlockScaling
,
Format
,
Recipe
,
QParams
,
)
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8CurrentScalingQuantizer
from
transformer_engine.pytorch
import
Float8CurrentScalingQuantizer
,
NVFP4Quantizer
from
transformer_engine.pytorch.constants
import
NVFP4_BLOCK_SCALING_SIZE
from
transformer_engine.pytorch.distributed
import
gather_along_first_dim
from
run_layer_with_overlap
import
_compare_tensors
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
...
@@ -48,6 +53,14 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
)
def
nvfp4_vanilla
():
nvfp4_recipe
=
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
QParams
()
nvfp4_recipe
.
fp4_quant_fwd_weight
=
QParams
()
nvfp4_recipe
.
fp4_quant_bwd_grad
=
QParams
()
return
nvfp4_recipe
# Quantization recipe setup
def
quantization_recipe
()
->
Recipe
:
if
QUANTIZATION
==
"fp8"
:
...
...
@@ -60,7 +73,9 @@ def quantization_recipe() -> Recipe:
return
Float8CurrentScaling
()
if
QUANTIZATION
==
"fp8_block_scaling"
:
return
Float8BlockScaling
()
return
te
.
fp8
.
get_default_fp8_recipe
()
if
QUANTIZATION
==
"nvfp4"
:
return
nvfp4_vanilla
()
return
te
.
quantization
.
get_default_fp8_recipe
()
def
main
(
argv
=
None
,
namespace
=
None
):
...
...
@@ -97,10 +112,14 @@ def main(argv=None, namespace=None):
# Quantization scheme
QUANTIZATION
=
args
.
quantization
global
SEQ_LEN
,
BATCH_SIZE
,
HIDDEN_SIZE
if
QUANTIZATION
in
(
"fp8"
,
"mxfp8"
):
if
QUANTIZATION
in
(
"fp8"
,
"mxfp8"
,
"nvfp4"
):
SEQ_LEN
=
32
BATCH_SIZE
=
32
HIDDEN_SIZE
=
128
# For fp8 block scaling, block size is 128,
# and to make low precision TP work, input tensor
# must be 128x128 divisible to be eligible for
# low precision All-Gather when needed
elif
QUANTIZATION
==
"fp8_block_scaling"
:
SEQ_LEN
=
128
BATCH_SIZE
=
128
...
...
@@ -108,6 +127,7 @@ def main(argv=None, namespace=None):
test_dict
=
[
test_quantizer
,
test_quantized_all_gather
,
test_linear
,
test_layernorm
,
test_layernorm_linear
,
...
...
@@ -177,6 +197,9 @@ def _get_tolerances(dtype):
# row parallel & sequence parallel, because we do the all_gather in backward pass
if
QUANTIZATION
==
"fp8_cs"
:
return
{
"rtol"
:
0.4
,
"atol"
:
0.25
}
elif
QUANTIZATION
==
"nvfp4"
:
# TODO(zhongboz): investigate why the tolerance is so large
return
{
"rtol"
:
0.125
,
"atol"
:
0.12
}
elif
QUANTIZATION
is
not
None
:
return
{
"rtol"
:
0.125
,
"atol"
:
0.0625
}
...
...
@@ -293,15 +316,15 @@ def _apply_models(
_alloc_main_grad
(
model_single_node
,
model_distributed
)
# for fuse_wgrad_accumulation=True
input_single_node
.
requires_grad_
()
input_distributed
.
requires_grad_
()
with
te
.
fp8_
autocast
(
with
te
.
autocast
(
enabled
=
QUANTIZATION
is
not
None
,
fp8_
recipe
=
quantization_recipe
(),
recipe
=
quantization_recipe
(),
):
output_single_node
=
model_single_node
(
input_single_node
,
**
kwargs
)
with
te
.
fp8_
autocast
(
with
te
.
autocast
(
enabled
=
QUANTIZATION
is
not
None
,
fp8_
recipe
=
quantization_recipe
(),
fp8
_group
=
NCCL_WORLD
,
recipe
=
quantization_recipe
(),
amax_reduction
_group
=
NCCL_WORLD
,
):
output_distributed
=
model_distributed
(
input_distributed
,
**
kwargs
)
return
output_single_node
,
output_distributed
...
...
@@ -327,24 +350,36 @@ def _alloc_main_grad(model_single_node, model_distributed):
###############################################
# Quantizer #
###############################################
def
_construct_quantizer
(
quantizer_class
,
fp8
_dtype
,
device
,
tp_group
,
tp_size
):
def
_construct_quantizer
(
quantizer_class
,
low_precision
_dtype
,
device
,
tp_group
,
tp_size
):
"""
quantizer is the reference quantizer on a single GPU.
quantizer_dist is the distributed quantizer to be tested on multiple GPUs.
"""
if
quantizer_class
==
Float8CurrentScalingQuantizer
:
quantizer_dist
=
quantizer_class
(
fp8_dtype
=
fp8
_dtype
,
fp8_dtype
=
low_precision
_dtype
,
device
=
device
,
with_amax_reduction
=
True
,
amax_reduction_group
=
tp_group
,
)
quantizer
=
quantizer_class
(
fp8_dtype
=
fp8
_dtype
,
fp8_dtype
=
low_precision
_dtype
,
device
=
device
,
with_amax_reduction
=
False
,
)
return
quantizer
,
quantizer_dist
elif
quantizer_class
==
NVFP4Quantizer
:
quantizer_dist
=
quantizer_class
(
fp4_dtype
=
low_precision_dtype
,
with_amax_reduction
=
True
,
amax_reduction_group
=
tp_group
,
)
quantizer
=
quantizer_class
(
fp4_dtype
=
low_precision_dtype
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
)
return
quantizer
,
quantizer_dist
else
:
raise
ValueError
(
f
"Unsupported quantizer class:
{
quantizer_class
}
"
)
...
...
@@ -415,6 +450,194 @@ def test_quantizer():
_test_quantizer
(
input_dtype
,
fp8_dtype
)
############################################
# Quantized All-Gather #
############################################
def
_ref_zero_padding_scale_inv
(
scale_inv
,
unpadded_shape
):
"""
Zero padding the scale_inv.
scale_inv shape is the padded shape, but not zero padded
unpadded_shape is the original shape before padding
"""
dim0
,
dim1
=
scale_inv
.
shape
unpadded_dim0
,
unpadded_dim1
=
unpadded_shape
pad_dim0
=
(
128
-
unpadded_dim0
%
128
)
%
128
pad_dim1
=
(
4
-
unpadded_dim1
%
4
)
%
4
new_dim0
=
unpadded_dim0
+
pad_dim0
new_dim1
=
unpadded_dim1
+
pad_dim1
assert
dim0
==
new_dim0
assert
dim1
==
new_dim1
# return input if no padding is needed
if
pad_dim0
==
0
and
pad_dim1
==
0
:
return
scale_inv
# unpad first to remove random bits from torch empty
scale_inv
=
scale_inv
[:
unpadded_dim0
,
:
unpadded_dim1
].
contiguous
()
# using torch padding
new_scale_inv
=
torch
.
nn
.
functional
.
pad
(
scale_inv
,
(
0
,
pad_dim1
,
0
,
pad_dim0
),
mode
=
"constant"
,
value
=
0
)
assert
new_scale_inv
.
shape
==
(
new_dim0
,
new_dim1
)
return
new_scale_inv
def
_get_unpadded_scale_inv_shape
(
input_shape
,
quantizer_cls
,
columnwise
):
"""
Calculate the unpadded shape of the scale_inv tensor.
"""
M
,
K
=
1
,
1
M
=
math
.
prod
(
input_shape
[:
-
1
])
K
=
input_shape
[
-
1
]
if
quantizer_cls
==
NVFP4Quantizer
:
if
columnwise
:
outer
=
K
inner
=
math
.
ceil
(
M
/
NVFP4_BLOCK_SCALING_SIZE
)
return
(
outer
,
inner
)
else
:
outer
=
M
inner
=
math
.
ceil
(
K
/
NVFP4_BLOCK_SCALING_SIZE
)
return
(
outer
,
inner
)
else
:
raise
ValueError
(
f
"Unsupported quantizer class:
{
quantizer_cls
}
"
)
@
run_distributed_test
()
def
_test_quantized_all_gather
(
input_dtype
,
low_precision_dtype
,
quantizer_cls
):
"""Test the quantizer under distributed settings.
Args:
input_dtype (torch.dtype): The data type of the input.
low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8.
"""
M
,
N
=
WORLD_SIZE
*
BATCH_SIZE
,
HIDDEN_SIZE
//
2
# high precision input
x_hp_cpu
=
torch
.
randn
((
M
,
N
),
device
=
"cpu"
).
to
(
input_dtype
)
# set one element of the input to a very large value, which doesn't live in rank 0 after the split
# to test the amax reduction on purpose
# x_hp_cpu[M - 1, N - 1] = 1e4
# get the unpadded shapes
unpadded_rowwise_scale_inv_shape
=
_get_unpadded_scale_inv_shape
((
M
,
N
),
quantizer_cls
,
False
)
unpadded_columnwise_scale_inv_shape
=
_get_unpadded_scale_inv_shape
((
M
,
N
),
quantizer_cls
,
True
)
# rank 0 takes the full copy and quantize with GPU 0 for verification
if
WORLD_RANK
==
0
:
x_hp_rank0
=
x_hp_cpu
.
clone
().
detach
().
requires_grad_
(
True
).
to
(
"cuda"
)
x_hp_local_rank
=
_shard_tensor
(
x_hp_cpu
,
WORLD_SIZE
,
0
)[
WORLD_RANK
]
# Create quantizers
quantizer
,
quantizer_dist
=
_construct_quantizer
(
quantizer_cls
,
low_precision_dtype
,
x_hp_local_rank
.
device
,
NCCL_WORLD
,
WORLD_SIZE
)
# quantize the entire input
if
WORLD_RANK
==
0
:
x_low_precision_single
=
quantizer
(
x_hp_rank0
)
# run all-gather with a quantizer as input for quantized all-gather
x_low_precision_total
,
_
=
gather_along_first_dim
(
x_hp_local_rank
,
NCCL_WORLD
,
async_op
=
False
,
quantizer
=
quantizer_dist
)
# check the outputs
if
WORLD_RANK
==
0
:
# assert all data and scale_inv are the same
torch
.
testing
.
assert_close
(
x_low_precision_single
.
_rowwise_data
,
x_low_precision_total
.
_rowwise_data
,
rtol
=
0.0
,
atol
=
0.0
,
)
# check the rowwise scale without any padding
unpad_dim0
,
unpad_dim1
=
unpadded_rowwise_scale_inv_shape
unpadded_rowwise_scale_inv_ref
=
x_low_precision_single
.
_rowwise_scale_inv
[
:
unpad_dim0
,
:
unpad_dim1
]
unpadded_rowwise_scale_inv
=
x_low_precision_total
.
_rowwise_scale_inv
[
:
unpad_dim0
,
:
unpad_dim1
]
torch
.
testing
.
assert_close
(
unpadded_rowwise_scale_inv_ref
,
unpadded_rowwise_scale_inv
,
rtol
=
0.0
,
atol
=
0.0
,
)
torch
.
testing
.
assert_close
(
_ref_zero_padding_scale_inv
(
x_low_precision_single
.
_rowwise_scale_inv
,
unpadded_rowwise_scale_inv_shape
),
_ref_zero_padding_scale_inv
(
x_low_precision_total
.
_rowwise_scale_inv
,
unpadded_rowwise_scale_inv_shape
),
rtol
=
0.0
,
atol
=
0.0
,
)
torch
.
testing
.
assert_close
(
x_low_precision_single
.
_columnwise_data
,
x_low_precision_total
.
_columnwise_data
,
rtol
=
0.0
,
atol
=
0.0
,
)
unpad_dim0
,
unpad_dim1
=
unpadded_columnwise_scale_inv_shape
unpadded_columnwise_scale_inv_ref
=
x_low_precision_single
.
_columnwise_scale_inv
[
:
unpad_dim0
,
:
unpad_dim1
]
unpadded_columnwise_scale_inv
=
x_low_precision_total
.
_columnwise_scale_inv
[
:
unpad_dim0
,
:
unpad_dim1
]
torch
.
testing
.
assert_close
(
unpadded_columnwise_scale_inv_ref
,
unpadded_columnwise_scale_inv
,
rtol
=
0.0
,
atol
=
0.0
,
)
torch
.
testing
.
assert_close
(
_ref_zero_padding_scale_inv
(
x_low_precision_single
.
_columnwise_scale_inv
,
unpadded_columnwise_scale_inv_shape
),
_ref_zero_padding_scale_inv
(
x_low_precision_total
.
_columnwise_scale_inv
,
unpadded_columnwise_scale_inv_shape
),
rtol
=
0.0
,
atol
=
0.0
,
)
def
test_quantized_all_gather
():
"""
Run quantized all-gather tests with various configurations.
"""
# skip this test for other quantization schemes
is_nvfp4
=
QUANTIZATION
==
"nvfp4"
# add other recipes for testing if needed
if
not
is_nvfp4
:
return
input_dtypes
=
[
torch
.
bfloat16
]
fp4_dtype
=
[
tex
.
DType
.
kFloat4E2M1
]
fp8_dtype
=
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
]
quantizer_cls_nvfp4
=
[
NVFP4Quantizer
]
# add FP8 quantizers if needed
quantizer_cls_fp8
=
[]
low_precisio_dtypes
=
fp4_dtype
if
is_nvfp4
else
fp8_dtype
quantizer_cls_list
=
quantizer_cls_nvfp4
if
is_nvfp4
else
quantizer_cls_fp8
for
quantizer_cls
in
quantizer_cls_list
:
for
input_dtype
in
input_dtypes
:
for
low_precision_dtype
in
low_precisio_dtypes
:
_test_quantized_all_gather
(
input_dtype
,
low_precision_dtype
,
quantizer_cls
)
############################################
# Linear #
############################################
...
...
@@ -515,7 +738,7 @@ def test_linear():
{
"init_method"
:
_constant
},
{
"fuse_wgrad_accumulation"
:
True
},
{
"return_bias"
:
True
},
{
"params_dtype"
:
torch
.
float16
},
{
"params_dtype"
:
torch
.
float16
if
QUANTIZATION
!=
"nvfp4"
else
torch
.
bfloat16
},
{
"delay_wgrad_compute"
:
True
},
{
"save_original_input"
:
True
},
]
...
...
@@ -703,7 +926,7 @@ def test_layernorm_linear():
{
"init_method"
:
_constant
},
{
"fuse_wgrad_accumulation"
:
True
},
{
"return_bias"
:
True
},
{
"params_dtype"
:
torch
.
float16
},
{
"params_dtype"
:
torch
.
float16
if
QUANTIZATION
!=
"nvfp4"
else
torch
.
bfloat16
},
{
"zero_centered_gamma"
:
False
},
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
...
...
@@ -818,7 +1041,7 @@ def test_layernorm_mlp():
{
"normalization"
:
"RMSNorm"
},
{
"zero_centered_gamma"
:
True
},
{
"bias"
:
False
},
{
"params_dtype"
:
torch
.
float16
},
{
"params_dtype"
:
torch
.
float16
if
QUANTIZATION
!=
"nvfp4"
else
torch
.
bfloat16
},
{
"activation"
:
"relu"
},
{
"fuse_wgrad_accumulation"
:
True
},
{
"return_bias"
:
True
},
...
...
@@ -924,7 +1147,7 @@ def test_transformer_layer():
{
"fuse_qkv_params"
:
True
,
"fuse_wgrad_accumulation"
:
True
},
{
"qkv_weight_interleaved"
:
False
},
{
"bias"
:
False
},
{
"params_dtype"
:
torch
.
float16
},
{
"params_dtype"
:
torch
.
float16
if
QUANTIZATION
!=
"nvfp4"
else
torch
.
bfloat16
},
{
"fuse_qkv_params"
:
True
},
{
"activation"
:
"relu"
},
]
...
...
tests/pytorch/distributed/run_numerics_exact.py
0 → 100644
View file @
063ef88d
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
argparse
import
datetime
import
os
import
sys
from
functools
import
wraps
import
transformer_engine.pytorch
as
te
import
torch
from
torch
import
nn
import
torch.distributed
as
dist
from
transformer_engine.common.recipe
import
(
NVFP4BlockScaling
,
Recipe
,
QParams
,
CustomRecipe
,
)
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch.constants
import
NVFP4_BLOCK_SCALING_SIZE
from
transformer_engine.pytorch.experimental
import
quantization_nvfp4
from
transformer_engine.pytorch.experimental
import
utils
from
run_layer_with_overlap
import
_compare_tensors
BATCH_SIZE
,
HIDDEN_SIZE
,
OUT_SIZE
=
128
,
256
,
128
WORLD_RANK
,
WORLD_SIZE
=
None
,
None
NCCL_WORLD
=
None
LOSS_FN
=
nn
.
MSELoss
()
QUANTIZATION
=
None
def
nvfp4_rht_and_2d_quantization
():
nvfp4_recipe
=
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
QParams
(
random_hadamard_transform
=
True
,
fp4_2d_quantization
=
False
)
nvfp4_recipe
.
fp4_quant_fwd_weight
=
QParams
(
random_hadamard_transform
=
False
,
fp4_2d_quantization
=
True
)
nvfp4_recipe
.
fp4_quant_bwd_grad
=
QParams
(
random_hadamard_transform
=
True
,
fp4_2d_quantization
=
False
)
return
nvfp4_recipe
def
get_nvfp4_quantizer_factory
():
"""
Create a quantizer factory for NVFP4 reference implementation.
This factory returns NVFP4QuantizerRef instances with RHT and 2D quantization
enabled.
Returns:
A factory function that takes a role string and returns a quantizer instance
"""
def
factory
(
role
):
if
role
==
"linear_input"
:
return
quantization_nvfp4
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
True
,
# RHT enabled for input
)
elif
role
==
"linear_weight"
:
return
quantization_nvfp4
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
16
,
16
),
# 2D quantization for weight
pow_2_scales
=
False
,
with_rht
=
False
,
)
elif
role
==
"linear_output"
:
# Output quantization not used
return
None
elif
role
==
"linear_grad_output"
:
return
quantization_nvfp4
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
True
,
# RHT enabled for grad_output
)
elif
role
==
"linear_grad_input"
:
# Grad input quantization not used
return
None
else
:
# For any other roles, return None
return
None
return
factory
# Quantization recipe setup
def
quantization_recipe
()
->
Recipe
:
if
QUANTIZATION
==
"nvfp4"
:
return
nvfp4_rht_and_2d_quantization
()
raise
ValueError
(
f
"Unsupported quantization:
{
QUANTIZATION
}
"
)
def
quantization_reference_recipe
()
->
Recipe
:
"""Create reference recipe using CustomRecipe with NVFP4 quantizer factory."""
if
QUANTIZATION
==
"nvfp4"
:
nvfp4_ref_factory
=
get_nvfp4_quantizer_factory
()
return
CustomRecipe
(
qfactory
=
nvfp4_ref_factory
)
raise
ValueError
(
f
"Unsupported quantization for reference:
{
QUANTIZATION
}
"
)
def
main
(
argv
=
None
,
namespace
=
None
):
global
WORLD_RANK
,
WORLD_SIZE
,
NCCL_WORLD
,
QUANTIZATION
,
BATCH_SIZE
,
HIDDEN_SIZE
,
OUT_SIZE
WORLD_RANK
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
WORLD_SIZE
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
LOCAL_RANK
=
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
LOCAL_SIZE
=
int
(
os
.
getenv
(
"LOCAL_WORLD_SIZE"
,
"1"
))
assert
WORLD_SIZE
==
LOCAL_SIZE
# this test supports only 1 node
assert
LOCAL_SIZE
<=
torch
.
cuda
.
device_count
()
dist_init_kwargs
=
{
"backend"
:
"nccl"
,
"rank"
:
WORLD_RANK
,
"world_size"
:
WORLD_SIZE
,
"timeout"
:
datetime
.
timedelta
(
seconds
=
30
),
}
dist_init_kwargs
[
"init_method"
]
=
"env://"
dist_init_kwargs
[
"device_id"
]
=
torch
.
device
(
f
"cuda:
{
LOCAL_RANK
}
"
)
assert
dist
.
is_nccl_available
()
torch
.
cuda
.
set_device
(
LOCAL_RANK
)
dist
.
init_process_group
(
**
dist_init_kwargs
)
NCCL_WORLD
=
dist
.
new_group
(
backend
=
"nccl"
)
WORLD_SIZE
=
dist
.
get_world_size
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--out-size"
,
type
=
int
,
default
=
128
)
args
=
parser
.
parse_args
(
argv
,
namespace
)
# Quantization scheme
QUANTIZATION
=
args
.
quantization
BATCH_SIZE
=
args
.
batch_size
HIDDEN_SIZE
=
args
.
hidden_size
OUT_SIZE
=
args
.
out_size
test_dict
=
[
test_linear
,
test_layernorm_linear
,
]
for
test
in
test_dict
:
test
()
dist
.
destroy_process_group
()
return
0
def
run_distributed_test
(
test_name
=
None
):
def
decorator
(
func
):
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
name
=
test_name
if
test_name
is
not
None
else
func
.
__name__
dist_print
(
f
"Starting test
{
name
}
with args
{
args
}
and
{
kwargs
}
"
)
torch
.
cuda
.
set_device
(
WORLD_RANK
)
torch
.
manual_seed
(
12345
)
torch
.
cuda
.
manual_seed
(
12345
)
func
(
*
args
,
**
kwargs
)
dist
.
barrier
()
dist_print
(
f
"Passed test
{
name
}
"
)
return
wrapper
return
decorator
def
dist_print
(
msg
,
src
=
None
,
end
=
"
\n
"
,
error
=
False
):
stream
=
sys
.
stderr
if
error
else
sys
.
stdout
if
WORLD_RANK
==
(
0
if
src
is
None
else
src
):
stream
.
write
(
f
"[rank
{
WORLD_RANK
}
]
{
msg
}{
end
}
\n
"
)
############################################
# Linear #
############################################
class
TestDistributedLinearBase
:
@
staticmethod
def
_prepare_data
(
batch_size
,
hidden_size
,
out_size
,
use_bias
=
True
,
seed
=
0
,
dtype
=
torch
.
float32
):
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
((
batch_size
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
w
=
torch
.
randn
((
out_size
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
randn
((
out_size
),
dtype
=
dtype
,
device
=
"cuda"
)
if
use_bias
else
None
gradient
=
torch
.
randn
((
batch_size
,
out_size
),
dtype
=
dtype
,
device
=
"cuda"
)
return
x
,
w
,
bias
,
gradient
@
staticmethod
def
_shard_tensor
(
x
,
world_size
,
axis
):
split_size
=
x
.
size
()[
axis
]
//
world_size
split_tensor
=
torch
.
split
(
x
,
split_size
,
axis
)
out
=
[]
for
tensor
in
split_tensor
:
out
.
append
(
tensor
.
detach
().
clone
().
requires_grad_
(
x
.
requires_grad
))
return
out
@
staticmethod
def
_gather_tensor
(
local
,
world_size
,
tp_group
,
concat_dim
):
out_list
=
[
torch
.
zeros_like
(
local
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
out_list
,
local
,
tp_group
)
return
torch
.
cat
(
out_list
,
dim
=
concat_dim
)
@
staticmethod
def
_all_reduce_tensor
(
local
,
world_size
,
tp_group
):
if
world_size
==
1
:
return
local
handle
=
torch
.
distributed
.
all_reduce
(
local
,
group
=
tp_group
,
async_op
=
False
)
return
local
@
staticmethod
def
_get_sum_abs_error
(
a
,
b
):
return
torch
.
sum
(
torch
.
abs
(
a
-
b
))
@
staticmethod
def
_get_mean_abs_relative_error
(
a
,
b
):
error
=
torch
.
where
(
b
==
0
,
torch
.
ne
(
a
,
b
),
torch
.
abs
((
a
-
b
)
/
b
))
return
torch
.
mean
(
error
)
@
classmethod
def
run_linear_preprocess_parallel
(
cls
,
x
,
w
,
bias
,
gradient
,
parallel_mode
=
None
,
sequence_parallel
=
False
,
tp_size
=
1
,
rank
=
0
,
):
if
tp_size
>
1
:
if
parallel_mode
==
"column"
:
# split w in N dim, which should be axis 0
w
=
cls
.
_shard_tensor
(
w
,
tp_size
,
0
)[
rank
]
bias
=
cls
.
_shard_tensor
(
bias
,
tp_size
,
0
)[
rank
]
if
bias
is
not
None
else
None
# split gradient in N dim, which should be axis 1
gradient
=
cls
.
_shard_tensor
(
gradient
,
tp_size
,
1
)[
rank
]
if
sequence_parallel
:
# split x in M dim, which should be axis 0
x
=
cls
.
_shard_tensor
(
x
,
tp_size
,
0
)[
rank
]
# row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1
if
parallel_mode
==
"row"
:
# split x in K dim, which should be axis 1
x
=
cls
.
_shard_tensor
(
x
,
tp_size
,
1
)[
rank
]
# split w in K dim, which should be axis 1
w
=
cls
.
_shard_tensor
(
w
,
tp_size
,
1
)[
rank
]
if
sequence_parallel
:
# split gradient in M dim, which should be axis 0
gradient
=
cls
.
_shard_tensor
(
gradient
,
tp_size
,
0
)[
rank
]
return
x
,
w
,
bias
,
gradient
@
classmethod
def
run_linear_postprocess_parallel
(
cls
,
y_q
,
dgrad
,
wgrad
,
bgrad
,
parallel_mode
,
sequence_parallel
,
tp_size
,
tp_group
,
):
if
tp_size
>
1
:
if
parallel_mode
==
"column"
:
# gather y_q in N dim, which should be axis 1
y_q
=
cls
.
_gather_tensor
(
y_q
,
tp_size
,
tp_group
,
1
)
# gather wgrad in N dim, which should be axis 0
wgrad
=
cls
.
_gather_tensor
(
wgrad
,
tp_size
,
tp_group
,
0
)
# gather bgrad in N dim, which should be axis 0
bgrad
=
(
cls
.
_gather_tensor
(
bgrad
,
tp_size
,
tp_group
,
0
)
if
bgrad
is
not
None
else
None
)
if
sequence_parallel
:
# gather dgrad in M dim, which should be axis 0
dgrad
=
cls
.
_gather_tensor
(
dgrad
,
tp_size
,
tp_group
,
0
)
if
parallel_mode
==
"row"
:
# gather dgrad in K dim, which should be axis 1
dgrad
=
cls
.
_gather_tensor
(
dgrad
,
tp_size
,
tp_group
,
1
)
# gather wgrad in K dim, which should be axis 1
wgrad
=
cls
.
_gather_tensor
(
wgrad
,
tp_size
,
tp_group
,
1
)
if
sequence_parallel
:
# gather y_q in M dim, which should be axis 0
y_q
=
cls
.
_gather_tensor
(
y_q
,
tp_size
,
tp_group
,
0
)
# we need to sum bias gradient when using TP + SP
bgrad
=
(
cls
.
_all_reduce_tensor
(
bgrad
,
tp_size
,
tp_group
)
if
bgrad
is
not
None
else
None
)
return
y_q
,
dgrad
,
wgrad
,
bgrad
@
classmethod
def
run_linear_one_step
(
cls
,
layer
,
x
,
gradient
,
is_first_microbatch
=
None
,
fuse_wgrad_accumulation
=
False
):
# reset gradients
layer
.
zero_grad
()
x
.
grad
=
None
# Forward pass
if
isinstance
(
layer
,
te
.
Linear
):
# Kitchen Linear
y_q
=
layer
.
forward
(
x
,
is_first_microbatch
=
is_first_microbatch
)
else
:
# the default torch.nn.Linear
y_q
=
layer
(
x
)
# Backward pass
y_q
.
backward
(
gradient
)
# Collect gradients
dgrad
=
x
.
grad
bgrad
=
(
layer
.
_parameters
[
"bias"
].
grad
if
layer
.
_parameters
.
get
(
"bias"
,
None
)
is
not
None
else
None
)
assert
"weight"
in
layer
.
_parameters
if
fuse_wgrad_accumulation
:
wgrad
=
layer
.
_parameters
[
"weight"
].
main_grad
assert
layer
.
_parameters
[
"weight"
].
grad
is
None
else
:
wgrad
=
layer
.
_parameters
[
"weight"
].
grad
return
y_q
,
dgrad
,
wgrad
,
bgrad
@
classmethod
def
run_linear_multiple_steps
(
cls
,
layer
,
x
,
gradient
,
run_num_steps
,
enable_weight_cache
,
fuse_wgrad_accumulation
=
False
,
):
"""
Run multiple steps of linear layer and collect results.
"""
y_q_list
,
dgrad_list
,
wgrad_list
=
[],
[],
[]
bgrad_list
=
[]
if
layer
.
_parameters
.
get
(
"bias"
,
None
)
is
not
None
else
None
for
i
in
range
(
run_num_steps
):
x_i
=
(
x
+
i
).
clone
().
detach
().
requires_grad_
(
True
)
# run_linear_one_step
y_q
,
dgrad
,
wgrad
,
bgrad
=
cls
.
run_linear_one_step
(
layer
,
x_i
,
gradient
,
is_first_microbatch
=
(
i
==
0
)
if
enable_weight_cache
else
None
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
)
# Collect results
y_q_list
.
append
(
y_q
.
detach
().
clone
())
dgrad_list
.
append
(
dgrad
.
detach
().
clone
())
wgrad_list
.
append
(
wgrad
.
detach
().
clone
())
if
bgrad_list
is
not
None
and
bgrad
is
not
None
:
bgrad_list
.
append
(
bgrad
.
detach
().
clone
())
# Stack the results
return
(
torch
.
stack
(
y_q_list
),
torch
.
stack
(
dgrad_list
),
torch
.
stack
(
wgrad_list
),
torch
.
stack
(
bgrad_list
)
if
bgrad_list
is
not
None
else
None
,
)
@
classmethod
def
run_linear
(
cls
,
x
,
w
,
bias
,
gradient
,
parallel_mode
=
None
,
sequence_parallel
=
False
,
tp_group
=
None
,
tp_size
=
1
,
rank
=
0
,
run_num_steps
=
1
,
enable_weight_cache
=
False
,
fuse_wgrad_accumulation
=
False
,
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
x
=
x
.
clone
().
detach
().
requires_grad_
(
True
).
to
(
"cuda"
)
w
=
w
.
clone
().
detach
().
to
(
"cuda"
)
gradient
=
gradient
.
clone
().
detach
().
to
(
"cuda"
)
bias
=
bias
.
clone
().
detach
().
to
(
"cuda"
)
if
bias
is
not
None
else
None
in_features
=
x
.
shape
[
1
]
out_features
=
w
.
shape
[
0
]
# If Model parallel: split inputs for a given rank
x
,
w
,
bias
,
gradient
=
cls
.
run_linear_preprocess_parallel
(
x
,
w
,
bias
,
gradient
,
parallel_mode
,
sequence_parallel
,
tp_size
,
rank
)
# set data types
params_dtype
=
x
.
dtype
# Create linear layer and copy weights
layer
=
te
.
Linear
(
in_features
,
out_features
,
bias
=
bias
is
not
None
,
params_dtype
=
params_dtype
,
parallel_mode
=
parallel_mode
,
sequence_parallel
=
sequence_parallel
,
tp_group
=
tp_group
,
tp_size
=
tp_size
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
)
layer
=
layer
.
to
(
"cuda"
)
with
torch
.
no_grad
():
layer
.
weight
.
copy_
(
w
)
if
bias
is
not
None
:
layer
.
bias
.
copy_
(
bias
)
if
fuse_wgrad_accumulation
:
assert
(
run_num_steps
>
1
),
"Fused weight gradient accumulation requires run_num_steps > 1"
layer
.
weight
.
main_grad
=
torch
.
zeros_like
(
layer
.
weight
)
# Run one step or multiple steps
if
run_num_steps
==
1
:
y_q
,
dgrad
,
wgrad
,
bgrad
=
cls
.
run_linear_one_step
(
layer
,
x
,
gradient
)
else
:
y_q
,
dgrad
,
wgrad
,
bgrad
=
cls
.
run_linear_multiple_steps
(
layer
,
x
,
gradient
,
run_num_steps
,
enable_weight_cache
,
fuse_wgrad_accumulation
,
)
# If Model parallel: gather output and gradients from all ranks
y_q
,
dgrad
,
wgrad
,
bgrad
=
cls
.
run_linear_postprocess_parallel
(
y_q
,
dgrad
,
wgrad
,
bgrad
,
parallel_mode
,
sequence_parallel
,
tp_size
,
tp_group
,
)
return
y_q
,
dgrad
,
wgrad
,
bgrad
@
run_distributed_test
()
def
_test_linear
(
parallel_mode
=
None
,
sequence_parallel
=
False
,
**
kwargs
):
"""Test the linear layer with specified parallel mode and sequence parallelization.
Args:
parallel_mode (str): 'row' or 'column' parallelism.
sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer.
QUANTIZATION options: nvfp4 <=> experimental nvfp4 as a reference
"""
params_dtype
=
torch
.
bfloat16
use_bias
=
kwargs
.
get
(
"bias"
,
True
)
fuse_wgrad_accumulation
=
kwargs
.
get
(
"fuse_wgrad_accumulation"
,
False
)
seed
=
torch
.
initial_seed
()
recipe
=
quantization_recipe
()
# turn on weight quantization cache when fusing wgrad accumulation
enable_weight_cache
=
fuse_wgrad_accumulation
run_num_steps
=
1
if
not
fuse_wgrad_accumulation
else
5
x
,
w
,
bias
,
gradient
=
TestDistributedLinearBase
.
_prepare_data
(
BATCH_SIZE
,
HIDDEN_SIZE
,
OUT_SIZE
,
use_bias
=
use_bias
,
seed
=
seed
,
dtype
=
params_dtype
)
# run the recipe under test
with
te
.
autocast
(
enabled
=
True
,
recipe
=
recipe
):
y_q
,
dgrad
,
wgrad
,
bgrad
=
TestDistributedLinearBase
.
run_linear
(
x
,
w
,
bias
,
gradient
,
parallel_mode
=
parallel_mode
,
sequence_parallel
=
sequence_parallel
,
tp_group
=
NCCL_WORLD
,
tp_size
=
WORLD_SIZE
,
rank
=
WORLD_RANK
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
run_num_steps
=
1
if
not
fuse_wgrad_accumulation
else
5
,
enable_weight_cache
=
fuse_wgrad_accumulation
,
)
# run the reference
reference_recipe
=
quantization_reference_recipe
()
with
te
.
autocast
(
enabled
=
True
,
recipe
=
reference_recipe
):
y_q_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
TestDistributedLinearBase
.
run_linear
(
x
,
w
,
bias
,
gradient
,
parallel_mode
=
parallel_mode
,
sequence_parallel
=
sequence_parallel
,
tp_group
=
NCCL_WORLD
,
tp_size
=
WORLD_SIZE
,
rank
=
WORLD_RANK
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
run_num_steps
=
run_num_steps
,
enable_weight_cache
=
enable_weight_cache
,
)
# compare results, zero tolerance
if
WORLD_RANK
==
0
:
torch
.
testing
.
assert_close
(
y_q
,
y_q_ref
,
atol
=
0
,
rtol
=
0
,
msg
=
"Output mismatch"
)
torch
.
testing
.
assert_close
(
dgrad
,
dgrad_ref
,
atol
=
0
,
rtol
=
0
,
msg
=
"Dgrad mismatch"
)
torch
.
testing
.
assert_close
(
wgrad
,
wgrad_ref
,
atol
=
0
,
rtol
=
0
,
msg
=
"Wgrad mismatch"
)
if
bgrad
is
not
None
and
bgrad_ref
is
not
None
:
torch
.
testing
.
assert_close
(
bgrad
,
bgrad_ref
,
atol
=
0
,
rtol
=
0
,
msg
=
"Bgrad mismatch"
)
def
test_linear
():
"""Run linear layer tests with various configurations."""
kwargs_list
=
[
{
"bias"
:
False
},
]
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"save_original_input"
,
False
)
and
QUANTIZATION
==
"fp8"
:
continue
for
parallel_mode
in
[
"column"
,
"row"
]:
for
sequence_parallel
in
[
False
,
True
]:
_test_linear
(
parallel_mode
,
sequence_parallel
,
**
kwargs
)
############################################
# LayerNormLinear #
############################################
class
TestDistributedLayerNormLinearBase
(
TestDistributedLinearBase
):
@
classmethod
def
run_linear_one_step
(
cls
,
layer
,
x
,
gradient
,
is_first_microbatch
=
None
):
# reset gradients
layer
.
zero_grad
()
x
.
grad
=
None
# Forward pass
y_q
,
ln_out
=
layer
.
forward
(
x
,
is_first_microbatch
=
is_first_microbatch
)
# Backward pass
y_q
.
backward
(
gradient
)
# Collect gradients
dgrad
=
x
.
grad
parameters
=
layer
.
_parameters
# bias and weight gradients
bgrad
=
parameters
[
"bias"
].
grad
if
parameters
.
get
(
"bias"
,
None
)
is
not
None
else
None
assert
"weight"
in
parameters
wgrad
=
parameters
[
"weight"
].
grad
return
y_q
,
ln_out
,
dgrad
,
wgrad
,
bgrad
@
classmethod
def
run_linear_multiple_steps
(
cls
,
layer
,
x
,
gradient
,
run_num_steps
,
enable_weight_cache
,
fuse_wgrad_accumulation
=
False
):
# raise error, no test case for multiple steps for now
raise
NotImplementedError
(
"LayerNormLinear does not support test multiple steps for now"
)
@
classmethod
def
run_layernorm_linear
(
cls
,
x
,
w
,
bias
,
gradient
,
parallel_mode
=
None
,
sequence_parallel
=
False
,
tp_group
=
None
,
tp_size
=
1
,
rank
=
0
,
run_num_steps
=
1
,
enable_weight_cache
=
False
,
LayerNormLinearClass
=
te
.
LayerNormLinear
,
normalization
=
"LayerNorm"
,
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
x
=
x
.
clone
().
detach
().
requires_grad_
(
True
).
to
(
"cuda"
)
w
=
w
.
clone
().
detach
().
to
(
"cuda"
)
gradient
=
gradient
.
clone
().
detach
().
to
(
"cuda"
)
bias
=
bias
.
clone
().
detach
().
to
(
"cuda"
)
if
bias
is
not
None
else
None
in_features
=
x
.
shape
[
1
]
out_features
=
w
.
shape
[
0
]
# If Model parallel: split inputs for a given rank
x
,
w
,
bias
,
gradient
=
cls
.
run_linear_preprocess_parallel
(
x
,
w
,
bias
,
gradient
,
parallel_mode
,
sequence_parallel
,
tp_size
,
rank
)
# set data types
params_dtype
=
x
.
dtype
# Create linear layer and copy weights
layer
=
LayerNormLinearClass
(
in_features
,
out_features
,
bias
=
bias
is
not
None
,
params_dtype
=
params_dtype
,
parallel_mode
=
parallel_mode
,
sequence_parallel
=
sequence_parallel
,
tp_group
=
tp_group
,
tp_size
=
tp_size
,
normalization
=
normalization
,
return_layernorm_output
=
True
,
)
layer
=
layer
.
to
(
"cuda"
)
# Copy weights
# kitchen_linear has different parameter names
with
torch
.
no_grad
():
layer
.
weight
.
copy_
(
w
)
if
bias
is
not
None
:
layer
.
bias
.
copy_
(
bias
)
# Run one step
y_q
,
ln_out
,
dgrad
,
wgrad
,
bgrad
=
cls
.
run_linear_one_step
(
layer
,
x
,
gradient
)
# If Model parallel: gather output and gradients from all ranks
y_q
,
dgrad
,
wgrad
,
bgrad
=
cls
.
run_linear_postprocess_parallel
(
y_q
,
dgrad
,
wgrad
,
bgrad
,
parallel_mode
,
sequence_parallel
,
tp_size
,
tp_group
,
)
return
y_q
,
ln_out
,
dgrad
,
wgrad
,
bgrad
@
run_distributed_test
()
def
_test_layernorm_linear
(
parallel_mode
=
None
,
sequence_parallel
=
False
,
**
kwargs
):
"""Test the linear layer with specified parallel mode and sequence parallelization.
Args:
parallel_mode (str): 'column' parallelism.
sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer.
"""
params_dtype
=
torch
.
bfloat16
use_bias
=
kwargs
.
get
(
"bias"
,
True
)
seed
=
torch
.
initial_seed
()
recipe
=
quantization_recipe
()
# run multiple steps currently not supported for LayerNormLinear
run_num_steps
=
1
x
,
w
,
bias
,
gradient
=
TestDistributedLayerNormLinearBase
.
_prepare_data
(
BATCH_SIZE
,
HIDDEN_SIZE
,
OUT_SIZE
,
use_bias
=
use_bias
,
seed
=
seed
,
dtype
=
params_dtype
)
# run the recipe under test
with
te
.
autocast
(
enabled
=
True
,
recipe
=
recipe
):
y_q
,
ln_out
,
dgrad
,
wgrad
,
bgrad
=
TestDistributedLayerNormLinearBase
.
run_layernorm_linear
(
x
,
w
,
bias
,
gradient
,
parallel_mode
=
parallel_mode
,
sequence_parallel
=
sequence_parallel
,
tp_group
=
NCCL_WORLD
,
tp_size
=
WORLD_SIZE
,
rank
=
WORLD_RANK
,
run_num_steps
=
run_num_steps
,
enable_weight_cache
=
False
,
)
# run the reference
reference_recipe
=
quantization_reference_recipe
()
with
te
.
autocast
(
enabled
=
True
,
recipe
=
reference_recipe
):
y_q_ref
,
ln_out_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
(
TestDistributedLayerNormLinearBase
.
run_layernorm_linear
(
x
,
w
,
bias
,
gradient
,
parallel_mode
=
parallel_mode
,
sequence_parallel
=
sequence_parallel
,
tp_group
=
NCCL_WORLD
,
tp_size
=
WORLD_SIZE
,
rank
=
WORLD_RANK
,
run_num_steps
=
run_num_steps
,
enable_weight_cache
=
False
,
)
)
# compare results, zero tolerance
if
WORLD_RANK
==
0
:
torch
.
testing
.
assert_close
(
y_q
,
y_q_ref
,
atol
=
0
,
rtol
=
0
,
msg
=
"Output mismatch"
)
torch
.
testing
.
assert_close
(
ln_out
,
ln_out_ref
,
atol
=
0
,
rtol
=
0
,
msg
=
"LN output mismatch"
)
torch
.
testing
.
assert_close
(
dgrad
,
dgrad_ref
,
atol
=
0
,
rtol
=
0
,
msg
=
"Dgrad mismatch"
)
torch
.
testing
.
assert_close
(
wgrad
,
wgrad_ref
,
atol
=
0
,
rtol
=
0
,
msg
=
"Wgrad mismatch"
)
if
bgrad
is
not
None
and
bgrad_ref
is
not
None
:
torch
.
testing
.
assert_close
(
bgrad
,
bgrad_ref
,
atol
=
0
,
rtol
=
0
,
msg
=
"Bgrad mismatch"
)
def
test_layernorm_linear
():
kwargs_list
=
[
{
"bias"
:
False
},
]
for
kwargs
in
kwargs_list
:
for
parallel_mode
in
[
"column"
]:
for
sequence_parallel
in
[
False
,
True
]:
_test_layernorm_linear
(
parallel_mode
,
sequence_parallel
,
**
kwargs
)
if
__name__
==
"__main__"
:
sys
.
exit
(
main
())
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
View file @
063ef88d
...
...
@@ -8,15 +8,15 @@ from pathlib import Path
import
pytest
import
torch
from
transformer_engine.pytorch
.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch
import
is_fp8_available
,
is_fp8_block_scaling_available
# NVTE_DISABLE_NVRTC=1 NVTE_INT8_SIM_FP8=1 torchrun --nproc_per_node=4 run_cast_master_weights_to_fp8.py --quantization fp8_block
if
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"cast_master_weights_to_fp8 test needs at least 2 GPUs."
)
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
is_fp8_block_scaling_available
(
return_reason
=
True
)
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
...
...
tests/pytorch/distributed/test_comm_gemm_overlap.py
View file @
063ef88d
...
...
@@ -14,14 +14,13 @@ import pytest
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch.cpp_extensions
as
tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
logging
if
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"Comm+GEMM overlap requires at least 2 GPUs."
)
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
RNG_SEED
:
int
=
42
SEQ_LENGTH
:
int
=
1024
...
...
tests/pytorch/distributed/test_fusible_ops.py
View file @
063ef88d
...
...
@@ -20,31 +20,34 @@ import torch
import
transformer_engine
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
from
transformer_engine.pytorch
import
(
QuantizedTensor
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
MXFP8Quantizer
,
NVFP4Quantizer
,
is_bf16_available
,
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine_torch
as
tex
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
# Import utility functions
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
from
utils
import
dtype_tols
,
make_recipe
from
utils
import
dtype_tols
,
make_recipe
,
quantization_tols
# Check what quantization schemes are supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
nvfp4_available
,
reason_for_no_nvfp4
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
quantization_list
:
list
[
Optional
[
str
]]
=
[
None
]
if
fp8_available
:
quantization_list
.
extend
((
"fp8_delayed_scaling"
,
"fp8_current_scaling"
))
if
mxfp8_available
:
quantization_list
.
append
(
"mxfp8"
)
if
nvfp4_available
:
quantization_list
.
append
(
"nvfp4"
)
@
functools
.
cache
...
...
@@ -115,6 +118,14 @@ def make_reference_and_test_tensors(
test
=
quantizer
(
test
)
elif
quantization
==
"mxfp8"
:
test
=
MXFP8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)(
test
)
elif
quantization
==
"nvfp4"
:
test
=
NVFP4Quantizer
(
with_rht
=
False
,
with_post_rht_amax
=
False
,
with_2d_quantization
=
False
,
stochastic_rounding
=
False
,
with_random_sign_mask
=
False
,
)(
test
)
else
:
raise
ValueError
(
f
"Unsupported quantization scheme (
{
quantization
}
)"
)
if
isinstance
(
test
,
QuantizedTensor
)
and
not
test_is_quantized
:
...
...
@@ -415,7 +426,7 @@ def _test_basic_linear(
# Implementation with fusible operation
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
op
=
te_ops
.
BasicLinear
(
in_features
,
out_features
,
...
...
@@ -428,7 +439,7 @@ def _test_basic_linear(
with
torch
.
no_grad
():
op
.
weight
.
copy_
(
w_test
)
del
w_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
op
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -437,7 +448,7 @@ def _test_basic_linear(
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -581,7 +592,7 @@ def _test_linear(
# Implementation with fusible operation
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
...
...
@@ -600,7 +611,7 @@ def _test_linear(
model
[
0
].
bias
.
copy_
(
b_test
)
del
w_test
del
b_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -609,7 +620,7 @@ def _test_linear(
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -623,6 +634,204 @@ def _test_linear(
torch
.
testing
.
assert_close
(
db_test
,
db_ref
,
**
tols
)
def
_test_mlp
(
*
,
bias
:
bool
=
True
,
hidden_size
:
int
=
32
,
local_batch_size
:
int
=
32
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
]
=
None
,
quantized_weight
:
bool
=
False
,
sequence_parallel
:
bool
=
False
,
)
->
None
:
"""2-layer MLP
MLP includes GELU activation in order to test op fusions. Model
performs warmup steps in order to test inter-step logic.
"""
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
if
not
quantized_compute
and
quantized_weight
:
return
# Distributed process group
process_group
=
world_group
()
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
# Tensor dimensions
mlp_size
=
hidden_size
*
world_size
batch_size
=
local_batch_size
if
sequence_parallel
:
batch_size
*=
world_size
in_shape
=
(
batch_size
,
hidden_size
)
# Random data
reset_rng
()
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w1_ref
,
w1_test
=
make_reference_and_test_tensors
(
(
mlp_size
,
hidden_size
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
b1_ref
,
b1_test
=
None
,
None
w2_ref
,
w2_test
=
make_reference_and_test_tensors
(
(
hidden_size
,
mlp_size
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
b2_ref
,
b2_test
=
None
,
None
if
bias
:
b1_ref
,
b1_test
=
make_reference_and_test_tensors
(
(
mlp_size
,),
test_dtype
=
dtype
,
test_device
=
device
,
)
b2_ref
,
b2_test
=
make_reference_and_test_tensors
(
(
world_size
,
hidden_size
),
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
gelu
(
x_ref
,
approximate
=
"tanh"
)
y_ref
=
torch
.
nn
.
functional
.
linear
(
y_ref
,
w1_ref
)
if
bias
:
y_ref
+=
b1_ref
y_ref
=
torch
.
nn
.
functional
.
gelu
(
y_ref
,
approximate
=
"tanh"
)
y_ref
=
torch
.
nn
.
functional
.
linear
(
y_ref
,
w2_ref
)
if
bias
:
y_ref
+=
b2_ref
.
sum
(
dim
=
0
)
y_ref
=
torch
.
nn
.
functional
.
gelu
(
y_ref
,
approximate
=
"tanh"
)
y_ref
.
backward
(
dy_ref
)
# Convert to distributed tensors
with
torch
.
no_grad
():
local_mlp_size
=
mlp_size
//
world_size
local_mlp_slice
=
slice
(
rank
*
local_mlp_size
,
(
rank
+
1
)
*
local_mlp_size
)
dx_ref
=
x_ref
.
grad
dw1_ref
=
w1_ref
.
grad
[
local_mlp_slice
,
:]
w1_ref
=
w1_ref
[
local_mlp_slice
,
:]
w1_test
=
w1_test
[
local_mlp_slice
,
:]
dw2_ref
=
w2_ref
.
grad
[:,
local_mlp_slice
]
w2_ref
=
w2_ref
[:,
local_mlp_slice
]
w2_test
=
w2_test
[:,
local_mlp_slice
]
if
bias
:
db1_ref
=
b1_ref
.
grad
[
local_mlp_slice
]
b1_ref
=
b1_ref
[
local_mlp_slice
]
b1_test
=
b1_test
[
local_mlp_slice
]
db2_ref
=
b2_ref
.
grad
[
rank
,
:]
b2_ref
=
b2_ref
[
rank
,
:]
b2_test
=
b2_test
[
rank
,
:]
else
:
db1_ref
=
None
db2_ref
=
None
if
sequence_parallel
:
local_batch_slice
=
slice
(
rank
*
local_batch_size
,
(
rank
+
1
)
*
local_batch_size
,
)
x_ref
=
x_ref
[
local_batch_slice
,
...]
dx_ref
=
dx_ref
[
local_batch_slice
,
...]
x_test
=
x_test
[
local_batch_slice
,
...].
clone
()
y_ref
=
y_ref
[
local_batch_slice
,
...]
dy_ref
=
dy_ref
[
local_batch_slice
,
...]
dy_test
=
dy_test
[
local_batch_slice
,
...].
clone
()
x_test
.
requires_grad_
()
# Implementation with fusible operation
recipe
=
make_recipe
(
quantization
)
with
te
.
quantized_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model
=
te_ops
.
Sequential
(
te_ops
.
GELU
(),
te_ops
.
Linear
(
hidden_size
,
mlp_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
tensor_parallel_mode
=
"column"
,
tensor_parallel_group
=
process_group
,
sequence_parallel
=
sequence_parallel
,
),
te_ops
.
GELU
(),
te_ops
.
Linear
(
mlp_size
,
hidden_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
tensor_parallel_mode
=
"row"
,
tensor_parallel_group
=
process_group
,
sequence_parallel
=
sequence_parallel
,
),
te_ops
.
GELU
(),
)
with
torch
.
no_grad
():
model
[
1
].
weight
.
copy_
(
w1_test
)
model
[
3
].
weight
.
copy_
(
w2_test
)
if
bias
:
model
[
1
].
bias
.
copy_
(
b1_test
)
model
[
3
].
bias
.
copy_
(
b2_test
)
del
w1_test
,
w2_test
,
b1_test
,
b2_test
# Warmup steps
for
_
in
range
(
3
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
x_test
.
grad
=
None
model
[
1
].
weight
.
grad
=
None
model
[
3
].
weight
.
grad
=
None
if
bias
:
model
[
1
].
bias
.
grad
=
None
model
[
3
].
bias
.
grad
=
None
# Forward and backward step
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw1_test
=
model
[
1
].
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw2_test
=
model
[
3
].
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
dx_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dw1_test
,
dw1_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dw2_test
,
dw2_ref
,
**
tols
)
if
bias
:
db1_test
=
model
[
1
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
db2_test
=
model
[
3
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
db1_test
,
db1_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
db2_test
,
db2_ref
,
**
tols
)
def
_test_fp8_scale_update
(
*
,
amax_history_len
:
int
=
31
,
...
...
@@ -734,7 +943,7 @@ def _test_fp8_scale_update(
amax_history_len
=
amax_history_len
,
amax_compute_algo
=
amax_compute_algo
,
)
with
te
.
fp8_
autocast
(
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
recipe
=
recipe
):
y_test
=
op
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -789,16 +998,31 @@ def run_parallel_tests() -> None:
for
config
in
itertools
.
product
(
quantization_list
,
(
"column"
,
"row"
),
(
False
,
True
),
):
if
rank
==
0
:
print
(
f
"Running _test_linear with
{
config
=
}
"
)
quantization
,
tensor_parallel_mode
=
config
dtype
=
torch
.
bfloat16
if
is_bf16_
compati
ble
()
else
torch
.
float32
quantization
,
tensor_parallel_mode
,
sequence_parallel
=
config
dtype
=
torch
.
bfloat16
if
is_bf16_
availa
ble
()
else
torch
.
float32
_test_linear
(
bias
=
True
,
# bias=False is tested in _test_basic_linear
dtype
=
dtype
,
quantization
=
quantization
,
tensor_parallel_mode
=
tensor_parallel_mode
,
sequence_parallel
=
sequence_parallel
,
)
# MLP
for
config
in
itertools
.
product
(
quantization_list
,
(
False
,
True
)):
if
rank
==
0
:
print
(
f
"Running _test_mlp with
{
config
=
}
"
)
quantization
,
sequence_parallel
=
config
dtype
=
torch
.
bfloat16
if
is_bf16_available
()
else
torch
.
float32
_test_mlp
(
bias
=
True
,
# bias=False is tested in _test_basic_linear
dtype
=
dtype
,
quantization
=
quantization
,
sequence_parallel
=
sequence_parallel
,
)
# FP8 scale update
...
...
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
View file @
063ef88d
...
...
@@ -16,23 +16,23 @@ import sys
import
pytest
import
torch
from
typing
import
Optional
,
Iterable
import
transformer_engine
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch.cpp_extensions
as
tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops.fused
import
(
UserbuffersBackwardLinear
,
UserbuffersForwardLinear
,
)
from
transformer_engine.pytorch
.tensor.float8_tensor
import
(
from
transformer_engine.pytorch
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
MXFP8Quantizer
,
QuantizedTensor
,
Float8Tensor
,
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.tensor.quantized_tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
# Import utility functions
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
...
...
@@ -40,8 +40,8 @@ sys.path.append(str(_current_file.parent.parent))
from
utils
import
dtype_tols
,
make_recipe
,
str_to_dtype
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
quantization_list
:
list
[
Optional
[
str
]]
=
[
None
]
if
fp8_available
:
quantization_list
.
extend
((
"fp8_delayed_scaling"
,
"fp8_current_scaling"
))
...
...
@@ -301,7 +301,7 @@ def _test_linear(
# Implementation with fusible operation
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
ops
=
[]
linear_op
=
None
bias_op
=
None
...
...
@@ -351,7 +351,7 @@ def _test_linear(
bias_op
.
bias
.
copy_
(
b_test
)
del
w_test
del
b_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
tests/pytorch/distributed/test_numerics.py
View file @
063ef88d
...
...
@@ -8,9 +8,8 @@ from pathlib import Path
import
pytest
import
torch
from
transformer_engine.pytorch
.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch
as
te
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine
as
te
"""
Distributed numerics tests
...
...
@@ -27,11 +26,12 @@ import transformer_engine as te
if
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"Distributed training needs at least 2 GPUs."
)
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
te
.
is_fp8_block_scaling_available
(
return_reason
=
True
)
nvfp4_available
,
reason_for_no_nvfp4
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
NUM_PROCS
:
int
=
min
(
4
,
torch
.
cuda
.
device_count
())
...
...
@@ -52,7 +52,9 @@ def _run_test(quantization):
all_boolean
=
[
True
,
False
]
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
None
,
"fp8"
,
"mxfp8"
,
"fp8_cs"
,
"fp8_block_scaling"
])
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
None
,
"fp8"
,
"mxfp8"
,
"fp8_cs"
,
"fp8_block_scaling"
,
"nvfp4"
]
)
def
test_distributed
(
quantization
):
if
quantization
==
"fp8"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
...
...
@@ -62,15 +64,17 @@ def test_distributed(quantization):
pytest
.
skip
(
reason_for_no_mxfp8
)
if
quantization
==
"fp8_block_scaling"
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
quantization
==
"nvfp4"
and
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
if
IS_HIP_EXTENSION
and
quantization
==
"fp8_block_scaling"
:
import
importlib
ori_int8_sim_fp8
=
os
.
environ
.
get
(
"NVTE_INT8_SIM_FP8"
,
"None"
)
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"1"
importlib
.
reload
(
te
.
pytorch
.
fp8
)
importlib
.
reload
(
te
.
fp8
)
_run_test
(
quantization
)
if
IS_HIP_EXTENSION
and
quantization
==
"fp8_block_scaling"
:
if
ori_int8_sim_fp8
is
None
or
ori_int8_sim_fp8
==
"None"
:
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"0"
else
:
del
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
importlib
.
reload
(
te
.
pytorch
.
fp8
)
importlib
.
reload
(
te
.
fp8
)
tests/pytorch/distributed/test_numerics_exact.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
os
import
subprocess
from
pathlib
import
Path
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
"""
Distributed numerics tests
This numerical test aims for zero tolerance test for absolute confidence in numerics.
In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise
result with the native silicon. For distrbuted test cases, we can do the same by thing
by comparing BF16 AG results with the low precision AG results at layer level.
"""
if
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"Distributed training needs at least 2 GPUs."
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
te
.
is_fp8_block_scaling_available
(
return_reason
=
True
)
nvfp4_available
,
reason_for_no_nvfp4
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
NUM_PROCS
:
int
=
min
(
4
,
torch
.
cuda
.
device_count
())
LAUNCH_CMD
=
[
"torchrun"
,
f
"--nproc_per_node=
{
NUM_PROCS
}
"
]
def
_run_test
(
quantization
,
batch_size
,
hidden_size
,
out_size
):
test_path
=
TEST_ROOT
/
"run_numerics_exact.py"
test_cmd
=
LAUNCH_CMD
+
[
str
(
test_path
)]
test_cmd
+=
[
"--quantization"
,
quantization
]
test_cmd
+=
[
"--batch-size"
,
str
(
batch_size
)]
test_cmd
+=
[
"--hidden-size"
,
str
(
hidden_size
)]
test_cmd
+=
[
"--out-size"
,
str
(
out_size
)]
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
check
=
False
)
assert
result
.
returncode
==
0
all_boolean
=
[
True
,
False
]
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
"nvfp4"
])
@
pytest
.
mark
.
parametrize
(
"batch_size, hidden_size, out_size"
,
[
(
64
,
128
,
128
),
(
128
,
128
,
128
),
(
128
,
256
,
256
),
(
512
,
1024
,
768
),
(
512
,
256
,
1024
),
(
2048
,
2048
,
2048
),
],
)
def
test_distributed
(
quantization
,
batch_size
,
hidden_size
,
out_size
):
if
quantization
==
"nvfp4"
and
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
_run_test
(
quantization
,
batch_size
,
hidden_size
,
out_size
)
tests/pytorch/distributed/test_sanity.py
View file @
063ef88d
...
...
@@ -7,8 +7,7 @@ import sys
import
pytest
import
torch
import
transformer_engine
from
transformer_engine.pytorch.attention.dot_product_attention
import
DotProductAttention
from
transformer_engine.pytorch
import
TransformerLayer
,
Linear
from
transformer_engine.pytorch
import
DotProductAttention
,
TransformerLayer
,
Linear
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
...
...
tests/pytorch/distributed/test_torch_fsdp2.py
View file @
063ef88d
...
...
@@ -7,12 +7,11 @@ import pytest
import
subprocess
from
pathlib
import
Path
import
torch
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch
as
te
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
NUM_PROCS
:
int
=
torch
.
cuda
.
device_count
()
...
...
@@ -34,7 +33,7 @@ def _run_test(fp_init, sharding_dims):
@
pytest
.
mark
.
skipif
(
NUM_PROCS
<
4
,
reason
=
"Requires 4+ GPUs"
)
@
pytest
.
mark
.
skipif
(
NUM_PROCS
%
2
!=
0
,
reason
=
"Requires even number of GPUs"
)
@
pytest
.
mark
.
skipif
(
not
torch_version
()
>=
(
2
,
4
,
0
),
reason
=
"Requires PyTorch 2.4.0+"
)
@
pytest
.
mark
.
skipif
(
not
te
.
torch_version
()
>=
(
2
,
4
,
0
),
reason
=
"Requires PyTorch 2.4.0+"
)
@
pytest
.
mark
.
parametrize
(
"sharding_dims"
,
([
NUM_PROCS
],
[
2
,
NUM_PROCS
//
2
]))
@
pytest
.
mark
.
parametrize
(
"fp8_init"
,
(
False
,
True
))
def
test_distributed
(
fp8_init
,
sharding_dims
):
...
...
tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch.experimental.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.pytorch.experimental
import
utils
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
def
check_nvfp4_gemm_versus_reference
(
x_dtype
:
torch
.
dtype
,
w_dtype
:
torch
.
dtype
,
out_dtype
:
torch
.
dtype
,
M
:
int
,
K
:
int
,
N
:
int
,
accumulate
:
bool
,
*
,
x_columnwise
:
bool
=
False
,
w_columnwise
:
bool
=
False
,
):
te_dtype
=
tex
.
DType
.
kFloat4E2M1
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Input tensors
x_shape
=
(
K
,
M
)
if
x_columnwise
else
(
M
,
K
)
w_shape
=
(
K
,
N
)
if
w_columnwise
else
(
N
,
K
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
x_dtype
,
device
=
device
)
w
=
torch
.
randn
(
w_shape
,
dtype
=
w_dtype
,
device
=
device
)
# Setup out tensor if accumulate is True
if
accumulate
:
out
=
torch
.
randn
((
M
,
N
),
dtype
=
out_dtype
,
device
=
device
)
else
:
out
=
None
# Native TE NVFP4 quantization
x_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
False
,
with_post_rht_amax
=
False
,
)
w_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
False
,
with_post_rht_amax
=
False
,
)
# Quantize x and w
x_nvfp4_native
=
x_quantizer
.
make_empty
(
x_shape
,
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_nvfp4_native
=
x_quantizer
.
update_quantized
(
x
,
x_nvfp4_native
)
w_nvfp4_native
=
w_quantizer
.
make_empty
(
w_shape
,
dtype
=
w_dtype
,
device
=
device
,
requires_grad
=
False
)
w_nvfp4_native
=
w_quantizer
.
update_quantized
(
w
,
w_nvfp4_native
)
# Extract quantized data from native NVFP4Tensors
qx_data
=
(
x_nvfp4_native
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
x_columnwise
else
x_nvfp4_native
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
)
qw_data
=
(
w_nvfp4_native
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
w_columnwise
else
w_nvfp4_native
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
)
sx_native
=
(
x_nvfp4_native
.
_columnwise_scale_inv
if
x_columnwise
else
x_nvfp4_native
.
_rowwise_scale_inv
)
sw_native
=
(
w_nvfp4_native
.
_columnwise_scale_inv
if
w_columnwise
else
w_nvfp4_native
.
_rowwise_scale_inv
)
# Trim quantized data to match the actual tensor dimensions (remove padding)
qx_data
=
qx_data
[:
M
,
:]
qw_data
=
qw_data
[:
N
,
:]
# NVFP4 uses 16-element blocks, trim scales to remove padding
block_length
=
16
# NVFP4 uses 16-element blocks
expected_sx_cols
=
expected_sw_cols
=
K
//
block_length
# Trim the scales to remove padding
sx_trimmed
=
sx_native
[:
M
,
:
expected_sx_cols
]
sw_trimmed
=
sw_native
[:
N
,
:
expected_sw_cols
]
# Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn
# for the reference GEMM to work correctly
sx_trimmed
=
sx_trimmed
.
view
(
torch
.
float8_e4m3fn
)
sw_trimmed
=
sw_trimmed
.
view
(
torch
.
float8_e4m3fn
)
# Create reference quantizer for reference GEMM
ref_quantizer
=
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
rowwise
=
True
,
columnwise
=
True
,
pow_2_scales
=
False
,
eps
=
0.0
,
quant_tile_shape
=
(
1
,
16
),
)
# Create reference quantized tensors needed by reference GEMM
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
)
w_nvfp4_ref
=
ref_quantizer
.
quantize
(
w
)
# Reference GEMM using quantizer's qgemm method
y_ref
=
ref_quantizer
.
qgemm
(
qx
=
qx_data
,
qw
=
qw_data
,
m_params
=
None
,
# MMParams not used in reference
out_dtype
=
out_dtype
,
sx
=
sx_trimmed
,
sw
=
sw_trimmed
,
bias
=
None
,
# No bias for this test
out
=
out
.
clone
()
if
accumulate
else
None
,
accumulate
=
accumulate
,
gemm_type
=
None
,
# GEMMType not used in reference
qresult_x
=
x_nvfp4_ref
,
qresult_w
=
w_nvfp4_ref
,
)
# Native TE GEMM using tex.generic_gemm (cuBLAS GEMM)
# Allocate cuBLAS workspace
workspace
=
torch
.
empty
(
4
,
dtype
=
torch
.
uint8
,
device
=
device
)
transa
=
True
if
not
w_columnwise
else
False
transb
=
False
if
not
x_columnwise
else
True
out_quantizer
=
None
bias
=
None
bias_dtype
=
TE_DType
[
torch
.
bfloat16
]
use_gelu
=
False
gelu_input
=
None
use_grad
=
False
use_split_accumulator
=
False
# Native cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_native
=
tex
.
generic_gemm
(
w_nvfp4_native
,
transa
,
x_nvfp4_native
,
transb
,
out
.
clone
()
if
accumulate
else
None
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
gelu_input
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# just in case of accumulation, make sure y_ref and y_native are not the same tensor
assert
y_ref
is
not
y_native
,
"y_ref and y_native should not be the same tensor"
# Reset nans to zeros because torch.assert_close does not assume nans to be equal
assert
not
torch
.
isnan
(
y_ref
.
float
()).
all
(),
"All elements are nan"
y_ref
=
torch
.
where
(
y_ref
.
isnan
(),
torch
.
zeros_like
(
y_ref
),
y_ref
)
y_native
=
torch
.
where
(
y_native
.
isnan
(),
torch
.
zeros_like
(
y_native
),
y_native
)
# Compare results with some tolerance
torch
.
testing
.
assert_close
(
y_native
,
y_ref
,
atol
=
8e-3
,
rtol
=
8e-3
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, K, N"
,
[
(
128
,
128
,
128
),
(
256
,
128
,
256
),
(
256
,
256
,
256
),
(
256
,
1024
,
256
),
(
1024
,
1024
,
1024
),
(
4096
,
512
,
3072
),
(
112
,
128
,
96
),
(
304
,
640
,
304
),
(
1008
,
3072
,
992
),
(
256
,
64
,
256
),
(
128
,
128
,
112
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
True
,
False
],
ids
=
[
"accumulate"
,
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"is_x_columnwise, is_w_columnwise"
,
[
(
False
,
False
),
# Only rowwise x rowwise is supported by reference GEMM
# Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization
# Columnwise layouts are not supported by the reference implementation
],
ids
=
[
"rowxrow"
],
)
def
test_nvfp4_gemm_versus_reference
(
M
:
int
,
K
:
int
,
N
:
int
,
x_dtype
:
torch
.
dtype
,
w_dtype
:
torch
.
dtype
,
out_dtype
:
torch
.
dtype
,
accumulate
:
bool
,
is_x_columnwise
:
bool
,
is_w_columnwise
:
bool
,
):
check_nvfp4_gemm_versus_reference
(
x_dtype
=
x_dtype
,
w_dtype
=
w_dtype
,
out_dtype
=
out_dtype
,
M
=
M
,
K
=
K
,
N
=
N
,
accumulate
=
accumulate
,
x_columnwise
=
is_x_columnwise
,
w_columnwise
=
is_w_columnwise
,
)
tests/pytorch/nvfp4/test_nvfp4_module_exact.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.experimental
import
quantization_nvfp4
from
transformer_engine.pytorch.experimental
import
utils
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
class
GetRecipes
:
@
staticmethod
def
nvfp4_vanilla
():
nvfp4_recipe
=
recipe
.
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
recipe
.
QParams
()
nvfp4_recipe
.
fp4_quant_fwd_weight
=
recipe
.
QParams
()
nvfp4_recipe
.
fp4_quant_bwd_grad
=
recipe
.
QParams
()
return
nvfp4_recipe
@
staticmethod
def
nvfp4_rht_only
():
nvfp4_recipe
=
recipe
.
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
recipe
.
QParams
(
random_hadamard_transform
=
True
)
nvfp4_recipe
.
fp4_quant_fwd_weight
=
recipe
.
QParams
(
random_hadamard_transform
=
False
)
nvfp4_recipe
.
fp4_quant_bwd_grad
=
recipe
.
QParams
(
random_hadamard_transform
=
True
)
return
nvfp4_recipe
@
staticmethod
def
nvfp4_2d_quantization_only
():
nvfp4_recipe
=
recipe
.
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
recipe
.
QParams
(
fp4_2d_quantization
=
False
)
nvfp4_recipe
.
fp4_quant_fwd_weight
=
recipe
.
QParams
(
fp4_2d_quantization
=
True
)
nvfp4_recipe
.
fp4_quant_bwd_grad
=
recipe
.
QParams
(
fp4_2d_quantization
=
False
)
return
nvfp4_recipe
@
staticmethod
def
nvfp4_rht_and_2d_quantization
():
nvfp4_recipe
=
recipe
.
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
recipe
.
QParams
(
random_hadamard_transform
=
True
,
fp4_2d_quantization
=
False
)
nvfp4_recipe
.
fp4_quant_fwd_weight
=
recipe
.
QParams
(
random_hadamard_transform
=
False
,
fp4_2d_quantization
=
True
)
nvfp4_recipe
.
fp4_quant_bwd_grad
=
recipe
.
QParams
(
random_hadamard_transform
=
True
,
fp4_2d_quantization
=
False
)
return
nvfp4_recipe
@
staticmethod
def
nvfp4_recipe_to_test
(
with_rht
:
bool
=
False
,
with_2d_quantization
:
bool
=
False
):
if
with_rht
and
with_2d_quantization
:
return
GetRecipes
.
nvfp4_rht_and_2d_quantization
()
elif
with_rht
:
return
GetRecipes
.
nvfp4_rht_only
()
elif
with_2d_quantization
:
return
GetRecipes
.
nvfp4_2d_quantization_only
()
else
:
return
GetRecipes
.
nvfp4_vanilla
()
def
get_nvfp4_quantizer_factory
(
with_rht
:
bool
=
False
,
with_2d_quantization
:
bool
=
False
):
"""
Create a quantizer factory for NVFP4 reference implementation.
This factory returns NVFP4QuantizerRef instances based on the role and configuration.
Used with CustomRecipe to create reference quantizers.
Args:
with_rht: Whether to enable random Hadamard transform
with_2d_quantization: Whether to use 2D quantization (16x16 tiles for weights)
Returns:
A factory function that takes a role string and returns a quantizer instance
"""
def
factory
(
role
):
if
role
==
"linear_input"
:
return
quantization_nvfp4
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
with_rht
,
)
elif
role
==
"linear_weight"
:
return
quantization_nvfp4
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
16
,
16
)
if
with_2d_quantization
else
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
False
,
)
elif
role
==
"linear_output"
:
# Output quantization not used
return
None
elif
role
==
"linear_grad_output"
:
return
quantization_nvfp4
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
with_rht
,
)
elif
role
==
"linear_grad_input"
:
# Grad input quantization not used
return
None
else
:
# For any other roles, return None
return
None
return
factory
def
reset_rng_states
():
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
def
check_nvfp4_module_versus_reference
(
module_class
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
,
x_dtype
:
torch
.
dtype
,
num_steps
:
int
=
1
,
with_rht
:
bool
=
False
,
with_2d_quantization
:
bool
=
False
,
):
"""
Compare native NVFP4 module against reference implementation.
Args:
module_class: te.Linear or te.LayerNormLinear
in_features: Input feature dimension
out_features: Output feature dimension
bias: Whether to use bias
x_dtype: Input tensor dtype
num_steps: Number of forward/backward steps to test
"""
device
=
"cuda"
batch_size
=
32
seq_len
=
128
# Create both modules with identical initialization
reset_rng_states
()
# Create native module
print
(
"
\n
Create native module"
)
if
module_class
==
te
.
Linear
:
native_module
=
te
.
Linear
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
params_dtype
=
x_dtype
,
)
elif
module_class
==
te
.
LayerNormLinear
:
native_module
=
te
.
LayerNormLinear
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
params_dtype
=
x_dtype
,
)
else
:
raise
ValueError
(
f
"Unsupported module class:
{
module_class
}
"
)
# Create reference module with same weights
reset_rng_states
()
# Create reference module
print
(
"Create reference module"
)
if
module_class
==
te
.
Linear
:
ref_module
=
te
.
Linear
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
params_dtype
=
x_dtype
,
)
elif
module_class
==
te
.
LayerNormLinear
:
ref_module
=
te
.
LayerNormLinear
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
params_dtype
=
x_dtype
,
)
# Sync weights between native and reference modules
with
torch
.
no_grad
():
# Copy main weight and bias parameters
if
hasattr
(
native_module
,
"weight"
)
and
hasattr
(
ref_module
,
"weight"
):
ref_module
.
weight
.
copy_
(
native_module
.
weight
)
if
bias
and
hasattr
(
native_module
,
"bias"
)
and
hasattr
(
ref_module
,
"bias"
):
ref_module
.
bias
.
copy_
(
native_module
.
bias
)
# Copy layer norm parameters if they exist
if
hasattr
(
native_module
,
"layer_norm_weight"
)
and
hasattr
(
ref_module
,
"layer_norm_weight"
):
ref_module
.
layer_norm_weight
.
copy_
(
native_module
.
layer_norm_weight
)
if
hasattr
(
native_module
,
"layer_norm_bias"
)
and
hasattr
(
ref_module
,
"layer_norm_bias"
):
ref_module
.
layer_norm_bias
.
copy_
(
native_module
.
layer_norm_bias
)
# Create recipes for native and reference implementations
nvfp4_recipe
=
GetRecipes
.
nvfp4_recipe_to_test
(
with_rht
,
with_2d_quantization
)
nvfp4_ref_factory
=
get_nvfp4_quantizer_factory
(
with_rht
,
with_2d_quantization
)
nvfp4_ref_recipe
=
recipe
.
CustomRecipe
(
qfactory
=
nvfp4_ref_factory
)
# Training loop comparison
native_outputs
=
[]
ref_outputs
=
[]
for
step
in
range
(
num_steps
):
torch
.
manual_seed
(
1234
+
step
)
torch
.
cuda
.
manual_seed
(
1234
+
step
)
x_shape
=
(
batch_size
,
seq_len
,
in_features
)
x_val
=
torch
.
normal
(
mean
=
0.0
,
std
=
1.0
,
size
=
x_shape
,
dtype
=
x_dtype
,
device
=
device
)
x_native
=
x_val
.
clone
().
detach
().
requires_grad_
(
True
)
x_ref
=
x_native
.
clone
().
detach
().
requires_grad_
(
True
)
grad_output_shape
=
(
batch_size
,
seq_len
,
out_features
)
grad_output_val
=
torch
.
normal
(
mean
=
0.0
,
std
=
1.0
,
size
=
grad_output_shape
,
dtype
=
x_dtype
,
device
=
device
)
grad_output
=
grad_output_val
.
clone
().
detach
()
# Native forward/backward
with
te
.
autocast
(
enabled
=
True
,
recipe
=
nvfp4_recipe
):
# enable weight cache by giving is_first_microbatch
y_native
=
native_module
(
x_native
,
is_first_microbatch
=
(
step
==
0
))
y_native
.
backward
(
grad_output
)
# Reference forward/backward
with
te
.
autocast
(
enabled
=
True
,
recipe
=
nvfp4_ref_recipe
):
y_ref
=
ref_module
(
x_ref
)
y_ref
.
backward
(
grad_output
)
# Store results
native_outputs
.
append
(
{
"output"
:
y_native
.
detach
().
clone
(),
"input_grad"
:
(
x_native
.
grad
.
detach
().
clone
()
if
x_native
.
grad
is
not
None
else
None
),
"weight_grad"
:
(
native_module
.
weight
.
grad
.
detach
().
clone
()
if
native_module
.
weight
.
grad
is
not
None
else
None
),
"bias_grad"
:
(
native_module
.
bias
.
grad
.
detach
().
clone
()
if
bias
and
native_module
.
bias
.
grad
is
not
None
else
None
),
}
)
ref_outputs
.
append
(
{
"output"
:
y_ref
.
detach
().
clone
(),
"input_grad"
:
(
x_ref
.
grad
.
detach
().
clone
()
if
x_ref
.
grad
is
not
None
else
None
),
"weight_grad"
:
(
ref_module
.
weight
.
grad
.
detach
().
clone
()
if
ref_module
.
weight
.
grad
is
not
None
else
None
),
"bias_grad"
:
(
ref_module
.
bias
.
grad
.
detach
().
clone
()
if
bias
and
ref_module
.
bias
.
grad
is
not
None
else
None
),
}
)
# Compare results across all steps
for
step
in
range
(
num_steps
):
native_out
=
native_outputs
[
step
]
ref_out
=
ref_outputs
[
step
]
# Compare outputs
torch
.
testing
.
assert_close
(
native_out
[
"output"
],
ref_out
[
"output"
],
atol
=
1e-6
,
rtol
=
1e-6
,
msg
=
f
"Output mismatch at step
{
step
}
"
,
)
# Compare input gradients
torch
.
testing
.
assert_close
(
native_out
[
"input_grad"
],
ref_out
[
"input_grad"
],
atol
=
1e-6
,
rtol
=
1e-6
,
msg
=
(
f
"Input gradient mismatch at step
{
step
}
. Native:
{
native_out
[
'input_grad'
]
}
, Ref:"
f
"
{
ref_out
[
'input_grad'
]
}
"
),
)
# Compare weight gradients
torch
.
testing
.
assert_close
(
native_out
[
"weight_grad"
],
ref_out
[
"weight_grad"
],
atol
=
1e-6
,
rtol
=
1e-6
,
msg
=
(
f
"Weight gradient mismatch at step
{
step
}
. Native:
{
native_out
[
'weight_grad'
]
}
,"
f
" Ref:
{
ref_out
[
'weight_grad'
]
}
"
),
)
# Compare bias gradients
if
bias
and
native_out
[
"bias_grad"
]
is
not
None
and
ref_out
[
"bias_grad"
]
is
not
None
:
torch
.
testing
.
assert_close
(
native_out
[
"bias_grad"
],
ref_out
[
"bias_grad"
],
atol
=
1e-6
,
rtol
=
1e-6
,
msg
=
f
"Bias gradient mismatch at step
{
step
}
"
,
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"in_features, out_features"
,
[
(
128
,
256
),
(
256
,
128
),
(
512
,
512
),
(
768
,
3072
),
(
1024
,
4096
),
],
)
# @pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
False
],
ids
=
[
"no_bias"
])
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"num_steps"
,
[
1
,
3
],
ids
=
[
"single_step"
,
"multi_step"
])
@
pytest
.
mark
.
parametrize
(
"with_rht"
,
[
True
,
False
],
ids
=
[
"with_rht"
,
"no_rht"
])
@
pytest
.
mark
.
parametrize
(
"with_2d_quantization"
,
[
True
,
False
],
ids
=
[
"with_2d_quantization"
,
"no_2d_quantization"
]
)
def
test_nvfp4_linear_versus_reference
(
in_features
:
int
,
out_features
:
int
,
bias
:
bool
,
x_dtype
:
torch
.
dtype
,
num_steps
:
int
,
with_rht
:
bool
,
with_2d_quantization
:
bool
,
):
"""Test NVFP4 Linear module against reference implementation."""
if
with_rht
and
x_dtype
!=
torch
.
bfloat16
:
pytest
.
skip
(
"RHT is only supported for bfloat16 input"
)
check_nvfp4_module_versus_reference
(
module_class
=
te
.
Linear
,
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
x_dtype
=
x_dtype
,
num_steps
=
num_steps
,
with_rht
=
with_rht
,
with_2d_quantization
=
with_2d_quantization
,
)
def
check_nvfp4_layernorm_linear_versus_reference
(
in_features
:
int
,
out_features
:
int
,
bias
:
bool
,
normalization
:
str
,
x_dtype
:
torch
.
dtype
,
num_steps
:
int
=
1
,
with_rht
:
bool
=
False
,
with_2d_quantization
:
bool
=
False
,
):
"""
Compare native NVFP4 LayerNormLinear module against reference implementation,
including ln_out.
"""
device
=
"cuda"
batch_size
=
32
seq_len
=
128
# Create both modules with identical initialization
reset_rng_states
()
# Native module
native_module
=
te
.
LayerNormLinear
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
params_dtype
=
x_dtype
,
normalization
=
normalization
,
return_layernorm_output
=
True
,
)
# Reference module
reset_rng_states
()
ref_module
=
te
.
LayerNormLinear
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
params_dtype
=
x_dtype
,
normalization
=
normalization
,
return_layernorm_output
=
True
,
)
# Sync weights and LN params
with
torch
.
no_grad
():
if
hasattr
(
native_module
,
"weight"
)
and
hasattr
(
ref_module
,
"weight"
):
ref_module
.
weight
.
copy_
(
native_module
.
weight
)
if
bias
and
hasattr
(
native_module
,
"bias"
)
and
hasattr
(
ref_module
,
"bias"
):
ref_module
.
bias
.
copy_
(
native_module
.
bias
)
if
hasattr
(
native_module
,
"layer_norm_weight"
)
and
hasattr
(
ref_module
,
"layer_norm_weight"
):
if
(
native_module
.
layer_norm_weight
is
not
None
and
ref_module
.
layer_norm_weight
is
not
None
):
ref_module
.
layer_norm_weight
.
copy_
(
native_module
.
layer_norm_weight
)
if
hasattr
(
native_module
,
"layer_norm_bias"
)
and
hasattr
(
ref_module
,
"layer_norm_bias"
):
if
native_module
.
layer_norm_bias
is
not
None
and
ref_module
.
layer_norm_bias
is
not
None
:
ref_module
.
layer_norm_bias
.
copy_
(
native_module
.
layer_norm_bias
)
# Create recipes for native and reference implementations
nvfp4_recipe
=
GetRecipes
.
nvfp4_recipe_to_test
(
with_rht
,
with_2d_quantization
)
nvfp4_ref_factory
=
get_nvfp4_quantizer_factory
(
with_rht
,
with_2d_quantization
)
nvfp4_ref_recipe
=
recipe
.
CustomRecipe
(
qfactory
=
nvfp4_ref_factory
)
native_outputs
=
[]
ref_outputs
=
[]
for
step
in
range
(
num_steps
):
torch
.
manual_seed
(
1234
+
step
)
torch
.
cuda
.
manual_seed
(
1234
+
step
)
x_shape
=
(
batch_size
,
seq_len
,
in_features
)
x_val
=
torch
.
normal
(
mean
=
0.0
,
std
=
1.0
,
size
=
x_shape
,
dtype
=
x_dtype
,
device
=
device
)
x_native
=
x_val
.
clone
().
detach
().
requires_grad_
(
True
)
x_ref
=
x_native
.
clone
().
detach
().
requires_grad_
(
True
)
grad_output_shape
=
(
batch_size
,
seq_len
,
out_features
)
grad_output_val
=
torch
.
normal
(
mean
=
0.0
,
std
=
1.0
,
size
=
grad_output_shape
,
dtype
=
x_dtype
,
device
=
device
)
grad_output
=
grad_output_val
.
clone
().
detach
()
# Native forward/backward
with
te
.
autocast
(
enabled
=
True
,
recipe
=
nvfp4_recipe
):
y_native
,
ln_out_native
=
native_module
(
x_native
,
is_first_microbatch
=
(
step
==
0
))
y_native
.
backward
(
grad_output
)
# Reference forward/backward
with
te
.
autocast
(
enabled
=
True
,
recipe
=
nvfp4_ref_recipe
):
y_ref
,
ln_out_ref
=
ref_module
(
x_ref
)
y_ref
.
backward
(
grad_output
)
native_outputs
.
append
(
{
"output"
:
y_native
.
detach
().
clone
(),
"ln_out"
:
ln_out_native
.
detach
().
clone
(),
"input_grad"
:
(
x_native
.
grad
.
detach
().
clone
()
if
x_native
.
grad
is
not
None
else
None
),
"weight_grad"
:
(
native_module
.
weight
.
grad
.
detach
().
clone
()
if
native_module
.
weight
.
grad
is
not
None
else
None
),
"bias_grad"
:
(
native_module
.
bias
.
grad
.
detach
().
clone
()
if
bias
and
native_module
.
bias
.
grad
is
not
None
else
None
),
}
)
ref_outputs
.
append
(
{
"output"
:
y_ref
.
detach
().
clone
(),
"ln_out"
:
ln_out_ref
.
detach
().
clone
(),
"input_grad"
:
(
x_ref
.
grad
.
detach
().
clone
()
if
x_ref
.
grad
is
not
None
else
None
),
"weight_grad"
:
(
ref_module
.
weight
.
grad
.
detach
().
clone
()
if
ref_module
.
weight
.
grad
is
not
None
else
None
),
"bias_grad"
:
(
ref_module
.
bias
.
grad
.
detach
().
clone
()
if
bias
and
ref_module
.
bias
.
grad
is
not
None
else
None
),
}
)
# Compare results
for
step
in
range
(
num_steps
):
n
=
native_outputs
[
step
]
r
=
ref_outputs
[
step
]
torch
.
testing
.
assert_close
(
n
[
"output"
],
r
[
"output"
],
atol
=
1e-6
,
rtol
=
1e-6
,
msg
=
f
"Output mismatch at step
{
step
}
"
,
)
torch
.
testing
.
assert_close
(
n
[
"ln_out"
],
r
[
"ln_out"
],
atol
=
1e-6
,
rtol
=
1e-6
,
msg
=
f
"LN output mismatch at step
{
step
}
"
,
)
torch
.
testing
.
assert_close
(
n
[
"input_grad"
],
r
[
"input_grad"
],
atol
=
1e-6
,
rtol
=
1e-6
,
msg
=
f
"Input gradient mismatch at step
{
step
}
"
,
)
torch
.
testing
.
assert_close
(
n
[
"weight_grad"
],
r
[
"weight_grad"
],
atol
=
1e-6
,
rtol
=
1e-6
,
msg
=
f
"Weight gradient mismatch at step
{
step
}
"
,
)
if
bias
and
n
[
"bias_grad"
]
is
not
None
and
r
[
"bias_grad"
]
is
not
None
:
torch
.
testing
.
assert_close
(
n
[
"bias_grad"
],
r
[
"bias_grad"
],
atol
=
1e-6
,
rtol
=
1e-6
,
msg
=
f
"Bias gradient mismatch at step
{
step
}
"
,
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"in_features, out_features"
,
[
(
128
,
256
),
(
256
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
False
],
ids
=
[
"no_bias"
])
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"num_steps"
,
[
1
],
ids
=
[
"single_step"
])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
[
"LayerNorm"
,
"RMSNorm"
],
ids
=
[
"LayerNorm"
,
"RMSNorm"
])
@
pytest
.
mark
.
parametrize
(
"with_rht"
,
[
True
,
False
],
ids
=
[
"with_rht"
,
"no_rht"
])
@
pytest
.
mark
.
parametrize
(
"with_2d_quantization"
,
[
True
,
False
],
ids
=
[
"with_2d_quantization"
,
"no_2d_quantization"
]
)
def
test_nvfp4_layernorm_linear_versus_reference
(
in_features
:
int
,
out_features
:
int
,
bias
:
bool
,
normalization
:
str
,
x_dtype
:
torch
.
dtype
,
num_steps
:
int
,
with_rht
:
bool
,
with_2d_quantization
:
bool
,
):
if
with_rht
and
x_dtype
!=
torch
.
bfloat16
:
pytest
.
skip
(
"RHT is only supported for bfloat16 input"
)
check_nvfp4_layernorm_linear_versus_reference
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
normalization
=
normalization
,
x_dtype
=
x_dtype
,
num_steps
=
num_steps
,
with_rht
=
with_rht
,
with_2d_quantization
=
with_2d_quantization
,
)
Prev
1
2
3
4
5
6
7
8
9
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment