Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1389 additions
and
221 deletions
+1389
-221
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+2
-2
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+1
-121
tests/pytorch/test_float8_current_scaling_exact.py
tests/pytorch/test_float8_current_scaling_exact.py
+133
-1
tests/pytorch/test_float8blockwisetensor.py
tests/pytorch/test_float8blockwisetensor.py
+2
-15
tests/pytorch/test_fused_optimizer.py
tests/pytorch/test_fused_optimizer.py
+1
-1
tests/pytorch/test_fused_rope.py
tests/pytorch/test_fused_rope.py
+1
-1
tests/pytorch/test_fused_router.py
tests/pytorch/test_fused_router.py
+1
-1
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+119
-38
tests/pytorch/test_gqa.py
tests/pytorch/test_gqa.py
+1
-1
tests/pytorch/test_hf_integration.py
tests/pytorch/test_hf_integration.py
+1
-1
tests/pytorch/test_jit.py
tests/pytorch/test_jit.py
+1
-1
tests/pytorch/test_multi_tensor.py
tests/pytorch/test_multi_tensor.py
+33
-1
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+81
-18
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+1
-1
tests/pytorch/test_parallel_cross_entropy.py
tests/pytorch/test_parallel_cross_entropy.py
+15
-2
tests/pytorch/test_partial_cast.py
tests/pytorch/test_partial_cast.py
+137
-0
tests/pytorch/test_permutation.py
tests/pytorch/test_permutation.py
+650
-10
tests/pytorch/test_qk_norm.py
tests/pytorch/test_qk_norm.py
+1
-1
tests/pytorch/test_quantized_tensor.py
tests/pytorch/test_quantized_tensor.py
+207
-3
tests/pytorch/test_recipe.py
tests/pytorch/test_recipe.py
+1
-2
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -909,7 +909,7 @@ def test_illegal_2D_by_2D_enforced(
is_w_1d_scaled
,
)
->
None
:
# 2D block quantization by 2D block quantization is not supported.
expected_err_msg
=
"Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported"
expected_err_msg
=
"Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling
GEMM is
supported"
cublas_gemm_test_constraint_enforced
(
x_dtype
,
w_dtype
,
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -89,126 +89,6 @@ def initialize_for_many_scales(
return
result
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
# full tile cases
(
128
,
128
),
(
256
,
256
),
(
256
,
1024
),
(
1024
,
256
),
# Padding required cases
(
256
,
272
),
(
303
,
300
),
(
305
,
256
),
# Some larger tiles.
(
2000
,
2000
),
(
2048
,
2000
),
(
2000
,
1024
),
(
2048
,
1024
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
True
],
ids
=
[
"pow2scales"
])
def
test_quantization_1D_block_tiling_with_compact_data_and_scales
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
quant_dtype
:
torch
.
dtype
,
eps
:
float
,
pow_2_scales
:
bool
,
)
->
None
:
te_dtype
=
TE_DType
[
quant_dtype
]
tile_size
=
(
1
,
blockwise_fp8_block_len
)
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
ref_quantizer
=
BlockwiseQuantizerReference
()
sut_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
eps
,
force_pow_2_scales
=
pow_2_scales
,
block_scaling_dim
=
1
,
all_gather_usage
=
True
,
)
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Input
x
=
initialize_for_many_scales
((
M
,
N
),
tile_size
,
dtype
=
x_dtype
,
device
=
device
)
x_fp8_sut
=
sut_quantizer
.
make_empty
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_fp8_sut
=
sut_quantizer
.
update_quantized
(
x
,
x_fp8_sut
)
x_fp8_sut_cpp_alloc
=
sut_quantizer
(
x
)
assert
x_fp8_sut
.
_rowwise_data
is
not
None
qx
:
torch
.
Tensor
=
x_fp8_sut
.
_rowwise_data
.
view
(
dtype
=
quant_dtype
)
assert
x_fp8_sut
.
_rowwise_scale_inv
is
not
None
sx
:
torch
.
Tensor
=
x_fp8_sut
.
_rowwise_scale_inv
qx_t
=
x_fp8_sut
.
_columnwise_data
sx_t
=
x_fp8_sut
.
_columnwise_scale_inv
qresult_ref
=
ref_quantizer
.
quantize
(
x
,
quant_dtype
=
quant_dtype
,
return_transpose
=
True
,
eps
=
eps
,
pow_2_scales
=
pow_2_scales
,
quant_tile_shape
=
tile_size
,
munge_scale_shapes
=
False
,
)
qx_ref
,
sx_ref
,
qx_t_ref
,
sx_t_ref
=
(
qresult_ref
.
data
,
qresult_ref
.
scale
,
qresult_ref
.
data_t
,
qresult_ref
.
scale_t
,
)
# match the reference quantize transpose output with the columnwise non-transpose method
qx_t_ref
=
qx_t_ref
.
transpose
(
-
1
,
-
2
).
contiguous
()
sx_t_ref
=
sx_t_ref
.
transpose
(
-
1
,
-
2
).
contiguous
()
# Check
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
assert
qx_t
is
not
None
qx_t
=
qx_t
.
view
(
dtype
=
quant_dtype
)
assert
qx_t_ref
is
not
None
assert
sx_t
is
not
None
assert
sx_t_ref
is
not
None
torch
.
testing
.
assert_close
(
qx_t
.
float
(),
qx_t_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx_t
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# check that the C++ and Python allocators are equivalent
torch
.
testing
.
assert_close
(
x_fp8_sut
.
_rowwise_data
,
x_fp8_sut_cpp_alloc
.
_rowwise_data
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
x_fp8_sut
.
_rowwise_scale_inv
,
x_fp8_sut_cpp_alloc
.
_rowwise_scale_inv
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
x_fp8_sut
.
_columnwise_data
,
x_fp8_sut_cpp_alloc
.
_columnwise_data
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
x_fp8_sut
.
_columnwise_scale_inv
,
x_fp8_sut_cpp_alloc
.
_columnwise_scale_inv
,
atol
=
0.0
,
rtol
=
0.0
,
)
# check if the fp8 output between C++ and Python are the same
assert
x_fp8_sut
.
_data_format
==
x_fp8_sut_cpp_alloc
.
_data_format
def
check_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
...
...
tests/pytorch/test_float8_current_scaling_exact.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -8,9 +8,15 @@ import torch
import
pytest
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Float8CurrentScaling
from
transformer_engine.pytorch.quantization
import
autocast
,
get_fp8_torch_dtype
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.custom_recipes.quantization
import
MMParams
from
transformer_engine.pytorch.custom_recipes.quantization_current_scaling
import
(
CurrentScalingQuantizerRef
,
)
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8
...
...
@@ -750,6 +756,132 @@ class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase):
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
class
TestFP8CurrentScalingNativeVsRef
:
@
staticmethod
def
_make_quantizers
(
rowwise
=
True
,
columnwise
=
True
):
# TE native FP8 current scaling quantizer
te_quant
=
te
.
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
torch
.
device
(
"cuda"
),
rowwise
=
rowwise
,
columnwise
=
columnwise
,
)
# Reference quantizer
ref_quant
=
CurrentScalingQuantizerRef
(
dtype
=
torch
.
float8_e4m3fn
,
rowwise
=
rowwise
,
columnwise
=
columnwise
,
pow_2_scales
=
False
,
eps
=
0.0
,
)
return
te_quant
,
ref_quant
@
pytest
.
mark
.
parametrize
(
"M, N, dtype"
,
[
(
128
,
256
,
torch
.
bfloat16
),
],
ids
=
[
"rowwise"
],
)
def
test_current_scaling_quantization_versus_reference
(
self
,
M
,
N
,
dtype
):
device
=
"cuda"
seed
=
123
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
((
M
,
N
),
dtype
=
dtype
,
device
=
device
)
te_quant
,
ref_quant
=
self
.
_make_quantizers
(
rowwise
=
True
,
columnwise
=
False
)
# Native TE quantization
x_te
=
te_quant
(
x
)
assert
x_te
.
_data
is
not
None
qx_native
=
x_te
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
)
sx_native
=
x_te
.
_scale_inv
# Reference quantization
x_ref
=
ref_quant
.
quantize
(
x
)
qx_ref
=
x_ref
.
data
sx_ref
=
x_ref
.
scale
# Byte-for-byte equality on data and exact scale_inv match
torch
.
testing
.
assert_close
(
qx_native
,
qx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx_native
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"M, K, N, out_dtype, accumulate"
,
[
(
128
,
256
,
96
,
torch
.
bfloat16
,
False
),
(
64
,
128
,
64
,
torch
.
float32
,
True
),
],
ids
=
[
"bf16_no_acc"
,
"fp32_acc"
],
)
def
test_current_scaling_gemm_versus_reference
(
self
,
M
,
K
,
N
,
out_dtype
,
accumulate
):
device
=
"cuda"
seed
=
42
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
((
M
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
device
)
w
=
torch
.
randn
((
N
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
device
)
out
=
torch
.
randn
((
M
,
N
),
dtype
=
out_dtype
,
device
=
device
)
if
accumulate
else
None
te_quant_x
,
ref_quant
=
self
.
_make_quantizers
(
rowwise
=
True
,
columnwise
=
True
)
te_quant_w
,
_
=
self
.
_make_quantizers
(
rowwise
=
True
,
columnwise
=
True
)
# Native TE quantization (direct)
qx_native
=
te_quant_x
(
x
)
qw_native
=
te_quant_w
(
w
)
# Prepare inputs for reference qgemm
assert
qx_native
.
_data
is
not
None
and
qw_native
.
_data
is
not
None
qx_data
=
qx_native
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
)
qw_data
=
qw_native
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
)
sx
=
qx_native
.
_scale_inv
sw
=
qw_native
.
_scale_inv
# Reference GEMM
m_params
=
MMParams
(
out_dtype
=
out_dtype
,
use_split_accumulator
=
False
)
y_ref
=
ref_quant
.
qgemm
(
qx
=
qx_data
,
qw
=
qw_data
,
m_params
=
m_params
,
out_dtype
=
out_dtype
,
sx
=
sx
,
sw
=
sw
,
bias
=
None
,
out
=
out
.
clone
()
if
accumulate
else
None
,
accumulate
=
accumulate
,
gemm_type
=
None
,
qresult_x
=
None
,
qresult_w
=
None
,
)
# Native TE GEMM
# return type is out, bias_grad, gelu_input, extra_output
y_native
=
tex
.
generic_gemm
(
qw_native
,
# A
True
,
# transa (treat (N,K) as (K,N))
qx_native
,
# B
False
,
# transb
out
.
clone
()
if
accumulate
else
None
,
None
,
# out quantizer
TE_DType
[
out_dtype
],
None
,
# bias
TE_DType
[
torch
.
bfloat16
],
False
,
# use_gelu
None
,
# gelu_input
False
,
# use_grad
torch
.
empty
(
0
,
dtype
=
torch
.
uint8
,
device
=
device
),
0
,
accumulate
,
False
,
# use_split_accumulator
)[
0
]
torch
.
testing
.
assert_close
(
y_native
,
y_ref
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
class
TestFP8CurrentScalingRecipeLayerNormLinear
(
TestFP8RecipeLayerNormLinearBase
):
...
...
tests/pytorch/test_float8blockwisetensor.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -179,16 +179,12 @@ class TestFloat8BlockwiseTensor:
)
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"dq_columnwise"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"all_gather_usage"
,
[
True
,
False
])
def
test_quantize_dequantize_dims
(
self
,
dims
:
DimsType
,
block_scaling_dim
:
int
,
dq_columnwise
:
bool
,
all_gather_usage
:
bool
,
)
->
None
:
if
all_gather_usage
and
block_scaling_dim
!=
1
:
pytest
.
skip
(
"all_gather_usage only implemented for 1D block quantization."
)
atol
=
_tols
[
tex
.
DType
.
kFloat8E4M3
][
"atol"
]
rtol
=
_tols
[
tex
.
DType
.
kFloat8E4M3
][
"rtol"
]
quantizer
=
Float8BlockQuantizer
(
...
...
@@ -196,7 +192,6 @@ class TestFloat8BlockwiseTensor:
rowwise
=
True
,
columnwise
=
dq_columnwise
,
block_scaling_dim
=
block_scaling_dim
,
all_gather_usage
=
all_gather_usage
,
)
self
.
_test_quantize_dequantize
(
quantizer
=
quantizer
,
...
...
@@ -222,7 +217,6 @@ class TestFloat8BlockwiseTensor:
rowwise
=
True
,
columnwise
=
dq_columnwise
,
block_scaling_dim
=
block_scaling_dim
,
all_gather_usage
=
(
block_scaling_dim
==
1
),
)
self
.
_test_quantize_dequantize
(
quantizer
=
quantizer
,
...
...
@@ -287,13 +281,8 @@ class TestFloat8BlockwiseTensor:
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"all_gather_usage"
,
[
True
,
False
])
def
test_serialization
(
self
,
dims
:
DimsType
,
block_scaling_dim
:
int
,
all_gather_usage
:
bool
)
->
None
:
def
test_serialization
(
self
,
dims
:
DimsType
,
block_scaling_dim
:
int
)
->
None
:
"""Test serialization of Float8BlockwiseQTensor"""
if
all_gather_usage
and
block_scaling_dim
!=
1
:
pytest
.
skip
(
"all_gather_usage only implemented for 1D block quantization."
)
device
=
"cuda"
dtype
=
torch
.
bfloat16
x_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
...
...
@@ -302,7 +291,6 @@ class TestFloat8BlockwiseTensor:
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
block_scaling_dim
,
all_gather_usage
=
all_gather_usage
,
)
# Create FP8 tensor
...
...
@@ -326,7 +314,6 @@ class TestFloat8BlockwiseTensor:
assert
x_fp8_loaded
.
_is_2D_scaled
==
x_fp8
.
_is_2D_scaled
assert
x_fp8_loaded
.
dtype
==
x_fp8
.
dtype
assert
x_fp8_loaded
.
_fp8_dtype
==
x_fp8
.
_fp8_dtype
assert
x_fp8_loaded
.
_data_format
==
x_fp8
.
_data_format
# Test that dequantized values match
x_fp8_dequant
=
x_fp8
.
dequantize
()
...
...
tests/pytorch/test_fused_optimizer.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/test_fused_rope.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
Callable
,
Tuple
,
Union
,
List
...
...
tests/pytorch/test_fused_router.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
torch
...
...
tests/pytorch/test_fusible_ops.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -913,15 +913,15 @@ class TestBasicOps:
dtype
=
dtype
,
accumulate_into_main_grad
=
accumulate_into_main_grad
,
)
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
quantized_input
,
backward
=
quantized_grad_input
),
op
,
te_ops
.
Quantize
(
forward
=
quantized_output
,
backward
=
quantized_grad_output
),
)
with
torch
.
no_grad
():
op
.
weight
.
copy_
(
w_test
)
del
w_test
op
.
weight
.
main_grad
=
torch
.
full_like
(
op
.
weight
,
0.5
,
dtype
=
torch
.
float32
)
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
quantized_input
,
backward
=
quantized_grad_input
),
op
,
te_ops
.
Quantize
(
forward
=
quantized_output
,
backward
=
quantized_grad_output
),
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -2751,7 +2751,11 @@ class TestCheckpointing:
# Check that original and loaded model match exactly
tols
=
{
"rtol"
:
0
,
"atol"
:
0
}
for
param_load
,
param_save
in
zip
(
model_load
.
parameters
(),
model_save
.
parameters
()):
torch
.
testing
.
assert_close
(
param_load
,
param_save
,
**
tols
)
torch
.
testing
.
assert_close
(
# Force dequantization by casting to FP64
param_load
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
),
param_save
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
),
**
tols
,
)
torch
.
testing
.
assert_close
(
param_load
.
grad
,
param_save
.
grad
,
**
tols
)
for
y_load
,
y_save
in
zip
(
ys_load
,
ys_save
):
torch
.
testing
.
assert_close
(
y_load
,
y_save
,
**
tols
)
...
...
@@ -2768,7 +2772,6 @@ class TestSequentialModules:
@
pytest
.
mark
.
parametrize
(
"requires_grad"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"normalization"
,
(
"LayerNorm"
,
"RMSNorm"
))
@
pytest
.
mark
.
parametrize
(
"quantized_compute"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
...
...
@@ -2778,25 +2781,18 @@ class TestSequentialModules:
*
,
requires_grad
:
bool
,
bias
:
bool
,
normalization
:
str
,
quantized_compute
:
bool
,
quantized_weight
:
bool
,
dtype
:
torch
.
dtype
,
quantization
:
Optional
[
str
],
device
:
torch
.
device
=
"cuda"
,
hidden_size
:
int
=
3
2
,
sequence_length
:
int
=
512
,
hidden_size
:
int
=
2
56
,
sequence_length
:
int
=
48
,
batch_size
:
int
=
4
,
ffn_hidden_size
:
int
=
6
4
,
ffn_hidden_size
:
int
=
38
4
,
layernorm_epsilon
:
float
=
1e-5
,
)
->
None
:
"""
LayerNorm/RMSNorm + Linear + GELU + Linear
Note that this test checks only if the module runs
as when chaining multiple modules it is hard to validate
numerical accuracy.
"""
"""LayerNorm/RMSNorm + Linear + SwiGLU + Linear"""
# Make input shape
in_shape
=
(
sequence_length
,
batch_size
,
hidden_size
)
...
...
@@ -2812,38 +2808,90 @@ class TestSequentialModules:
pytest
.
skip
(
"Quantization scheme is not used"
)
# Random data
_
,
x_test
=
make_reference_and_test_tensors
(
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
requires_grad
,
)
_
,
dy_test
=
make_reference_and_test_tensors
(
norm_w_ref
,
norm_w_test
=
make_reference_and_test_tensors
(
hidden_size
,
test_dtype
=
dtype
,
test_device
=
device
,
)
norm_b_ref
,
norm_b_test
=
make_reference_and_test_tensors
(
hidden_size
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w1_ref
,
w1_test
=
make_reference_and_test_tensors
(
(
ffn_hidden_size
,
hidden_size
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w2_ref
,
w2_test
=
make_reference_and_test_tensors
(
(
hidden_size
,
ffn_hidden_size
//
2
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
b1_ref
,
b1_test
,
b2_ref
,
b2_test
=
None
,
None
,
None
,
None
if
bias
:
b1_ref
,
b1_test
=
make_reference_and_test_tensors
(
ffn_hidden_size
,
test_dtype
=
dtype
,
test_device
=
device
,
)
b2_ref
,
b2_test
=
make_reference_and_test_tensors
(
hidden_size
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
with
torch
.
no_grad
():
for
t
in
(
norm_w_ref
,
norm_w_test
,
norm_b_ref
,
norm_b_test
):
t
-=
0.5
for
t
in
(
w1_ref
,
w1_test
,
w2_ref
,
w2_test
):
t
*=
1
/
64
if
bias
:
for
t
in
(
b1_ref
,
b1_test
,
b2_ref
,
b2_test
):
t
-=
0.5
for
t
in
(
dy_ref
,
dy_test
):
t
-=
0.5
# Reference implementation
x
=
x_ref
x
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
hidden_size
,),
weight
=
norm_w_ref
,
bias
=
norm_b_ref
,
eps
=
layernorm_epsilon
,
)
x
=
torch
.
nn
.
functional
.
linear
(
x
,
w1_ref
,
bias
=
b1_ref
)
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
x
=
torch
.
nn
.
functional
.
silu
(
x1
)
*
x2
x
=
torch
.
nn
.
functional
.
linear
(
x
,
w2_ref
,
bias
=
b2_ref
)
y_ref
=
x
y_ref
.
backward
(
dy_ref
)
#
Implementation with fusible
operations
#
Construct
operations
recipe
=
make_recipe
(
quantization
)
with
te
.
quantized_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
if
normalization
==
"LayerNorm"
:
norm
=
te_ops
.
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
,
device
=
device
,
dtype
=
dtype
,
)
else
:
norm
=
te_ops
.
RMSNorm
(
hidden_size
,
eps
=
layernorm_epsilon
,
device
=
device
,
dtype
=
dtype
,
)
norm
=
te_ops
.
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
,
device
=
device
,
dtype
=
dtype
,
)
ffn1
=
te_ops
.
Linear
(
hidden_size
,
ffn_hidden_size
,
...
...
@@ -2851,15 +2899,48 @@ class TestSequentialModules:
device
=
device
,
dtype
=
dtype
,
)
act
=
te_ops
.
G
E
LU
()
act
=
te_ops
.
Swi
GLU
()
ffn2
=
te_ops
.
Linear
(
ffn_hidden_size
,
ffn_hidden_size
//
2
,
hidden_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
# Copy weights
with
torch
.
no_grad
():
norm
.
weight
.
copy_
(
norm_w_test
)
norm
.
bias
.
copy_
(
norm_b_test
)
ffn1
.
weight
.
copy_
(
w1_test
)
ffn2
.
weight
.
copy_
(
w2_test
)
if
bias
:
ffn1
.
bias
.
copy_
(
b1_test
)
ffn2
.
bias
.
copy_
(
b2_test
)
del
norm_w_test
,
norm_b_test
,
w1_test
,
b1_test
,
w2_test
,
b2_test
# Fuse ops and perform forward and backward pass
forward
=
te_ops
.
Sequential
(
norm
,
ffn1
,
act
,
ffn2
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
def
to_cpu
(
tensor
:
Optional
[
torch
.
Tensor
])
->
Optional
[
torch
.
Tensor
]:
"""Convert to FP64 CPU tensor"""
if
tensor
is
None
:
return
None
out
=
tensor
.
detach
().
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
out
=
out
.
requires_grad_
(
requires_grad
=
tensor
.
requires_grad
)
return
out
# Check values
tols
=
{
"rtol"
:
0.25
,
"atol"
:
0.5
}
# Loose tols for sanity checking
torch
.
testing
.
assert_close
(
to_cpu
(
y_test
),
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
to_cpu
(
x_test
.
grad
),
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
to_cpu
(
norm
.
weight
.
grad
),
norm_w_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
to_cpu
(
norm
.
bias
.
grad
),
norm_b_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
to_cpu
(
ffn2
.
weight
.
grad
),
w2_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
to_cpu
(
ffn1
.
weight
.
grad
),
w1_ref
.
grad
,
**
tols
)
if
bias
:
torch
.
testing
.
assert_close
(
to_cpu
(
ffn1
.
bias
.
grad
),
b1_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
to_cpu
(
ffn2
.
bias
.
grad
),
b2_ref
.
grad
,
**
tols
)
tests/pytorch/test_gqa.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/test_hf_integration.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/test_jit.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/test_multi_tensor.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -7,6 +7,7 @@ import torch
import
transformer_engine.pytorch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch
import
is_mxfp8_available
from
transformer_engine.pytorch.optimizers
import
MultiTensorApply
from
references.quantize_scale_calc
import
scale_from_amax_tensor
...
...
@@ -23,6 +24,7 @@ input_size_pairs = [
(
555
,
33333
),
]
appliers
=
[
MultiTensorApply
(
2048
*
32
),
MultiTensorApply
(
333
),
MultiTensorApply
(
33333
)]
mxfp8_available
,
reason_for_no_mxfp8
=
is_mxfp8_available
(
return_reason
=
True
)
@
pytest
.
mark
.
parametrize
(
"input_size_pair"
,
input_size_pairs
)
...
...
@@ -260,3 +262,33 @@ def test_multi_tensor_compute_scale_and_scale_inv(
torch
.
testing
.
assert_close
(
scale
,
scale_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
scale_inv
,
scale_inv_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
skipif
(
not
mxfp8_available
,
reason
=
reason_for_no_mxfp8
)
@
pytest
.
mark
.
parametrize
(
"input_size_pair"
,
input_size_pairs
+
[(
1
,
1
)])
@
pytest
.
mark
.
parametrize
(
"applier"
,
appliers
)
@
pytest
.
mark
.
parametrize
(
"repeat"
,
[
1
,
55
])
def
test_multi_tensor_compute_scale_inv_e8m0
(
input_size_pair
,
applier
,
repeat
):
sizea
,
sizeb
=
input_size_pair
device
=
torch
.
device
(
"cuda"
)
a
=
torch
.
randn
([
sizea
],
dtype
=
torch
.
bfloat16
,
device
=
device
).
abs
()
b
=
torch
.
randn
([
sizeb
],
dtype
=
torch
.
bfloat16
,
device
=
device
).
abs
()
amax_list
=
[]
for
_
in
range
(
repeat
):
amax_list
+=
[
a
.
clone
(),
b
.
clone
()]
scale_inv_list
=
[
torch
.
empty_like
(
x
).
to
(
torch
.
uint8
)
for
x
in
amax_list
]
applier
(
tex
.
multi_tensor_compute_scale_inv_e8m0
,
None
,
# overflow_buf
[
amax_list
,
scale_inv_list
],
)
max_fp8
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
for
amax
,
scale_inv
in
zip
(
amax_list
,
scale_inv_list
):
scale_inv_u32
=
(
amax
.
float
()
/
max_fp8
).
view
(
torch
.
int
)
exponent
=
scale_inv_u32
//
2
**
23
mantissa
=
scale_inv_u32
&
0x7FFFFF
exponent
+=
(
((
mantissa
>
0
)
&
(
exponent
!=
0xFE
))
&
~
((
exponent
==
0
)
&
(
mantissa
<=
0x400000
))
).
to
(
torch
.
int
)
torch
.
testing
.
assert_close
(
exponent
.
to
(
torch
.
uint8
),
scale_inv
)
tests/pytorch/test_numerics.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -13,7 +13,10 @@ import torch.nn as nn
from
torch.nn
import
Parameter
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.quantization
import
FP8GlobalStateManager
from
transformer_engine.pytorch.quantization
import
(
FP8GlobalStateManager
,
get_align_size_for_quantization
,
)
from
transformer_engine.pytorch.utils
import
(
init_method_normal
,
scaled_init_method_normal
,
...
...
@@ -46,7 +49,6 @@ from transformer_engine.pytorch import (
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
...
...
@@ -191,7 +193,7 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
return
dict
(
rtol
=
1e-3
,
atol
=
1e-5
)
if
dtype
==
torch
.
bfloat16
:
return
dict
(
rtol
=
1.6e-2
,
atol
=
1e-5
)
raise
ValueError
(
f
"Unsuppored dtype (
{
dtype
}
)"
)
raise
ValueError
(
f
"Unsuppor
t
ed dtype (
{
dtype
}
)"
)
def
assert_allclose
(
...
...
@@ -1279,6 +1281,9 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias):
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
def
test_linear_accuracy_delay_wgrad_compute
(
dtype
,
bs
,
model
,
bias
,
fuse_wgrad_accumulation
):
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"Delayed wgrad compute is not supported in debug mode."
)
config
=
model_configs
[
model
]
te_linear_ref
=
Linear
(
...
...
@@ -1376,7 +1381,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
te_outputs
=
_test_granular_accuracy
(
te_linear
,
bs
,
dtype
,
config
,
recipe
=
recipe
)
te_outputs_ref
=
_test_granular_accuracy
(
te_linear_ref
,
bs
,
dtype
,
config
,
recipe
=
recipe
)
# Shoul
e
be bit-wise match
# Shoul
d
be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
te_outputs
,
te_outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
...
...
@@ -1576,6 +1581,9 @@ def test_layernorm_linear_accuracy(
def
test_layernorm_linear_accuracy_delay_wgrad_compute
(
dtype
,
bs
,
model
,
normalization
,
zero_centered_gamma
,
bias
,
fuse_wgrad_accumulation
):
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"Delayed wgrad compute is not supported in debug mode."
)
config
=
model_configs
[
model
]
ln_linear_ref
=
LayerNormLinear
(
...
...
@@ -1709,8 +1717,15 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
def
test_layernorm_mlp_accuracy_delay_wgrad_compute
(
dtype
,
bs
,
model
,
bias
,
fuse_wgrad_accumulation
dtype
,
bs
,
model
,
bias
,
fuse_wgrad_accumulation
,
):
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"Delayed wgrad compute is not supported in debug mode."
)
config
=
model_configs
[
model
]
ln_mlp
=
LayerNormMLP
(
...
...
@@ -1760,6 +1775,58 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
def
test_layernorm_mlp_accuracy_checkpoint
(
dtype
,
bs
,
model
,
bias
,
):
config
=
model_configs
[
model
]
ln_mlp
=
LayerNormMLP
(
hidden_size
=
config
.
hidden_size
,
ffn_hidden_size
=
4
*
config
.
hidden_size
,
eps
=
config
.
eps
,
bias
=
bias
,
params_dtype
=
dtype
,
device
=
"cuda"
,
checkpoint
=
True
,
).
eval
()
ln_mlp_ref
=
LayerNormMLP
(
hidden_size
=
config
.
hidden_size
,
ffn_hidden_size
=
4
*
config
.
hidden_size
,
eps
=
config
.
eps
,
bias
=
bias
,
params_dtype
=
dtype
,
device
=
"cuda"
,
checkpoint
=
False
,
).
eval
()
# Share params
with
torch
.
no_grad
():
ln_mlp_ref
.
layer_norm_weight
=
Parameter
(
ln_mlp
.
layer_norm_weight
.
clone
())
ln_mlp_ref
.
layer_norm_bias
=
Parameter
(
ln_mlp
.
layer_norm_bias
.
clone
())
ln_mlp_ref
.
fc1_weight
=
Parameter
(
ln_mlp
.
fc1_weight
.
clone
())
ln_mlp_ref
.
fc2_weight
=
Parameter
(
ln_mlp
.
fc2_weight
.
clone
())
if
bias
:
ln_mlp_ref
.
fc1_bias
=
Parameter
(
ln_mlp
.
fc1_bias
.
clone
())
ln_mlp_ref
.
fc2_bias
=
Parameter
(
ln_mlp
.
fc2_bias
.
clone
())
te_outputs
=
_test_granular_accuracy
(
ln_mlp
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
False
)
te_outputs_ref
=
_test_granular_accuracy
(
ln_mlp_ref
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
False
)
# Shoule be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
te_outputs
,
te_outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
def
_test_grouped_linear_accuracy
(
block
,
num_gemms
,
...
...
@@ -1786,9 +1853,7 @@ def _test_grouped_linear_accuracy(
if
num_gemms
>
1
:
split_size
=
1
if
fp8
:
split_size
=
16
if
recipe
.
mxfp8
()
or
recipe
.
nvfp4
():
split_size
=
32
split_size
=
get_align_size_for_quantization
(
recipe
)
m
=
config
.
max_seqlen_q
//
split_size
dist
=
torch
.
sort
(
torch
.
randint
(
0
,
m
,
(
num_gemms
-
2
,))).
values
.
tolist
()
dist
.
append
(
dist
[
-
1
])
# Manually add a zero
...
...
@@ -1857,6 +1922,8 @@ def test_grouped_linear_accuracy(
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
NVTE_TEST_NVINSPECT_ENABLED
and
delay_wgrad_compute
:
pytest
.
skip
(
"Delayed wgrad compute is not supported in debug mode."
)
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
...
@@ -2001,6 +2068,8 @@ def test_grouped_linear_accuracy_save_original_input(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
NVTE_TEST_NVINSPECT_ENABLED
and
delay_wgrad_compute
:
pytest
.
skip
(
"Delayed wgrad compute is not supported in debug mode."
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
...
...
@@ -2106,9 +2175,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
def
_test_padding_grouped_linear_accuracy
(
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
=
False
):
def
_pad_tensor_for_fp8
(
hidden_states
,
tokens_per_expert
):
align_size
=
16
if
recipe
.
mxfp8
()
or
recipe
.
nvfp4
():
align_size
=
32
align_size
=
get_align_size_for_quantization
(
recipe
)
padded_tokens_per_expert
=
[
(
num_tokens
+
align_size
-
1
)
//
align_size
*
align_size
for
num_tokens
in
tokens_per_expert
...
...
@@ -2725,7 +2792,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
general_gemm
(
A
[
i
],
B
[
i
],
get_workspace
(),
dtype
,
grad
=
grad
,
accumulate
=
accumulate
,
...
...
@@ -2739,8 +2805,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
A
,
B
,
out
,
[
None
]
*
z
,
dtype
,
get_multi_stream_cublas_workspace
(),
m_splits
=
m_splits
,
grad
=
grad
,
accumulate
=
accumulate
,
...
...
@@ -2800,7 +2866,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
quantized_out
,
*
_
=
general_gemm
(
weight_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
quantization_params
=
out_quantizer
,
bias
=
None
,
...
...
@@ -2810,7 +2875,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
out
,
*
_
=
general_gemm
(
weight_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
quantization_params
=
None
,
bias
=
None
,
...
...
@@ -2886,7 +2950,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
general_gemm
(
A_fp8
[
i
],
B_fp8
[
i
],
get_workspace
(),
dtype
,
out
=
out_ref
[
i
],
accumulate
=
accumulate
,
...
...
@@ -2895,8 +2958,8 @@ def test_fp8_grouped_gemm(shape, accumulate):
A_fp8
,
B_fp8
,
out
,
[
None
]
*
z
,
dtype
,
get_multi_stream_cublas_workspace
(),
m_splits
=
m_splits
,
accumulate
=
accumulate
,
)
...
...
tests/pytorch/test_onnx_export.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/test_parallel_cross_entropy.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -89,7 +89,7 @@ class TestParallelCrossEntropy:
# Check that loss and grad input match
tols
=
dtype_tols
(
dtype
)
test_loss
=
test_loss
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
ref_loss
=
test
_loss
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
ref_loss
=
ref
_loss
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
ref_loss
=
ref_loss
.
reshape
(
test_loss
.
size
())
test_grad_input
=
self
.
input_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
ref_grad_input
=
self
.
input_ref
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -154,3 +154,16 @@ class TestParallelCrossEntropy:
reduce_loss
=
False
,
ignore_idx
=
True
,
)
def
test_ignore_idx_reduced_loss
(
self
):
"""Test ignore_idx with reduce_loss=True"""
self
.
generate_iters
(
5
)
self
.
generate_infra
(
True
,
0
)
# reduce_loss=True
for
i
in
range
(
self
.
iters
):
self
.
one_iteration_test
(
dtype
=
torch
.
float32
,
swap_dim
=
random
.
choice
([
True
,
False
]),
label_smoothing
=
0
,
reduce_loss
=
True
,
ignore_idx
=
True
,
)
tests/pytorch/test_partial_cast.py
0 → 100644
View file @
0d874a4e
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
multi_tensor_compute_scale_inv_e8m0
from
transformer_engine.pytorch
import
is_mxfp8_available
from
transformer_engine.pytorch.optimizers.multi_tensor_apply
import
multi_tensor_applier
mxfp8_available
,
reason_for_no_mxfp8
=
is_mxfp8_available
(
return_reason
=
True
)
def
compute_partial_amax_reference
(
inp
,
amax_rowwise
,
amax_colwise
,
h
,
w
,
start_offset
):
n
=
inp
.
view
(
-
1
).
size
(
0
)
if
n
==
h
*
w
:
full
=
inp
.
view
(
-
1
)
else
:
full
=
torch
.
zeros
(
h
*
w
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
full
[
start_offset
:
start_offset
+
n
].
copy_
(
inp
)
full
=
torch
.
abs
(
full
)
_amax_rowwise
,
_
=
torch
.
max
(
full
.
view
(
h
,
w
//
32
,
32
),
dim
=
2
)
amax_rowwise
[:
h
,
:
(
w
//
32
)].
copy_
(
_amax_rowwise
)
_amax_colwise
,
_
=
torch
.
max
(
full
.
view
(
h
//
32
,
32
,
w
),
dim
=
1
)
amax_colwise
[:
(
h
//
32
),
:
w
].
copy_
(
_amax_colwise
)
def
partial_cast_reference
(
inp
,
rowwise_out
,
colwise_out
,
rowwise_inv_scale
,
colwise_inv_scale
,
h
,
w
,
start_offset
):
rowwise_scale
=
((
254
-
rowwise_inv_scale
.
int
())
*
2
**
23
).
view
(
torch
.
float32
)
colwise_scale
=
((
254
-
colwise_inv_scale
.
int
())
*
2
**
23
).
view
(
torch
.
float32
)
n
=
inp
.
view
(
-
1
).
size
(
0
)
if
n
==
h
*
w
:
full
=
inp
else
:
full
=
torch
.
empty
(
h
*
w
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
full
[
start_offset
:
start_offset
+
n
].
copy_
(
inp
)
full
=
full
.
float
()
rowwise_scale
=
rowwise_scale
[:
h
,
:
(
w
//
32
)].
contiguous
().
float
()
colwise_scale
=
colwise_scale
[:
(
h
//
32
),
:
w
].
contiguous
().
float
()
scaled
=
(
full
.
view
(
-
1
,
32
)
*
rowwise_scale
.
view
(
-
1
,
1
)).
view
(
-
1
)
rowwise_out
.
copy_
(
scaled
[
start_offset
:
start_offset
+
n
].
to
(
torch
.
float8_e4m3fn
).
view
(
rowwise_out
.
dtype
)
)
scaled
=
(
full
.
view
(
h
//
32
,
32
,
w
)
*
colwise_scale
.
view
(
h
//
32
,
1
,
w
)).
view
(
-
1
)
colwise_out
.
copy_
(
scaled
[
start_offset
:
start_offset
+
n
].
to
(
torch
.
float8_e4m3fn
).
view
(
colwise_out
.
dtype
)
)
def
run_one_case
(
n
,
h
,
w
,
start_offset
):
inp
=
torch
.
randn
(
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
rowwise_padding
=
[
128
,
4
]
colwise_padding
=
[
4
,
128
]
def
_pad
(
x
,
padding
):
return
(
x
+
padding
-
1
)
//
padding
*
padding
rowwise_shape
=
[
_pad
(
h
,
rowwise_padding
[
0
]),
_pad
(
w
//
32
,
rowwise_padding
[
1
])]
colwise_shape
=
[
_pad
(
h
//
32
,
colwise_padding
[
0
]),
_pad
(
w
,
colwise_padding
[
1
])]
# Partial amax cuda kernel
amax_rowwise
=
torch
.
zeros
(
*
rowwise_shape
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
amax_colwise
=
torch
.
zeros
(
*
colwise_shape
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
tex
.
mxfp8_scaling_compute_partial_amax
(
inp
,
amax_rowwise
,
amax_colwise
,
h
,
w
,
start_offset
)
# Partial amax pytorch reference
amax_rowwise_ref
=
torch
.
zeros
(
*
rowwise_shape
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
amax_colwise_ref
=
torch
.
zeros
(
*
colwise_shape
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
compute_partial_amax_reference
(
inp
,
amax_rowwise_ref
,
amax_colwise_ref
,
h
,
w
,
start_offset
)
# Check partial amax
torch
.
testing
.
assert_close
(
amax_rowwise
,
amax_rowwise_ref
,
atol
=
0
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
amax_colwise
,
amax_colwise_ref
,
atol
=
0
,
rtol
=
0
)
# Calculate scales and scale_invs
scale_inv_rowwise
=
torch
.
empty_like
(
amax_rowwise
).
to
(
torch
.
uint8
)
scale_inv_colwise
=
torch
.
empty_like
(
amax_colwise
).
to
(
torch
.
uint8
)
multi_tensor_applier
(
multi_tensor_compute_scale_inv_e8m0
,
None
,
[
[
amax_rowwise
,
amax_colwise
],
[
scale_inv_rowwise
,
scale_inv_colwise
],
],
)
# Partial cast cuda kernel
output_rowwise
=
torch
.
empty_like
(
inp
).
to
(
torch
.
uint8
)
output_colwise
=
torch
.
empty_like
(
inp
).
to
(
torch
.
uint8
)
tex
.
mxfp8_scaling_partial_cast
(
inp
,
output_rowwise
,
output_colwise
,
scale_inv_rowwise
,
scale_inv_colwise
,
h
,
w
,
start_offset
,
)
# Partial cast pytorch reference
output_rowwise_ref
=
torch
.
empty_like
(
inp
).
to
(
torch
.
uint8
)
output_colwise_ref
=
torch
.
empty_like
(
inp
).
to
(
torch
.
uint8
)
partial_cast_reference
(
inp
,
output_rowwise_ref
,
output_colwise_ref
,
scale_inv_rowwise
,
scale_inv_colwise
,
h
,
w
,
start_offset
,
)
# Check partial cast results
torch
.
testing
.
assert_close
(
output_rowwise
,
output_rowwise_ref
,
atol
=
0
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
output_colwise
,
output_colwise_ref
,
atol
=
0
,
rtol
=
0
)
@
pytest
.
mark
.
skipif
(
not
mxfp8_available
,
reason
=
reason_for_no_mxfp8
)
def
test_mxfp8_scaling_partial_cast
():
torch
.
cuda
.
manual_seed
(
1234
)
run_one_case
(
3
,
32
,
64
,
31
)
run_one_case
(
64
*
64
-
2
,
64
,
64
,
1
)
run_one_case
(
16384
*
6144
,
16384
,
6144
,
0
)
run_one_case
(
32768
,
256
,
128
,
0
)
run_one_case
(
131072
,
768
,
256
,
0
)
run_one_case
(
65536
,
768
,
256
,
131072
)
run_one_case
(
98304
,
128
,
768
,
0
)
tests/pytorch/test_permutation.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
os
import
random
import
torch
...
...
@@ -13,6 +14,7 @@ from transformer_engine.common import recipe
from
transformer_engine.pytorch
import
(
moe_permute
as
te_permute
,
moe_permute_with_probs
as
te_permute_with_probs
,
moe_permute_and_pad_with_probs
as
te_permute_and_pad_with_probs
,
moe_unpermute
as
te_unpermute
,
moe_sort_chunks_by_index
as
te_sort_chunks_by_index
,
moe_sort_chunks_by_index_with_probs
as
te_sort_chunks_by_index_with_probs
,
...
...
@@ -24,6 +26,7 @@ from transformer_engine.pytorch import (
MXFP8Quantizer
,
)
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch
import
Fp8Padding
,
Fp8Unpadding
import
copy
seed
=
1234
...
...
@@ -653,6 +656,522 @@ def _test_permutation_mask_map(
print
(
f
"unpermute
\t
bwd: pytorch:
{
t1
:.
3
f
}
ms, TE:
{
t2
:.
3
f
}
ms"
)
def
_test_permutation_and_padding_mask_map
(
te_dtype
,
num_tokens
,
num_expert
,
hidden_size
,
topK
,
num_out_tokens
,
with_merging_probs
=
False
,
align_size
=
16
,
BENCHMARK
=
False
,
):
if
topK
>
num_expert
:
pytest
.
skip
(
"topK should be smaller than the number of experts."
)
if
num_out_tokens
is
None
:
num_out_tokens
=
num_tokens
*
topK
print
(
"permutation and padding:"
f
" token:
{
num_tokens
}
hidden_size:
{
hidden_size
}
expert:
{
num_expert
}
topK:
{
topK
}
"
f
" with_merging_probs:
{
with_merging_probs
}
align_size:
{
align_size
}
{
te_dtype
}
"
)
# Convert TE dtypes to PyTorch dtypes
if
te_dtype
==
tex
.
DType
.
kFloat32
:
dtype
=
torch
.
float32
elif
te_dtype
==
tex
.
DType
.
kFloat16
:
dtype
=
torch
.
float16
elif
te_dtype
==
tex
.
DType
.
kBFloat16
:
dtype
=
torch
.
bfloat16
else
:
pytest
.
skip
(
"Invalid dtype."
)
_tmp_tensor
=
torch
.
zeros
((
num_tokens
*
num_expert
,))
_tmp_tensor
[:
int
(
num_out_tokens
)]
=
1.0
_tmp_idx
=
torch
.
randperm
(
num_tokens
*
num_expert
)
routing_map
=
torch
.
reshape
(
_tmp_tensor
[
_tmp_idx
],
(
num_tokens
,
num_expert
)).
bool
().
cuda
()
probs
=
torch
.
rand
(
num_tokens
,
num_expert
).
cuda
()
*
routing_map
row_sums
=
probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
probs
=
probs
/
row_sums
probs
=
probs
.
to
(
dtype
)
probs
.
requires_grad_
(
True
)
tokens_per_expert
=
routing_map
.
sum
(
dim
=
0
).
cpu
()
target_tokens_per_expert
=
(
torch
.
ceil
(
tokens_per_expert
/
align_size
)
*
align_size
).
long
()
num_permute_pad_out_tokens
=
target_tokens_per_expert
.
sum
().
item
()
permute_pad_fwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
permute_pad_bwd_input
=
torch
.
rand
(
(
num_permute_pad_out_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
unpermute_unpad_bwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
permute_pad_fwd_input
.
requires_grad_
(
True
)
restore_shape
=
permute_pad_fwd_input
.
shape
###################################################################################################################################
#
# moe_permute_with_probs and Fp8Padding, moe_unpermute and Fp8Unpadding
#
###################################################################################################################################
# permute + padding
permuted_output
,
permuted_probs
,
row_id_map
=
te_permute_with_probs
(
permute_pad_fwd_input
,
probs
,
routing_map
,
num_out_tokens
=
num_out_tokens
,
)
tokens_per_expert_list
=
tokens_per_expert
.
tolist
()
fp8_padding
=
Fp8Padding
(
num_expert
,
align_size
)
permuted_paded_output
,
_
=
fp8_padding
(
permuted_output
,
tokens_per_expert_list
)
permuted_paded_probs
,
_
=
fp8_padding
(
permuted_probs
.
unsqueeze
(
-
1
),
tokens_per_expert_list
)
permuted_paded_output
.
backward
(
permute_pad_bwd_input
,
retain_graph
=
True
)
# unpadding + unpermute
unpermute_unpad_fwd_input
=
permuted_paded_output
.
detach
()
unpermute_unpad_fwd_input
.
requires_grad_
(
True
)
fp8_unpadding
=
Fp8Unpadding
(
num_expert
,
align_size
)
unpaded_output
=
fp8_unpadding
(
unpermute_unpad_fwd_input
,
tokens_per_expert_list
)
probs_naive
=
probs
unpermuted_unpaded_output
=
te_unpermute
(
unpaded_output
,
row_id_map
,
merging_probs
=
probs_naive
if
with_merging_probs
else
None
,
restore_shape
=
restore_shape
,
)
unpermuted_unpaded_output
.
backward
(
unpermute_unpad_bwd_input
,
retain_graph
=
True
)
###################################################################################################################################
#
# fusion moe_permute_with_probs and Fp8Padding, fusion fusion moe_unpermute and Fp8Unpadding
#
###################################################################################################################################
# fusion permute_and_pad
fusion_permute_and_pad_fwd_input
=
permute_pad_fwd_input
.
detach
()
fusion_permute_and_pad_fwd_input
.
requires_grad_
(
True
)
probs_fusion
=
probs_naive
.
detach
().
clone
()
probs_fusion
.
requires_grad_
(
True
)
(
fusion_permuted_padded_output
,
fusion_permuted_padded_probs
,
row_id_map
,
pad_offsets
,
target_tokens_per_expert
,
)
=
te_permute_and_pad_with_probs
(
fusion_permute_and_pad_fwd_input
,
probs_fusion
,
routing_map
,
tokens_per_expert
,
align_size
,
)
fusion_permuted_padded_probs
=
fusion_permuted_padded_probs
.
unsqueeze
(
-
1
)
fusion_permute_pad_bwd_input
=
permute_pad_bwd_input
.
detach
()
fusion_permuted_padded_output
.
backward
(
fusion_permute_pad_bwd_input
,
retain_graph
=
True
)
# fusion unpad and unpermute
fusion_unpermute_unpad_fwd_input
=
fusion_permuted_padded_output
.
detach
()
fusion_unpermute_unpad_fwd_input
.
requires_grad_
(
True
)
fusion_unpermuted_unpaded_output
=
te_unpermute
(
fusion_unpermute_unpad_fwd_input
,
row_id_map
,
merging_probs
=
probs_fusion
if
with_merging_probs
else
None
,
restore_shape
=
restore_shape
,
pad_offsets
=
pad_offsets
,
)
fusion_unpermute_bwd_input
=
unpermute_unpad_bwd_input
.
detach
()
fusion_unpermuted_unpaded_output
.
backward
(
fusion_unpermute_bwd_input
,
retain_graph
=
True
)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols
=
dtype_tols
(
te_dtype
)
permuted_paded_output_
=
permuted_paded_output
.
float
()
fusion_permuted_padded_output_
=
fusion_permuted_padded_output
.
float
()
permute_pad_fwd_input_grad
=
permute_pad_fwd_input
.
grad
.
float
()
fusion_permute_and_pad_fwd_input_grad
=
fusion_permute_and_pad_fwd_input
.
grad
.
float
()
unpermuted_unpaded_output_
=
unpermuted_unpaded_output
.
float
()
fusion_unpermuted_unpaded_output_
=
fusion_unpermuted_unpaded_output
.
float
()
unpermute_unpad_fwd_input_grad
=
unpermute_unpad_fwd_input
.
grad
.
float
()
fusion_unpermute_unpad_fwd_input_grad
=
fusion_unpermute_unpad_fwd_input
.
grad
.
float
()
if
not
BENCHMARK
:
torch
.
testing
.
assert_close
(
permuted_paded_output_
,
fusion_permuted_padded_output_
,
msg
=
f
"Mismatch in te_permute_and_pad fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
permute_pad_fwd_input_grad
,
fusion_permute_and_pad_fwd_input_grad
,
msg
=
f
"Mismatch in te_permute_and_pad bwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
unpermuted_unpaded_output_
,
fusion_unpermuted_unpaded_output_
,
msg
=
f
"Mismatch in te_unpermute fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
unpermute_unpad_fwd_input_grad
,
fusion_unpermute_unpad_fwd_input_grad
,
msg
=
f
"Mismatch in te_unpermute bwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
permuted_paded_probs
.
float
(),
fusion_permuted_padded_probs
.
float
(),
msg
=
f
"Mismatch in te_permute_and_pad bwd"
,
**
tols
,
)
if
with_merging_probs
:
torch
.
testing
.
assert_close
(
probs_naive
.
grad
.
float
(),
probs_fusion
.
grad
.
float
(),
msg
=
f
"Mismatch in te_unpermute bwd"
,
**
tols
,
)
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
if
BENCHMARK
:
def
permute_and_pad
():
permuted_output
,
permuted_probs
,
row_id_map
=
te_permute_with_probs
(
permute_pad_fwd_input
,
probs
,
routing_map
,
num_out_tokens
=
num_out_tokens
,
)
fp8_padding
(
permuted_output
,
tokens_per_expert_list
)
fp8_padding
(
permuted_probs
.
unsqueeze
(
-
1
),
tokens_per_expert_list
)
def
fusion_permute_and_pad
():
(
fusion_permuted_padded_output
,
fusion_permuted_padded_probs
,
row_id_map
,
pad_offsets
,
target_tokens_per_expert
,
)
=
te_permute_and_pad_with_probs
(
fusion_permute_and_pad_fwd_input
,
probs
,
routing_map
,
tokens_per_expert
,
align_size
,
)
fusion_permuted_padded_probs
=
fusion_permuted_padded_probs
.
unsqueeze
(
-
1
)
t1
=
perf_test_cuda_kernel
(
lambda
:
permute_and_pad
())
t2
=
perf_test_cuda_kernel
(
lambda
:
fusion_permute_and_pad
())
print
(
f
"permute_and_pad
\t\t
fwd: naive:
{
t1
:.
3
f
}
ms, fusion:
{
t2
:.
3
f
}
ms"
)
t1
=
perf_test_cuda_kernel
(
lambda
:
backward_wrapper
(
permuted_paded_output
,
permute_pad_bwd_input
,
forward_input
=
[
permute_pad_fwd_input
],
retain_graph
=
True
,
accumulate_grad
=
False
,
)
)
t2
=
perf_test_cuda_kernel
(
lambda
:
backward_wrapper
(
fusion_permuted_padded_output
,
fusion_permute_pad_bwd_input
,
forward_input
=
[
fusion_permute_and_pad_fwd_input
],
retain_graph
=
True
,
accumulate_grad
=
False
,
)
)
print
(
f
"permute_and_pad
\t\t
bwd: naive:
{
t1
:.
3
f
}
ms, fusion:
{
t2
:.
3
f
}
ms"
)
def
unpad_unpermute
():
unpaded_output
=
fp8_unpadding
(
unpermute_unpad_fwd_input
,
tokens_per_expert_list
)
unpermuted_unpaded_output
=
te_unpermute
(
unpaded_output
,
row_id_map
,
restore_shape
=
restore_shape
)
unpermuted_unpaded_output
.
backward
(
unpermute_unpad_bwd_input
,
retain_graph
=
True
)
t1
=
perf_test_cuda_kernel
(
lambda
:
unpad_unpermute
())
t2
=
perf_test_cuda_kernel
(
lambda
:
te_unpermute
(
fusion_unpermute_unpad_fwd_input
,
row_id_map
,
restore_shape
=
restore_shape
,
pad_offsets
=
pad_offsets
,
)
)
print
(
f
"unpermute_and_unpad
\t
fwd: naive:
{
t1
:.
3
f
}
ms, fusion:
{
t2
:.
3
f
}
ms"
)
t1
=
perf_test_cuda_kernel
(
lambda
:
backward_wrapper
(
unpermuted_unpaded_output
,
unpermute_unpad_bwd_input
,
forward_input
=
([
unpermute_unpad_fwd_input
,
probs
]),
retain_graph
=
True
,
accumulate_grad
=
False
,
)
)
t2
=
perf_test_cuda_kernel
(
lambda
:
backward_wrapper
(
fusion_unpermuted_unpaded_output
,
fusion_unpermute_bwd_input
,
forward_input
=
([
fusion_unpermute_unpad_fwd_input
,
probs
]),
retain_graph
=
True
,
accumulate_grad
=
False
,
)
)
print
(
f
"unpermute_and_unpad
\t
bwd: naive:
{
t1
:.
3
f
}
ms, fusion:
{
t2
:.
3
f
}
ms"
)
def
_test_permutation_and_padding_with_merging_probs
(
te_dtype
,
num_tokens
,
num_expert
,
hidden_size
,
topK
,
num_out_tokens
,
align_size
=
16
,
BENCHMARK
=
False
,
):
"""
Test the combination of merging_probs AND pad_offsets together in moe_unpermute.
This specifically tests the backward pass fix where pad_offsets must be used
when computing gradients with merging_probs.
"""
if
topK
>
num_expert
:
pytest
.
skip
(
"topK should be smaller than the number of experts."
)
if
num_out_tokens
==
None
:
num_out_tokens
=
num_tokens
*
topK
print
(
"permutation and padding with merging probs:"
f
" token:
{
num_tokens
}
hidden_size:
{
hidden_size
}
expert:
{
num_expert
}
topK:
{
topK
}
align_size:
{
align_size
}
{
te_dtype
}
"
)
# Convert TE dtypes to PyTorch dtypes
if
te_dtype
==
tex
.
DType
.
kFloat32
:
dtype
=
torch
.
float32
elif
te_dtype
==
tex
.
DType
.
kFloat16
:
dtype
=
torch
.
float16
elif
te_dtype
==
tex
.
DType
.
kBFloat16
:
dtype
=
torch
.
bfloat16
else
:
pytest
.
skip
(
"Invalid dtype."
)
_tmp_tensor
=
torch
.
zeros
((
num_tokens
*
num_expert
,))
_tmp_tensor
[:
int
(
num_out_tokens
)]
=
1.0
_tmp_idx
=
torch
.
randperm
(
num_tokens
*
num_expert
)
routing_map
=
torch
.
reshape
(
_tmp_tensor
[
_tmp_idx
],
(
num_tokens
,
num_expert
)).
bool
().
cuda
()
probs
=
torch
.
rand
(
num_tokens
,
num_expert
).
cuda
()
*
routing_map
row_sums
=
probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
probs
=
probs
/
row_sums
probs
=
probs
.
to
(
dtype
)
probs
.
requires_grad_
(
True
)
tokens_per_expert
=
routing_map
.
sum
(
dim
=
0
).
cpu
()
target_tokens_per_expert
=
(
torch
.
ceil
(
tokens_per_expert
/
align_size
)
*
align_size
).
long
()
num_permute_pad_out_tokens
=
target_tokens_per_expert
.
sum
().
item
()
permute_pad_fwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
permute_pad_bwd_input
=
torch
.
rand
(
(
num_permute_pad_out_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
unpermute_unpad_bwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
permute_pad_fwd_input
.
requires_grad_
(
True
)
restore_shape
=
permute_pad_fwd_input
.
shape
###################################################################################################################################
#
# Reference: moe_permute_with_probs + Fp8Padding, then Fp8Unpadding + moe_unpermute with merging_probs
#
###################################################################################################################################
# permute + padding
permuted_output
,
permuted_probs
,
row_id_map
=
te_permute_with_probs
(
permute_pad_fwd_input
,
probs
,
routing_map
,
num_out_tokens
=
num_out_tokens
,
)
tokens_per_expert_list
=
tokens_per_expert
.
tolist
()
fp8_padding
=
Fp8Padding
(
num_expert
,
align_size
)
permuted_paded_output
,
_
=
fp8_padding
(
permuted_output
,
tokens_per_expert_list
)
permuted_paded_output
.
backward
(
permute_pad_bwd_input
,
retain_graph
=
True
)
# Reference: unpadding + unpermute WITH merging_probs
ref_unpermute_fwd_input
=
permuted_paded_output
.
detach
()
ref_unpermute_fwd_input
.
requires_grad_
(
True
)
ref_probs
=
probs
.
detach
()
ref_probs
.
requires_grad_
(
True
)
fp8_unpadding
=
Fp8Unpadding
(
num_expert
,
align_size
)
unpaded_output
=
fp8_unpadding
(
ref_unpermute_fwd_input
,
tokens_per_expert_list
)
ref_unpermuted_output
=
te_unpermute
(
unpaded_output
,
row_id_map
,
ref_probs
,
restore_shape
=
restore_shape
)
ref_unpermuted_output
.
backward
(
unpermute_unpad_bwd_input
,
retain_graph
=
True
)
###################################################################################################################################
#
# Fused: moe_permute_and_pad_with_probs, then moe_unpermute with BOTH merging_probs AND pad_offsets
#
###################################################################################################################################
# fusion permute_and_pad
fusion_permute_fwd_input
=
permute_pad_fwd_input
.
detach
()
fusion_permute_fwd_input
.
requires_grad_
(
True
)
fusion_probs
=
probs
.
detach
()
fusion_probs
.
requires_grad_
(
True
)
(
fusion_permuted_padded_output
,
fusion_permuted_padded_probs
,
fused_row_id_map
,
pad_offsets
,
_
,
)
=
te_permute_and_pad_with_probs
(
fusion_permute_fwd_input
,
fusion_probs
,
routing_map
,
tokens_per_expert
,
align_size
,
)
fusion_permute_pad_bwd_input
=
permute_pad_bwd_input
.
detach
()
fusion_permuted_padded_output
.
backward
(
fusion_permute_pad_bwd_input
,
retain_graph
=
True
)
# Fused: unpermute with BOTH merging_probs AND pad_offsets
fusion_unpermute_fwd_input
=
fusion_permuted_padded_output
.
detach
()
fusion_unpermute_fwd_input
.
requires_grad_
(
True
)
fusion_merging_probs
=
probs
.
detach
()
fusion_merging_probs
.
requires_grad_
(
True
)
fusion_unpermuted_output
=
te_unpermute
(
fusion_unpermute_fwd_input
,
fused_row_id_map
,
fusion_merging_probs
,
restore_shape
=
restore_shape
,
pad_offsets
=
pad_offsets
,
)
fusion_unpermute_bwd_input
=
unpermute_unpad_bwd_input
.
detach
()
fusion_unpermuted_output
.
backward
(
fusion_unpermute_bwd_input
,
retain_graph
=
True
)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols
=
dtype_tols
(
te_dtype
)
# Check forward pass
ref_unpermuted_output_
=
ref_unpermuted_output
.
float
()
fusion_unpermuted_output_
=
fusion_unpermuted_output
.
float
()
if
not
BENCHMARK
:
torch
.
testing
.
assert_close
(
ref_unpermuted_output_
,
fusion_unpermuted_output_
,
msg
=
f
"Mismatch in te_unpermute with merging_probs and pad_offsets fwd"
,
**
tols
,
)
# Check backward pass - activation gradients
ref_unpermute_fwd_input_grad
=
ref_unpermute_fwd_input
.
grad
.
float
()
fusion_unpermute_fwd_input_grad
=
fusion_unpermute_fwd_input
.
grad
.
float
()
torch
.
testing
.
assert_close
(
ref_unpermute_fwd_input_grad
,
fusion_unpermute_fwd_input_grad
,
msg
=
f
"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (act_grad)"
,
**
tols
,
)
# Check backward pass - probs gradients
ref_probs_grad
=
ref_probs
.
grad
.
float
()
fusion_probs_grad
=
fusion_merging_probs
.
grad
.
float
()
torch
.
testing
.
assert_close
(
ref_probs_grad
,
fusion_probs_grad
,
msg
=
f
"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (probs_grad)"
,
**
tols
,
)
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
if
BENCHMARK
:
def
ref_unpad_unpermute
():
unpaded
=
fp8_unpadding
(
ref_unpermute_fwd_input
,
tokens_per_expert_list
)
return
te_unpermute
(
unpaded
,
row_id_map
,
ref_probs
,
restore_shape
=
restore_shape
)
def
fused_unpermute
():
return
te_unpermute
(
fusion_unpermute_fwd_input
,
fused_row_id_map
,
fusion_merging_probs
,
restore_shape
=
restore_shape
,
pad_offsets
=
pad_offsets
,
)
t1
=
perf_test_cuda_kernel
(
lambda
:
ref_unpad_unpermute
())
t2
=
perf_test_cuda_kernel
(
lambda
:
fused_unpermute
())
print
(
f
"unpermute_unpad_with_probs
\t
fwd: naive:
{
t1
:.
3
f
}
ms, fusion:
{
t2
:.
3
f
}
ms"
)
t1
=
perf_test_cuda_kernel
(
lambda
:
backward_wrapper
(
ref_unpermuted_output
,
unpermute_unpad_bwd_input
,
forward_input
=
[
ref_unpermute_fwd_input
,
ref_probs
],
retain_graph
=
True
,
accumulate_grad
=
False
,
)
)
t2
=
perf_test_cuda_kernel
(
lambda
:
backward_wrapper
(
fusion_unpermuted_output
,
fusion_unpermute_bwd_input
,
forward_input
=
[
fusion_unpermute_fwd_input
,
fusion_merging_probs
],
retain_graph
=
True
,
accumulate_grad
=
False
,
)
)
print
(
f
"unpermute_unpad_with_probs
\t
bwd: naive:
{
t1
:.
3
f
}
ms, fusion:
{
t2
:.
3
f
}
ms"
)
def
_test_permutation_mask_map_fp8
(
te_dtype
,
num_tokens
,
...
...
@@ -1126,7 +1645,7 @@ if te.is_bf16_available():
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
def
test_permutation_index_map
(
te_dtype
,
...
...
@@ -1155,7 +1674,7 @@ def test_permutation_index_map(
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
def
test_permutation_mask_map
(
te_dtype
,
...
...
@@ -1180,6 +1699,74 @@ def test_permutation_mask_map(
)
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
_te_dtypes
)
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"num_tokens, num_expert, hidden_size, topK"
,
[
(
4096
,
8
,
1280
,
2
),
(
4096
,
64
,
4096
,
6
),
(
4096
,
256
,
7168
,
6
),
(
4096
,
512
,
9216
,
8
),
],
)
@
pytest
.
mark
.
parametrize
(
"with_merging_probs"
,
[
True
,
False
])
def
test_permutation_and_padding_mask_map
(
te_dtype
,
num_tokens
,
num_expert
,
hidden_size
,
topK
,
num_out_tokens
,
with_merging_probs
,
):
BENCHMARK
=
False
_test_permutation_and_padding_mask_map
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
with_merging_probs
=
with_merging_probs
,
BENCHMARK
=
BENCHMARK
,
)
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
_te_dtypes
)
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"num_tokens, num_expert, hidden_size, topK"
,
[
(
4096
,
8
,
1280
,
2
),
(
4096
,
64
,
4096
,
6
),
(
4096
,
256
,
7168
,
6
),
(
4096
,
512
,
9216
,
8
),
],
)
def
test_permutation_and_padding_with_merging_probs
(
te_dtype
,
num_tokens
,
num_expert
,
hidden_size
,
topK
,
num_out_tokens
,
):
"""Test moe_unpermute backward pass with BOTH merging_probs AND pad_offsets."""
BENCHMARK
=
False
_test_permutation_and_padding_with_merging_probs
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
BENCHMARK
=
BENCHMARK
,
)
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
_te_dtypes
)
def
test_permutation_mask_map_empty_input
(
te_dtype
):
with_probs
=
True
...
...
@@ -1201,9 +1788,9 @@ def test_permutation_mask_map_empty_input(te_dtype):
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
,
8
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
def
test_permutation_mask_map_alongside_probs
(
te_dtype
,
num_tokens
,
...
...
@@ -1253,10 +1840,10 @@ fp8_recipes = [
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
def
test_permutation_mask_map_fp8
(
...
...
@@ -1341,7 +1928,7 @@ def test_permutation_mask_map_topk1_no_probs(
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
_te_dtypes
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
,
8
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
,
8
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
def
test_chunk_permutation
(
te_dtype
,
...
...
@@ -1376,6 +1963,10 @@ def test_chunk_permutation_empty_input(te_dtype):
)
@
pytest
.
mark
.
skipif
(
os
.
getenv
(
"RUN_BENCHMARK_TESTS"
,
"0"
)
!=
"1"
,
reason
=
"Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k single_case"
,
)
def
test_permutation_single_case
():
print
(
"GPU:"
,
torch
.
cuda
.
get_device_name
(
0
))
...
...
@@ -1413,6 +2004,26 @@ def test_permutation_single_case():
BENCHMARK
=
Benchmark
,
)
_test_permutation_and_padding_mask_map
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
BENCHMARK
=
Benchmark
,
)
_test_permutation_and_padding_with_merging_probs
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
BENCHMARK
=
Benchmark
,
)
_test_moe_chunk_sort
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
...
...
@@ -1479,6 +2090,30 @@ def benchmark_single_case(
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_push
(
"permutation_and_padding_mask_map"
)
_test_permutation_and_padding_mask_map
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
BENCHMARK
=
True
,
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_push
(
"permutation_and_padding_with_merging_probs"
)
_test_permutation_and_padding_with_merging_probs
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
BENCHMARK
=
True
,
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_push
(
"permutation_mask_map_alongside_probs"
)
_test_permutation_mask_map_alongside_probs
(
te_dtype
=
te_dtype
,
...
...
@@ -1495,7 +2130,12 @@ def benchmark_single_case(
torch
.
cuda
.
nvtx
.
range_pop
()
def
benchmark_multiple_cases
():
@
pytest
.
mark
.
skipif
(
os
.
getenv
(
"RUN_BENCHMARK_TESTS"
,
"0"
)
!=
"1"
,
reason
=
"Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark"
,
)
def
test_benchmark_multiple_cases
():
"""Benchmark test - skipped by default. Run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark"""
print
(
"GPU:"
,
torch
.
cuda
.
get_device_name
(
0
))
# te_dtype = tex.DType.kFloat32
...
...
@@ -1537,4 +2177,4 @@ def benchmark_multiple_cases():
if
__name__
==
"__main__"
:
benchmark_multiple_cases
()
test_
benchmark_multiple_cases
()
tests/pytorch/test_qk_norm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/test_
float8
tensor.py
→
tests/pytorch/test_
quantized_
tensor.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -13,9 +13,16 @@ import transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch
import
(
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
,
Float8BlockQuantizer
,
MXFP8Quantizer
,
NVFP4Quantizer
,
Float8Tensor
,
MXFP8Tensor
,
NVFP4Tensor
,
QuantizedTensor
,
)
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
import
transformer_engine_torch
as
tex
...
...
@@ -44,8 +51,22 @@ def _to_list(x: Union[Iterable, Any]) -> List:
# Types that can be interpreted as tensor dims
DimsType
=
Union
[
Iterable
[
int
],
int
]
#
Check if FP8 is supported
#
Supported quantization recipes
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
te
.
is_fp8_block_scaling_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
nvfp4_available
,
reason_for_no_nvfp4
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
_quantization_list
:
List
[
str
]
=
[]
if
fp8_available
:
_quantization_list
.
append
(
"fp8"
)
if
fp8_block_scaling_available
:
_quantization_list
.
append
(
"fp8_blockwise"
)
if
mxfp8_available
:
_quantization_list
.
append
(
"mxfp8"
)
if
nvfp4_available
:
_quantization_list
.
append
(
"nvfp4"
)
# delayed scaling
...
...
@@ -86,6 +107,79 @@ def to_float8_CS(
return
quantizer
(
tensor
)
@
torch
.
no_grad
()
def
make_reference_and_test_tensors
(
shape
:
int
|
Iterable
[
int
],
quantization
:
Optional
[
str
]
=
None
,
ref_dtype
:
torch
.
dtype
=
torch
.
float64
,
ref_device
:
torch
.
device
=
"cpu"
,
test_dtype
:
torch
.
dtype
=
torch
.
float32
,
test_device
:
torch
.
device
=
"cuda"
,
requires_grad
:
bool
=
True
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random reference tensor
ref
=
torch
.
rand
(
shape
,
dtype
=
ref_dtype
,
device
=
ref_device
)
# Construct test tensor from reference tensor
test
=
ref
.
to
(
device
=
test_device
,
dtype
=
test_dtype
)
if
quantization
is
None
:
if
test
.
data_ptr
()
==
ref
.
data_ptr
():
test
=
test
.
clone
()
elif
quantization
in
(
"fp8"
,
"fp8_delayed_scaling"
):
quantizer
=
Float8Quantizer
(
scale
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
test_device
).
squeeze
(),
amax
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
test_device
),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
test
=
quantizer
(
test
)
elif
quantization
==
"fp8_current_scaling"
:
quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
test_device
,
)
test
=
quantizer
(
test
)
elif
quantization
==
"fp8_blockwise"
:
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
rowwise
=
True
,
columnwise
=
True
,
force_pow_2_scales
=
True
,
amax_epsilon
=
0.0
,
block_scaling_dim
=
1
,
)
test
=
quantizer
(
test
)
elif
quantization
==
"mxfp8"
:
test
=
MXFP8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)(
test
)
elif
quantization
==
"nvfp4"
:
test
=
NVFP4Quantizer
(
with_rht
=
False
,
with_post_rht_amax
=
False
,
with_2d_quantization
=
False
,
stochastic_rounding
=
False
,
with_random_sign_mask
=
False
,
)(
test
)
else
:
raise
ValueError
(
f
"Unsupported quantization scheme (
{
quantization
}
)"
)
# Make sure reference and test tensors match each other
ref
.
copy_
(
test
)
ref
.
requires_grad_
(
requires_grad
)
test
.
requires_grad_
(
requires_grad
)
return
ref
,
test
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
class
TestFloat8Tensor
:
...
...
@@ -452,3 +546,113 @@ class TestCurrentScalingFloat8Tensor:
# Make sure we are not trivially passing the test
with
pytest
.
raises
(
AssertionError
):
torch
.
testing
.
assert_close
(
x_fp8_dequantized
,
-
x_hp
,
**
_tols
[
fp8_dtype
])
class
TestQuantizedTensor
:
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"op"
,
(
"clone"
,
"view"
,
"reshape"
,
"contiguous"
))
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_identity_op
(
self
,
*
,
op
:
str
,
quantization
:
str
,
shape
:
Iterable
[
int
]
=
(
128
,
128
),
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
device
:
torch
.
device
=
"cuda"
,
)
->
None
:
"""Test operations that do not affect tensor values.
These operations are must produce outputs that are bit-wise
equivalent to the inputs. They must support autograd.
"""
# Create reference and quantized tensor
x_ref
,
x_test
=
make_reference_and_test_tensors
(
shape
=
shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
shape
=
shape
,
test_dtype
=
dtype
,
requires_grad
=
False
,
)
# Apply identity operation
if
op
==
"clone"
:
y_ref
=
x_ref
.
clone
()
y_test
=
x_test
.
clone
()
elif
op
==
"view"
:
y_ref
=
x_ref
.
view
(
shape
)
y_test
=
x_test
.
view
(
shape
)
elif
op
==
"reshape"
:
y_ref
=
x_ref
.
reshape
(
shape
)
y_test
=
x_test
.
reshape
(
shape
)
elif
op
==
"contiguous"
:
y_ref
=
x_ref
.
contiguous
()
y_test
=
x_test
.
contiguous
()
# Check autograd
y_test
.
backward
(
dy_test
)
assert
x_test
.
grad
is
not
None
# Check values
tols
=
dict
(
rtol
=
0
,
atol
=
0
)
if
isinstance
(
y_test
,
QuantizedTensor
):
y_test
=
y_test
.
dequantize
()
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_ref
=
dy_ref
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
dx_ref
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
0
,
1
])
def
test_chunk
(
self
,
*
,
quantization
:
str
,
dim
:
int
,
shape
:
Iterable
[
int
]
=
(
128
,
128
),
chunks
:
int
=
2
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
device
:
torch
.
device
=
"cuda"
,
)
->
None
:
# Create reference and quantized tensor
x_ref
,
x_test
=
make_reference_and_test_tensors
(
shape
=
shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
)
# Chunk tensors
ys_ref
=
torch
.
chunk
(
x_ref
,
chunks
,
dim
=
dim
)
ys_test
=
torch
.
chunk
(
x_test
,
chunks
,
dim
=
dim
)
# Check splits
for
y_ref
,
y_test
in
zip
(
ys_ref
,
ys_test
):
# Check split shapes
assert
y_ref
.
size
()
==
y_test
.
size
()
# Check that splits are quantized when expected
if
quantization
==
"fp8"
:
assert
isinstance
(
y_test
,
Float8Tensor
)
y_test
=
y_test
.
dequantize
()
elif
quantization
==
"mxfp8"
and
dim
==
0
:
assert
isinstance
(
y_test
,
MXFP8Tensor
)
y_test
=
y_test
.
dequantize
()
# Check values
tols
=
dict
(
rtol
=
0
,
atol
=
0
)
# Chunking is exact
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
tests/pytorch/test_recipe.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -30,7 +30,6 @@ from transformer_engine.pytorch.quantization import (
)
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.common.recipe
import
DelayedScaling
,
Float8BlockScaling
,
MXFP8BlockScaling
import
transformer_engine_torch
as
tex
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
...
...
Prev
1
…
9
10
11
12
13
14
15
16
17
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment