Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
897 additions
and
43 deletions
+897
-43
tests/pytorch/distributed/test_numerics_exact.py
tests/pytorch/distributed/test_numerics_exact.py
+1
-1
tests/pytorch/distributed/test_sanity.py
tests/pytorch/distributed/test_sanity.py
+106
-14
tests/pytorch/distributed/test_torch_fsdp2.py
tests/pytorch/distributed/test_torch_fsdp2.py
+1
-1
tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py
...rch/layernorm_mlp/test_selective_activation_checkpoint.py
+175
-0
tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
+18
-7
tests/pytorch/nvfp4/test_nvfp4_group_quantize.py
tests/pytorch/nvfp4/test_nvfp4_group_quantize.py
+309
-0
tests/pytorch/nvfp4/test_nvfp4_module_exact.py
tests/pytorch/nvfp4/test_nvfp4_module_exact.py
+1
-1
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
+1
-1
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
+1
-1
tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
+196
-1
tests/pytorch/references/blockwise_fp8_gemm_reference.py
tests/pytorch/references/blockwise_fp8_gemm_reference.py
+1
-1
tests/pytorch/references/blockwise_quantizer_reference.py
tests/pytorch/references/blockwise_quantizer_reference.py
+1
-1
tests/pytorch/references/quantize_scale_calc.py
tests/pytorch/references/quantize_scale_calc.py
+1
-1
tests/pytorch/references/ref_per_tensor_cs.py
tests/pytorch/references/ref_per_tensor_cs.py
+1
-1
tests/pytorch/test_checkpoint.py
tests/pytorch/test_checkpoint.py
+1
-1
tests/pytorch/test_cpu_offloading.py
tests/pytorch/test_cpu_offloading.py
+10
-3
tests/pytorch/test_cpu_offloading_v1.py
tests/pytorch/test_cpu_offloading_v1.py
+1
-1
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+28
-4
tests/pytorch/test_custom_recipe.py
tests/pytorch/test_custom_recipe.py
+1
-1
tests/pytorch/test_deferred_init.py
tests/pytorch/test_deferred_init.py
+43
-2
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
tests/pytorch/distributed/test_numerics_exact.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/distributed/test_sanity.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -7,7 +7,16 @@ import sys
import
pytest
import
torch
import
transformer_engine
from
transformer_engine.pytorch
import
DotProductAttention
,
TransformerLayer
,
Linear
from
transformer_engine.pytorch
import
(
DotProductAttention
,
TransformerLayer
,
Linear
,
GroupedLinear
,
NVFP4Quantizer
,
autocast
,
is_nvfp4_available
,
)
from
transformer_engine.common
import
recipe
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
...
...
@@ -17,9 +26,13 @@ model_configs = {
"small"
:
ModelConfig
(
2
,
10
,
2
,
16
),
}
nvfp4_available
,
reason_for_no_nvfp4
=
is_nvfp4_available
(
return_reason
=
True
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
,
"GroupedLinear"
]
)
def
test_current_device
(
model
,
module
):
"""Test cases where current device is different from tensor device"""
...
...
@@ -42,7 +55,29 @@ def test_current_device(model, module):
self_attn_mask_type
=
"padding"
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
seqlens_q
=
torch
.
randint
(
1
,
config
.
max_seqlen_q
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_q
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
seqlens_kv
=
torch
.
randint
(
1
,
config
.
max_seqlen_kv
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_kv
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_kv
[
1
:]
=
torch
.
cumsum
(
seqlens_kv
,
dim
=
0
)
num_tokens
=
cu_seqlens_q
[
-
1
]
args
=
[
torch
.
randn
(
(
num_tokens
,
config
.
hidden_size
),
...
...
@@ -51,37 +86,55 @@ def test_current_device(model, module):
requires_grad
=
True
,
)
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
if
module
==
"DotProductAttention"
:
el
if
module
==
"DotProductAttention"
:
model
=
DotProductAttention
(
config
.
num_heads
,
config
.
head_dim_qk
,
qkv_format
=
"thd"
,
attn_mask_type
=
"padding"
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
seqlens_q
=
torch
.
randint
(
1
,
config
.
max_seqlen_q
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_q
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
seqlens_kv
=
torch
.
randint
(
1
,
config
.
max_seqlen_kv
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_kv
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_kv
[
1
:]
=
torch
.
cumsum
(
seqlens_kv
,
dim
=
0
)
num_tokens
=
cu_seqlens_q
[
-
1
]
args
=
[
torch
.
randn
(
num_tokens
,
config
.
num_heads
,
config
.
head_dim_qk
,
dtype
=
dtype
,
device
=
tensor_device
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
)
for
_
in
range
(
3
)
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
tensor_device
)]
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
)
]
elif
module
==
"Linear"
:
model
=
Linear
(
config
.
hidden_size
,
...
...
@@ -97,6 +150,24 @@ def test_current_device(model, module):
requires_grad
=
True
,
)
]
elif
module
==
"GroupedLinear"
:
num_gemms
=
4
model
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
params_dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
args
=
[
torch
.
randn
(
(
config
.
max_seqlen_q
*
config
.
batch_size
*
(
num_gemms
-
1
),
config
.
hidden_size
),
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
),
[
0
]
+
[
config
.
max_seqlen_q
*
config
.
batch_size
]
*
(
num_gemms
-
1
),
# Empty first split.
]
current_device_before
=
torch
.
cuda
.
current_device
()
out
=
model
(
*
args
,
**
kwargs
)
...
...
@@ -118,3 +189,24 @@ def test_current_device(model, module):
assert
(
tensor_device_grad
==
tensor_device
),
"The gradient tensor should be the same as the input tensors!"
@
pytest
.
mark
.
skipif
(
not
nvfp4_available
,
reason
=
reason_for_no_nvfp4
)
def
test_nvfp4_rht_cache
():
"""Ensure correct RHT cache for NVFP4."""
num_devices
=
torch
.
cuda
.
device_count
()
assert
num_devices
>
1
,
"This test requires more than one GPU!"
# Populate cache on last device.
with
torch
.
cuda
.
device
(
num_devices
-
1
):
_
=
NVFP4Quantizer
()
hidden_size
=
128
dtype
=
torch
.
bfloat16
model
=
Linear
(
hidden_size
,
hidden_size
,
params_dtype
=
dtype
)
inp
=
torch
.
randn
(
hidden_size
,
hidden_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
fp4_recipe
=
recipe
.
NVFP4BlockScaling
()
with
autocast
(
recipe
=
fp4_recipe
):
_
=
model
(
inp
)
tests/pytorch/distributed/test_torch_fsdp2.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py
0 → 100644
View file @
0d874a4e
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
torch
from
transformer_engine.pytorch
import
LayerNormMLP
import
pytest
torch
.
manual_seed
(
1234
)
device
=
torch
.
device
(
"cuda"
)
class
_Sequential
(
torch
.
nn
.
Sequential
):
"""Sequential model that forwards keyword arguments to modules"""
def
forward
(
self
,
input_
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
x
=
input_
for
module
in
self
:
x
=
module
(
x
,
**
kwargs
)
return
x
class
ModelConfig
:
def
__init__
(
self
,
hidden_size
:
int
=
128
,
ffn_hidden_size
:
int
=
512
,
layers
:
int
=
1
,
):
self
.
_hidden_size
=
hidden_size
self
.
_ffn_hidden_size
=
ffn_hidden_size
self
.
_layers
=
layers
def
build
(
self
):
ln_list
,
sln_list
=
[],
[]
for
_
in
range
(
self
.
_layers
):
ln
=
LayerNormMLP
(
self
.
_hidden_size
,
self
.
_ffn_hidden_size
,
checkpoint
=
False
).
to
(
device
)
sln
=
LayerNormMLP
(
self
.
_hidden_size
,
self
.
_ffn_hidden_size
,
checkpoint
=
True
).
to
(
device
)
with
torch
.
no_grad
():
sln
.
layer_norm_weight
=
torch
.
nn
.
Parameter
(
ln
.
layer_norm_weight
.
clone
())
sln
.
layer_norm_bias
=
torch
.
nn
.
Parameter
(
ln
.
layer_norm_bias
.
clone
())
sln
.
fc1_weight
=
torch
.
nn
.
Parameter
(
ln
.
fc1_weight
.
clone
())
sln
.
fc2_weight
=
torch
.
nn
.
Parameter
(
ln
.
fc2_weight
.
clone
())
sln
.
fc1_bias
=
torch
.
nn
.
Parameter
(
ln
.
fc1_bias
.
clone
())
sln
.
fc2_bias
=
torch
.
nn
.
Parameter
(
ln
.
fc2_bias
.
clone
())
ln_list
.
append
(
ln
)
sln_list
.
append
(
sln
)
ln_model
=
_Sequential
(
*
ln_list
)
sln_model
=
_Sequential
(
*
sln_list
)
return
ln_model
,
sln_model
config
=
{
"small"
:
ModelConfig
(
128
,
512
,
12
),
"medium"
:
ModelConfig
(
512
,
2048
,
12
),
"large"
:
ModelConfig
(
1024
,
4096
,
12
),
"huge"
:
ModelConfig
(
2048
,
8192
,
12
),
}
seq_sizes
=
[
2
**
7
,
2
**
10
,
2
**
14
,
2
**
16
]
def
_warmup
(
model
,
tensor
):
for
_
in
range
(
3
):
model
(
tensor
).
sum
().
backward
()
def
_run_fwd
(
model
,
tensor
):
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
start_time
,
end_time
=
torch
.
cuda
.
Event
(
enable_timing
=
True
),
torch
.
cuda
.
Event
(
enable_timing
=
True
)
torch
.
cuda
.
synchronize
()
start_mem
=
torch
.
cuda
.
memory_allocated
(
device
)
start_time
.
record
()
out
=
model
(
tensor
)
end_time
.
record
()
end_time
.
synchronize
()
elapsed
=
start_time
.
elapsed_time
(
end_time
)
peak_mem
=
torch
.
cuda
.
max_memory_allocated
(
device
)
mem
=
float
(
peak_mem
-
start_mem
)
return
out
,
elapsed
,
mem
def
_run_bwd
(
model
,
out
):
model
.
zero_grad
(
set_to_none
=
False
)
loss
=
out
.
sum
()
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
start_time
,
end_time
=
torch
.
cuda
.
Event
(
enable_timing
=
True
),
torch
.
cuda
.
Event
(
enable_timing
=
True
)
torch
.
cuda
.
synchronize
()
start_mem
=
torch
.
cuda
.
memory_allocated
(
device
)
start_time
.
record
()
loss
.
backward
()
end_time
.
record
()
end_time
.
synchronize
()
elapsed
=
start_time
.
elapsed_time
(
end_time
)
peak_mem
=
torch
.
cuda
.
max_memory_allocated
(
device
)
mem
=
float
(
peak_mem
-
start_mem
)
param_grads
=
_collect_param_grads
(
model
)
return
param_grads
,
elapsed
,
mem
def
_max_diff
(
ref
,
other
):
"""Return max absolute difference between two tensors or collections."""
if
ref
is
None
or
other
is
None
:
return
0.0
if
isinstance
(
ref
,
(
list
,
tuple
)):
diffs
=
[
_max_diff
(
r
,
o
)
for
r
,
o
in
zip
(
ref
,
other
)]
return
max
(
diffs
)
if
diffs
else
0.0
return
torch
.
max
(
torch
.
abs
(
ref
.
detach
()
-
other
.
detach
())).
item
()
def
_collect_param_grads
(
model
):
grads
=
{}
for
name
,
param
in
model
.
named_parameters
():
if
param
.
grad
is
None
:
continue
key
=
_param_key
(
name
)
if
key
is
not
None
:
grads
[
key
]
=
param
.
grad
.
detach
().
clone
()
return
grads
def
_param_key
(
name
):
return
name
.
split
(
"."
)[
-
1
]
@
pytest
.
mark
.
parametrize
(
"size"
,
config
.
keys
())
@
pytest
.
mark
.
parametrize
(
"seq_size"
,
seq_sizes
)
def
test_selective_activation_checkpoint
(
size
,
seq_size
):
ln_model
,
sln_model
=
config
[
size
].
build
()
data
=
torch
.
randn
((
seq_size
,
config
[
size
].
_hidden_size
),
device
=
device
)
_warmup
(
ln_model
,
data
)
ln_fwd_out
,
ln_fwd_time
,
ln_fwd_mem
=
_run_fwd
(
ln_model
,
data
)
ln_grads
,
ln_bwd_time
,
ln_bwd_mem
=
_run_bwd
(
ln_model
,
ln_fwd_out
)
_warmup
(
sln_model
,
data
)
sln_fwd_out
,
sln_fwd_time
,
sln_fwd_mem
=
_run_fwd
(
sln_model
,
data
)
sln_grads
,
sln_bwd_time
,
sln_bwd_mem
=
_run_bwd
(
sln_model
,
sln_fwd_out
)
assert
ln_fwd_mem
>
6
*
sln_fwd_mem
,
(
"selective activation checkpointing does not reduce forward memory by 6X, only by"
f
"
{
ln_fwd_mem
/
sln_fwd_mem
}
!"
)
assert
ln_bwd_time
<
sln_bwd_time
,
(
"selective activation activation checkpointing backward pass is NOT slower than native!"
f
" got Native LayerNormMLP Backward Time:
{
ln_bwd_time
}
ms and Selective Activation"
f
" Checkpointed LayerNormMLP Backward Time:
{
sln_bwd_time
}
ms"
)
diff
=
_max_diff
(
ln_fwd_out
,
sln_fwd_out
)
assert
diff
==
0.0
,
f
"outputs are not equal! maximum difference
{
diff
}
"
for
key
in
[
"layer_norm_weight"
,
"layer_norm_bias"
,
"fc1_weight"
,
"fc1_bias"
,
"fc2_weight"
,
"fc2_bias"
,
]:
diff
=
_max_diff
(
ln_grads
[
key
],
sln_grads
[
key
])
assert
diff
==
0.0
,
f
"gradients for
{
key
}
are not equal! maximum difference:
{
diff
}
"
tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -122,8 +122,15 @@ def check_nvfp4_gemm_versus_reference(
)
# Create reference quantized tensors needed by reference GEMM
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
)
w_nvfp4_ref
=
ref_quantizer
.
quantize
(
w
)
# Reference GEMM is only rowwise.
if
x_columnwise
:
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
.
t
().
contiguous
())
else
:
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
)
if
w_columnwise
:
w_nvfp4_ref
=
ref_quantizer
.
quantize
(
w
.
t
().
contiguous
())
else
:
w_nvfp4_ref
=
ref_quantizer
.
quantize
(
w
)
# Reference GEMM using quantizer's qgemm method
y_ref
=
ref_quantizer
.
qgemm
(
...
...
@@ -155,6 +162,10 @@ def check_nvfp4_gemm_versus_reference(
use_grad
=
False
use_split_accumulator
=
False
if
x_columnwise
:
x_nvfp4_native
.
update_usage
(
rowwise_usage
=
False
)
if
w_columnwise
:
w_nvfp4_native
.
update_usage
(
rowwise_usage
=
False
)
# Native cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
...
...
@@ -212,11 +223,11 @@ def check_nvfp4_gemm_versus_reference(
@
pytest
.
mark
.
parametrize
(
"is_x_columnwise, is_w_columnwise"
,
[
(
False
,
False
),
#
Only rowwise x rowwise is supported by reference GEMM
# Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization
# Columnwise layouts are not supported by the reference implementation
(
False
,
False
),
#
TN
(
True
,
False
),
# NN
(
True
,
True
),
# NT
],
ids
=
[
"rowxrow"
],
ids
=
[
"rowxrow"
,
"colxrow"
,
"colxcol"
],
)
def
test_nvfp4_gemm_versus_reference
(
M
:
int
,
...
...
tests/pytorch/nvfp4/test_nvfp4_group_quantize.py
0 → 100644
View file @
0d874a4e
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py
# and also the test_nvfp4_rht_quantize_exact.py.
# Separate to make sure all the functionalities are working as expected.
# Otherwise reference implementation will get messy.
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality.
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch.custom_recipes.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.pytorch.custom_recipes
import
utils
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.common.recipe
import
NVFP4BlockScaling
import
pytest
import
torch
import
random
import
math
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
def
generate_random_multiples_sum
(
total
=
8192
,
n
=
4
,
multiple
=
64
):
if
total
%
multiple
!=
0
:
raise
ValueError
(
f
"Total (
{
total
}
) must be a multiple of
{
multiple
}
"
)
if
(
total
//
multiple
)
<
n
:
raise
ValueError
(
"Total too small for given n and multiple."
)
# Work in units of multiples
total_units
=
total
//
multiple
# choose n−1 random cut points in [1, total_units−1)
cuts
=
sorted
(
random
.
sample
(
range
(
1
,
total_units
),
n
-
1
))
# convert to segment lengths
parts
=
(
[
cuts
[
0
]]
+
[
cuts
[
i
]
-
cuts
[
i
-
1
]
for
i
in
range
(
1
,
len
(
cuts
))]
+
[
total_units
-
cuts
[
-
1
]]
)
# convert back to multiples
return
[
p
*
multiple
for
p
in
parts
]
def
generate_split_sections
(
M
:
int
,
N
:
int
,
edge_cases
:
str
)
->
list
[
int
]:
least_multiple
=
64
num_chunks
=
4
split_sections
=
None
avg_split
=
M
//
num_chunks
if
M
==
0
or
N
==
0
:
# all zeros
return
[
0
]
*
num_chunks
if
edge_cases
==
"regular"
:
split_sections
=
[
avg_split
]
*
num_chunks
elif
edge_cases
==
"zero_tokens_front"
:
split_sections
=
[
0
]
+
[
avg_split
]
*
(
num_chunks
-
2
)
+
[
avg_split
*
2
]
elif
edge_cases
==
"zero_tokens_end"
:
split_sections
=
[
avg_split
*
2
]
+
[
avg_split
]
*
(
num_chunks
-
2
)
+
[
0
]
elif
edge_cases
==
"zero_tokens_middle"
:
split_sections
=
[
avg_split
]
*
(
num_chunks
-
2
)
+
[
0
]
+
[
avg_split
*
2
]
elif
edge_cases
==
"random_uneven_split"
:
split_sections
=
generate_random_multiples_sum
(
M
,
num_chunks
,
least_multiple
)
else
:
raise
ValueError
(
f
"Invalid edge case:
{
edge_cases
}
"
)
# adds up the split_sections to make it M
assert
sum
(
split_sections
)
==
M
,
"The split_sections do not add up to M"
# make sure every split_section is a multiple of least_multiple
for
split_section
in
split_sections
:
assert
(
split_section
%
least_multiple
==
0
),
"The split_sections are not multiples of least_multiple"
return
split_sections
# Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization without padding
def
get_nvfp4_scale_shape_no_padding
(
shape
,
columnwise
):
M
,
K
=
1
,
1
M
=
math
.
prod
(
shape
[:
-
1
])
K
=
shape
[
-
1
]
if
columnwise
:
outer
=
K
inner
=
math
.
ceil
(
M
/
16
)
return
(
outer
,
inner
)
# rowwise
outer
=
M
inner
=
math
.
ceil
(
K
/
16
)
return
(
outer
,
inner
)
def
reference_group_quantize
(
x
:
torch
.
Tensor
,
quantizers
:
list
[
NVFP4Quantizer
],
split_sections
:
list
[
int
],
return_identity
:
bool
,
return_transpose
:
bool
,
)
->
torch
.
Tensor
:
x_view
=
x
.
reshape
(
-
1
,
x
.
size
(
-
1
))
x_chunks
=
torch
.
split
(
x
,
split_sections
)
# rowwise quantization
x_qx
=
[]
x_sx
=
[]
x_amax_rowwise
=
[]
# columnwise quantization
x_qx_t
=
[]
x_sx_t
=
[]
x_amax_colwise
=
[]
for
i
in
range
(
len
(
x_chunks
)):
x_chunk
=
x_chunks
[
i
]
x_nvfp4_res
=
quantizers
[
i
](
x_chunk
)
if
return_identity
:
x_qx
.
append
(
x_nvfp4_res
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
))
x_sx
.
append
(
x_nvfp4_res
.
_rowwise_scale_inv
)
x_amax_rowwise
.
append
(
x_nvfp4_res
.
_amax_rowwise
)
else
:
x_qx
.
append
(
None
)
x_sx
.
append
(
None
)
x_amax_rowwise
.
append
(
None
)
if
return_transpose
:
x_qx_t
.
append
(
x_nvfp4_res
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
))
x_sx_t
.
append
(
x_nvfp4_res
.
_columnwise_scale_inv
)
x_amax_colwise
.
append
(
x_nvfp4_res
.
_amax_columnwise
)
else
:
x_qx_t
.
append
(
None
)
x_sx_t
.
append
(
None
)
x_amax_colwise
.
append
(
None
)
return
x_qx
,
x_sx
,
x_amax_rowwise
,
x_qx_t
,
x_sx_t
,
x_amax_colwise
def
assert_same_shape_and_dtype
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
)
->
None
:
assert
x
.
shape
==
y
.
shape
assert
x
.
dtype
==
y
.
dtype
def
check_group_quantization_nvfp4_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_identity
:
bool
,
return_transpose
:
bool
,
split_sections
:
list
[
int
],
with_rht
:
bool
=
True
,
with_post_rht_amax
:
bool
=
True
,
with_random_sign_mask
:
bool
=
True
,
)
->
None
:
te_dtype
=
tex
.
DType
.
kFloat4E2M1
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Input
x
=
torch
.
randn
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
num_chunks
=
len
(
split_sections
)
x_splits
=
torch
.
split
(
x
,
split_sections
)
# Quantize
quantizers
=
[
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
return_identity
,
columnwise
=
return_transpose
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
with_rht
,
with_post_rht_amax
=
with_post_rht_amax
,
with_random_sign_mask
=
with_random_sign_mask
,
)
for
_
in
range
(
len
(
split_sections
))
]
x_qx_ref
,
x_sx_ref
,
x_amax_rowwise_ref
,
x_qx_t_ref
,
x_sx_t_ref
,
x_amax_colwise_ref
=
(
reference_group_quantize
(
x
,
quantizers
,
split_sections
,
return_identity
,
return_transpose
)
)
split_quantize_outputs
=
tex
.
split_quantize
(
x
,
split_sections
,
quantizers
)
if
return_identity
:
x_qx
=
[
output
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
for
output
in
split_quantize_outputs
]
x_sx
=
[
output
.
_rowwise_scale_inv
for
output
in
split_quantize_outputs
]
x_amax_rowwise
=
[
output
.
_amax_rowwise
for
output
in
split_quantize_outputs
]
for
i
in
range
(
len
(
x_qx
)):
if
split_sections
[
i
]
==
0
:
# then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype
(
x_amax_rowwise
[
i
],
x_amax_rowwise_ref
[
i
])
assert_same_shape_and_dtype
(
x_qx
[
i
],
x_qx_ref
[
i
])
assert_same_shape_and_dtype
(
x_sx
[
i
],
x_sx_ref
[
i
])
else
:
torch
.
testing
.
assert_close
(
x_amax_rowwise
[
i
],
x_amax_rowwise_ref
[
i
],
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
x_qx
[
i
],
x_qx_ref
[
i
],
atol
=
0.0
,
rtol
=
0.0
)
valid_scale_shape
=
get_nvfp4_scale_shape_no_padding
(
x_splits
[
i
].
shape
,
False
)
x_sx_valid
=
x_sx
[
i
][:
valid_scale_shape
[
0
],
:
valid_scale_shape
[
1
]]
x_sx_ref_valid
=
x_sx_ref
[
i
][:
valid_scale_shape
[
0
],
:
valid_scale_shape
[
1
]]
torch
.
testing
.
assert_close
(
x_sx_valid
,
x_sx_ref_valid
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
x_qx_t
=
[
output
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
for
output
in
split_quantize_outputs
]
x_sx_t
=
[
output
.
_columnwise_scale_inv
for
output
in
split_quantize_outputs
]
x_amax_colwise
=
[
output
.
_amax_columnwise
for
output
in
split_quantize_outputs
]
# assert with zero tolerance
for
i
in
range
(
len
(
x_qx_t
)):
if
split_sections
[
i
]
==
0
:
# then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype
(
x_amax_colwise
[
i
],
x_amax_colwise_ref
[
i
])
assert_same_shape_and_dtype
(
x_qx_t
[
i
],
x_qx_t_ref
[
i
])
assert_same_shape_and_dtype
(
x_sx_t
[
i
],
x_sx_t_ref
[
i
])
else
:
torch
.
testing
.
assert_close
(
x_amax_colwise
[
i
],
x_amax_colwise_ref
[
i
],
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
x_qx_t
[
i
],
x_qx_t_ref
[
i
],
atol
=
0.0
,
rtol
=
0.0
)
valid_scale_shape
=
get_nvfp4_scale_shape_no_padding
(
x_splits
[
i
].
shape
,
True
)
x_sx_t_valid
=
x_sx_t
[
i
][:
valid_scale_shape
[
0
],
:
valid_scale_shape
[
1
]]
x_sx_t_ref_valid
=
x_sx_t_ref
[
i
][:
valid_scale_shape
[
0
],
:
valid_scale_shape
[
1
]]
torch
.
testing
.
assert_close
(
x_sx_t_valid
,
x_sx_t_ref_valid
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
# edge case, zero tokens for all
(
0
,
512
),
# full tile cases
(
256
,
1024
),
(
1024
,
256
),
# larger sizes
(
8192
,
1024
),
(
16384
,
8192
),
(
16384
,
16384
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"edge_cases"
,
[
"regular"
,
"zero_tokens_front"
,
"zero_tokens_end"
,
"zero_tokens_middle"
,
"random_uneven_split"
,
],
)
@
pytest
.
mark
.
parametrize
(
"quantize_mode"
,
[
"quantize"
,
"quantize_transpose"
,
"quantize_colwise_only"
]
)
@
pytest
.
mark
.
parametrize
(
"with_random_sign_mask"
,
[
True
,
False
],
ids
=
[
"with_random_sign_mask"
,
"no_random_sign_mask"
]
)
@
pytest
.
mark
.
parametrize
(
"with_rht"
,
[
True
,
False
],
ids
=
[
"with_rht"
,
"no_rht"
])
def
test_rht_with_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
edge_cases
:
str
,
quantize_mode
:
str
,
with_random_sign_mask
:
bool
,
with_rht
:
bool
,
)
->
None
:
split_sections
=
generate_split_sections
(
M
,
N
,
edge_cases
)
# currently disable pre-RHT amax
with_post_rht_amax
=
with_rht
if
quantize_mode
==
"quantize"
:
return_identity
=
True
return_transpose
=
False
elif
quantize_mode
==
"quantize_transpose"
:
return_identity
=
True
return_transpose
=
True
elif
quantize_mode
==
"quantize_colwise_only"
:
return_identity
=
False
return_transpose
=
True
else
:
raise
ValueError
(
f
"Invalid quantize mode:
{
quantize_mode
}
"
)
check_group_quantization_nvfp4_versus_reference
(
x_dtype
=
x_dtype
,
M
=
M
,
N
=
N
,
return_identity
=
return_identity
,
return_transpose
=
return_transpose
,
split_sections
=
split_sections
,
with_rht
=
with_rht
,
with_post_rht_amax
=
with_post_rht_amax
,
with_random_sign_mask
=
with_random_sign_mask
,
)
tests/pytorch/nvfp4/test_nvfp4_module_exact.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
List
,
Tuple
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch
import
NVFP4Quantizer
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
...
...
@@ -151,6 +156,74 @@ def quantize_fp4(
return
qx
,
sx
,
qx_t
,
sx_t
def
group_quantize_fp4
(
x
:
torch
.
Tensor
,
use_stochastic_rounding
:
bool
,
use_2D
:
bool
,
use_RHT
:
bool
,
split_sections
:
list
[
int
],
use_tex_split_quantize
:
bool
=
True
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
"""
Group quantize function with toggle between tex.split_quantize and manual split/call methods.
Args:
x (torch.Tensor): Input tensor.
use_stochastic_rounding (bool): Use stochastic rounding.
use_2D (bool): Use 2D quantization.
use_RHT (bool): Use RHT.
split_sections (list[int]): Split sizes for inputs.
use_tex_split_quantize (bool): Toggle method. If True, use tex.split_quantize, else use manual split and per-quantizer invocation.
Returns:
tuple: Lists of quantized tensors and scale tensors for all sections.
"""
num_tensors
=
len
(
split_sections
)
nvfp4_quantizers
=
[
NVFP4Quantizer
(
rowwise
=
True
,
columnwise
=
True
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
use_RHT
,
with_post_rht_amax
=
True
,
stochastic_rounding
=
use_stochastic_rounding
,
with_2d_quantization
=
use_2D
,
)
for
_
in
range
(
num_tensors
)
]
if
use_tex_split_quantize
:
outputs
=
tex
.
split_quantize
(
x
,
split_sections
,
nvfp4_quantizers
)
qx_list
=
[
output
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
for
output
in
outputs
]
sx_list
=
[
output
.
_rowwise_scale_inv
for
output
in
outputs
]
qx_t_list
=
[
output
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
for
output
in
outputs
]
sx_t_list
=
[
output
.
_columnwise_scale_inv
for
output
in
outputs
]
else
:
x_chunks
=
torch
.
split
(
x
,
split_sections
)
qx_list
=
[]
sx_list
=
[]
qx_t_list
=
[]
sx_t_list
=
[]
for
i
in
range
(
num_tensors
):
x_chunk
=
x_chunks
[
i
]
x_nvfp4_sut
=
nvfp4_quantizers
[
i
](
x_chunk
)
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
=
x_nvfp4_sut
.
_rowwise_scale_inv
assert
x_nvfp4_sut
.
_columnwise_data
is
not
None
qx_t
=
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_columnwise_scale_inv
is
not
None
sx_t
=
x_nvfp4_sut
.
_columnwise_scale_inv
qx_list
.
append
(
qx
)
sx_list
.
append
(
sx
)
qx_t_list
.
append
(
qx_t
)
sx_t_list
.
append
(
sx_t
)
return
qx_list
,
sx_list
,
qx_t_list
,
sx_t_list
def
check_quantization_nvfp4_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
use_2D
:
bool
,
use_RHT
:
bool
)
->
None
:
...
...
@@ -209,6 +282,92 @@ def check_quantization_nvfp4_versus_reference(
assert
me_t_sr
<
me_t_rn
,
"Stochastic rounding failed - error larger than the round to nearest."
def
check_group_quantization_nvfp4_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
use_2D
:
bool
,
use_RHT
:
bool
,
num_splits
:
int
,
use_tex_split_quantize
:
bool
=
True
,
)
->
None
:
device
=
"cuda"
torch
.
manual_seed
(
seed
)
n_iters
=
50
split_sections
=
[
M
//
num_splits
]
*
num_splits
x_total
=
torch
.
randn
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
*
2
-
1
x_splits
=
torch
.
split
(
x_total
,
split_sections
)
q_rn_list
,
s_rn_list
,
q_t_rn_list
,
s_t_rn_list
=
group_quantize_fp4
(
x_total
,
use_stochastic_rounding
=
False
,
use_2D
=
use_2D
,
use_RHT
=
use_RHT
,
split_sections
=
split_sections
,
use_tex_split_quantize
=
use_tex_split_quantize
,
)
sr_n_iter_results
=
[]
for
i
in
range
(
n_iters
):
q_sr_list
,
s_sr_list
,
q_t_sr_list
,
s_t_sr_list
=
group_quantize_fp4
(
x_total
,
use_stochastic_rounding
=
True
,
use_2D
=
use_2D
,
use_RHT
=
use_RHT
,
split_sections
=
split_sections
,
use_tex_split_quantize
=
use_tex_split_quantize
,
)
sr_n_iter_results
.
append
((
q_sr_list
,
s_sr_list
,
q_t_sr_list
,
s_t_sr_list
))
for
i
,
x
in
enumerate
(
x_splits
):
y
=
x
.
t
().
contiguous
()
if
use_RHT
:
y
=
RHT
(
y
)
amax
=
torch
.
max
(
torch
.
abs
(
x
)).
float
()
# fetch q_rn, s_rn, q_t_rn, s_t_rn
q_rn
=
q_rn_list
[
i
]
s_rn
=
s_rn_list
[
i
]
q_t_rn
=
q_t_rn_list
[
i
]
s_t_rn
=
s_t_rn_list
[
i
]
dq_rn
=
dequantize_fp4
(
q_rn
,
s_rn
,
amax
)
dq_t_rn
=
dequantize_fp4
(
q_t_rn
,
s_t_rn
,
amax
)
error_rn
=
(
dq_rn
-
x
).
float
()
me_rn
=
torch
.
sqrt
((
error_rn
*
error_rn
).
mean
())
error_t_rn
=
(
dq_t_rn
-
y
).
float
()
me_t_rn
=
torch
.
sqrt
((
error_t_rn
*
error_t_rn
).
mean
())
sr_result
=
torch
.
zeros_like
(
x
).
float
()
sr_t_result
=
torch
.
zeros_like
(
x
).
float
().
t
().
contiguous
()
for
iter_idx
in
range
(
n_iters
):
result_sr
=
sr_n_iter_results
[
iter_idx
]
q_sr
=
result_sr
[
0
][
i
]
s_sr
=
result_sr
[
1
][
i
]
q_t_sr
=
result_sr
[
2
][
i
]
s_t_sr
=
result_sr
[
3
][
i
]
dq_sr
=
dequantize_fp4
(
q_sr
,
s_sr
,
amax
)
dq_t_sr
=
dequantize_fp4
(
q_t_sr
,
s_t_sr
,
amax
)
sr_result
+=
dq_sr
.
float
()
sr_t_result
+=
dq_t_sr
.
float
()
# Get the mean result of the stochastic rounding
# It should be more accurate than the RN result
sr_result
/=
n_iters
error_sr
=
(
sr_result
-
x
).
float
()
me_sr
=
torch
.
sqrt
((
error_sr
*
error_sr
).
mean
())
sr_t_result
/=
n_iters
error_t_sr
=
(
sr_t_result
-
y
).
float
()
me_t_sr
=
torch
.
sqrt
((
error_t_sr
*
error_t_sr
).
mean
())
print
(
f
"RMSE SR:
{
me_sr
:.
3
e
}
| RMSE RN:
{
me_rn
:.
3
e
}
"
)
print
(
f
"RMSE SR_t:
{
me_t_sr
:.
3
e
}
| RMSE RN_t:
{
me_t_rn
:.
3
e
}
"
)
assert
me_sr
<
me_rn
,
"Stochastic rounding failed - error larger than the round to nearest."
assert
(
me_t_sr
<
me_t_rn
),
"Stochastic rounding failed - error larger than the round to nearest."
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
...
...
@@ -236,3 +395,39 @@ def test_quantization_block_tiling_versus_reference(
M
=
M
,
N
=
N
,
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
(
8192
,
8192
),
(
4096
,
7168
),
(
16384
,
2048
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"use_2D"
,
[
False
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"use_RHT"
,
[
True
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
4
,
8
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"use_tex_split_quantize"
,
[
True
,
False
],
ids
=
str
)
def
test_group_stochastic_rounding_quantization_versus_reference
(
x_dtype
:
torch
.
dtype
,
use_2D
:
bool
,
use_RHT
:
bool
,
num_splits
:
int
,
use_tex_split_quantize
:
bool
,
M
:
int
,
N
:
int
,
)
->
None
:
if
x_dtype
==
torch
.
float32
and
use_RHT
:
pytest
.
skip
(
"RHT is only supported with bfloat16"
)
check_group_quantization_nvfp4_versus_reference
(
x_dtype
=
x_dtype
,
use_2D
=
use_2D
,
use_RHT
=
use_RHT
,
M
=
M
,
N
=
N
,
num_splits
=
num_splits
,
use_tex_split_quantize
=
use_tex_split_quantize
,
)
tests/pytorch/references/blockwise_fp8_gemm_reference.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/references/blockwise_quantizer_reference.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/references/quantize_scale_calc.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/references/ref_per_tensor_cs.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/test_checkpoint.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/test_cpu_offloading.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -54,9 +54,13 @@ gc.disable()
class
Utils
:
# Tensor used for simulating long-running GPU work in long_job()
tensor1
=
torch
.
randn
((
1024
,
1024
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
_B
=
64
_S
=
256
# Test tensor dimensions: _B x _S x _D = 128 x 512 x 256 = 16,777,216 elements
# This exceeds the 256K element threshold for offloading (cpu_offload.py line 443).
# For quantized tensors, scale_inv tensors (~524K elements for block scaling) also exceed threshold.
_B
=
128
_S
=
512
_H
=
4
_D
=
256
...
...
@@ -395,6 +399,9 @@ class TestsDefaultOffloadSynchronizer:
offload_synchronizer
.
push_tensor
(
x1
)
offload_synchronizer
.
push_tensor
(
x1
)
offload_synchronizer
.
push_tensor
(
x1
)
# Verify x1 is not corrupted after pushing (important for QuantizedTensor)
if
recipe
is
not
None
:
x1
.
dequantize
()
# Should not raise - tensor should still be valid
offload_synchronizer
.
fwd_step
()
# Only one copy of tensor on cpu is allocated.
assert
Utils
.
get_cpu_memory_mb
()
==
pytest
.
approx
(
init_cpu_memory
+
1
*
x_size
,
0.1
)
...
...
tests/pytorch/test_cpu_offloading_v1.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/test_cuda_graphs.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -203,7 +203,8 @@ _test_cuda_graphs_modules: List[str] = [
# creating TMA descriptor for MXFP8 quantization.
"linear"
,
"transformer"
,
"layernorm_mlp"
,
"layernorm_mlp_nocheckpoint"
,
"layernorm_mlp_checkpoint"
,
"layernorm_linear"
,
"mha"
,
"linear_op"
,
...
...
@@ -245,12 +246,23 @@ def _test_cuda_graphs(
)
for
_
in
range
(
num_layers
)
]
elif
module
==
"layernorm_mlp"
:
elif
module
==
"layernorm_mlp
_nocheckpoint
"
:
modules
=
[
LayerNormMLP
(
model_config
.
hidden_size
,
model_config
.
hidden_size
,
params_dtype
=
dtype
,
checkpoint
=
False
,
)
for
_
in
range
(
num_layers
)
]
elif
module
==
"layernorm_mlp_checkpoint"
:
modules
=
[
LayerNormMLP
(
model_config
.
hidden_size
,
model_config
.
hidden_size
,
params_dtype
=
dtype
,
checkpoint
=
True
,
)
for
_
in
range
(
num_layers
)
]
...
...
@@ -389,6 +401,17 @@ def test_make_graphed_callables(
)
if
fp8_params
:
pytest
.
skip
(
"NVFP4 params not supported"
)
if
(
fp8
and
fp8_recipe
.
delayed
()
and
torch
.
cuda
.
get_device_capability
()
>=
(
10
,
0
)
and
module
==
"layernorm_mlp_checkpoint"
):
pytest
.
skip
(
"CUDA graphs not supported for LayerNormMLP "
"with checkpoint=True, SM>=10, "
"and DelayedScaling recipe"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
"FP8 not supported on rocm GPU."
)
if
fp8
and
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
...
...
@@ -421,7 +444,8 @@ def test_make_graphed_callables(
_test_make_graphed_callables_with_fp8_weight_caching_modules
=
[
"transformer"
,
"layernorm_mlp"
,
"layernorm_mlp_nocheckpoint"
,
"layernorm_mlp_checkpoint"
,
"layernorm_linear"
,
"linear"
,
"mha"
,
...
...
tests/pytorch/test_custom_recipe.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/test_deferred_init.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -28,7 +28,6 @@ dtype = torch.bfloat16
class
TestDeferredInit
:
@
staticmethod
def
get_module_args
(
module
):
hidden_size
=
num_heads
*
head_dim
...
...
@@ -82,3 +81,45 @@ class TestDeferredInit:
"on CUDA device"
)
del
module
@
pytest
.
mark
.
parametrize
(
"module_type"
,
_core_modules
)
def
test_reset_parameters_doesnt_change_parameter_stats
(
self
,
module_type
:
torch
.
nn
.
Module
,
)
->
None
:
"""Test for github issue #2528 and #2529 to ensure that reset_parameters() doesn't change
the parameter mean and std"""
args
,
kwargs
=
TestDeferredInit
.
get_module_args
(
module_type
)
kwargs
[
"device"
]
=
"cuda"
module
=
module_type
(
*
args
,
**
kwargs
)
param_stats
=
{
name
:
{
"mean"
:
param
.
mean
(),
"std"
:
param
.
std
()}
for
name
,
param
in
module
.
named_parameters
()
}
with
torch
.
no_grad
():
module
.
reset_parameters
()
param_stats_after
=
{
name
:
{
"mean"
:
param
.
mean
(),
"std"
:
param
.
std
()}
for
name
,
param
in
module
.
named_parameters
()
}
for
name
,
stats
in
param_stats_after
.
items
():
torch
.
testing
.
assert_close
(
stats
[
"mean"
],
param_stats
[
name
][
"mean"
],
atol
=
1e-3
,
rtol
=
1e-3
,
msg
=
f
"
{
name
}
mean changed after reset_parameters"
,
)
torch
.
testing
.
assert_close
(
stats
[
"std"
],
param_stats
[
name
][
"std"
],
atol
=
1e-3
,
rtol
=
1e-3
,
msg
=
f
"
{
name
}
std changed after reset_parameters"
,
)
del
module
Prev
1
…
8
9
10
11
12
13
14
15
16
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment