Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1680 additions
and
228 deletions
+1680
-228
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
+491
-0
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
+248
-0
tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
+238
-0
tests/pytorch/test_checkpoint.py
tests/pytorch/test_checkpoint.py
+8
-6
tests/pytorch/test_cpu_offloading.py
tests/pytorch/test_cpu_offloading.py
+6
-7
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+89
-27
tests/pytorch/test_custom_recipe.py
tests/pytorch/test_custom_recipe.py
+290
-0
tests/pytorch/test_deferred_init.py
tests/pytorch/test_deferred_init.py
+0
-1
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+7
-6
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+18
-6
tests/pytorch/test_float8_current_scaling_exact.py
tests/pytorch/test_float8_current_scaling_exact.py
+14
-9
tests/pytorch/test_float8blockwisetensor.py
tests/pytorch/test_float8blockwisetensor.py
+2
-3
tests/pytorch/test_float8tensor.py
tests/pytorch/test_float8tensor.py
+2
-4
tests/pytorch/test_fused_optimizer.py
tests/pytorch/test_fused_optimizer.py
+15
-18
tests/pytorch/test_fused_rope.py
tests/pytorch/test_fused_rope.py
+16
-0
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+188
-91
tests/pytorch/test_hf_integration.py
tests/pytorch/test_hf_integration.py
+1
-1
tests/pytorch/test_multi_tensor.py
tests/pytorch/test_multi_tensor.py
+1
-1
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+32
-33
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+14
-15
No files found.
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch.experimental.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.common.recipe
import
NVFP4BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.experimental
import
utils
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
def
unpack_fp4
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
repeated
=
x
.
repeat_interleave
(
2
,
dim
=
1
)
repeated
[:,
0
::
2
]
&=
0x0F
repeated
[:,
1
::
2
]
>>=
4
return
repeated
def
check_quantization_nvfp4_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
swizzled_scale
:
bool
,
use_cpp_allocator
:
bool
,
with_2d_quantization
:
bool
,
)
->
None
:
te_dtype
=
tex
.
DType
.
kFloat4E2M1
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Input
x
=
torch
.
randn
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
# Quantize
nvfp4_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
return_transpose
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
False
,
with_post_rht_amax
=
False
,
with_2d_quantization
=
with_2d_quantization
,
)
if
use_cpp_allocator
:
x_nvfp4_sut
=
nvfp4_quantizer
(
x
)
else
:
x_nvfp4_sut
=
nvfp4_quantizer
.
make_empty
(
(
M
,
N
),
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_nvfp4_sut
=
nvfp4_quantizer
.
update_quantized
(
x
,
x_nvfp4_sut
)
# Extract data from NVFP4Tensor
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_scale_inv
qx_t
=
(
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_sut
.
_columnwise_data
is
not
None
else
None
)
sx_t
=
x_nvfp4_sut
.
_columnwise_scale_inv
qx_amax
=
x_nvfp4_sut
.
_amax_rowwise
# Reference quantization
quant_tile_shape
=
(
1
,
16
)
if
not
with_2d_quantization
else
(
16
,
16
)
ref_quantizer
=
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
rowwise
=
True
,
columnwise
=
return_transpose
,
pow_2_scales
=
False
,
eps
=
0.0
,
quant_tile_shape
=
quant_tile_shape
,
)
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
)
# Extract data from RefNVFP4Tensor
qx_ref
=
(
unpack_fp4
(
x_nvfp4_ref
.
data
.
view
(
dtype
=
torch
.
uint8
))
if
x_nvfp4_ref
.
data
is
not
None
else
None
)
sx_ref
=
x_nvfp4_ref
.
scale
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale
is
not
None
else
None
qx_t_ref
=
(
unpack_fp4
(
x_nvfp4_ref
.
data_t
.
view
(
dtype
=
torch
.
uint8
))
if
x_nvfp4_ref
.
data_t
is
not
None
else
None
)
sx_t_ref
=
(
x_nvfp4_ref
.
scale_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale_t
is
not
None
else
None
)
ref_amax
=
x_nvfp4_ref
.
global_amax_row
qx
=
unpack_fp4
(
qx
)
qx_t
=
unpack_fp4
(
qx_t
)
if
qx_t
is
not
None
else
None
torch
.
testing
.
assert_close
(
qx
,
qx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only the valid portion of scale tensors (reference may not have padding)
ref_sx_shape
=
sx_ref
.
shape
sx_valid
=
sx
[:
ref_sx_shape
[
0
],
:
ref_sx_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_valid
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
torch
.
testing
.
assert_close
(
qx_t
,
qx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only the valid portion of transpose scale tensors
ref_sx_t_shape
=
sx_t_ref
.
shape
sx_t_valid
=
sx_t
[:
ref_sx_t_shape
[
0
],
:
ref_sx_t_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_t_valid
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx_amax
,
ref_amax
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
# full tile cases
(
128
,
128
),
(
256
,
256
),
(
256
,
1024
),
(
1024
,
256
),
# Padding required cases
(
256
,
272
),
(
304
,
304
),
(
320
,
256
),
# Some larger tiles
(
2048
,
2048
),
(
1024
,
2048
),
(
2048
,
1024
),
# # largest tile
(
8192
,
8192
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"swizzled_scale"
,
[
False
],
ids
=
[
"linear_scale"
])
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
@
pytest
.
mark
.
parametrize
(
"with_2d_quantization"
,
[
True
,
False
],
ids
=
[
"2d_quantization"
,
"1d_quantization"
]
)
def
test_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
swizzled_scale
:
bool
,
use_cpp_allocator
:
bool
,
with_2d_quantization
:
bool
,
)
->
None
:
check_quantization_nvfp4_versus_reference
(
x_dtype
=
x_dtype
,
M
=
M
,
N
=
N
,
return_transpose
=
return_transpose
,
swizzled_scale
=
swizzled_scale
,
use_cpp_allocator
=
use_cpp_allocator
,
with_2d_quantization
=
with_2d_quantization
,
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
(
128
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"extrema_high"
,
[
False
,
True
],
ids
=
[
"zeros"
,
"maxes"
])
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
def
test_nvfp4_quantization_extrema_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
extrema_high
:
bool
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
):
te_dtype
=
tex
.
DType
.
kFloat4E2M1
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
if
extrema_high
:
x
=
torch
.
full
((
M
,
N
),
torch
.
finfo
(
x_dtype
).
max
,
dtype
=
x_dtype
,
device
=
device
)
else
:
x
=
torch
.
zeros
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
nvfp4_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
return_transpose
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
False
,
with_post_rht_amax
=
False
,
)
if
use_cpp_allocator
:
x_nvfp4_sut
=
nvfp4_quantizer
(
x
)
else
:
x_nvfp4_sut
=
nvfp4_quantizer
.
make_empty
(
(
M
,
N
),
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_nvfp4_sut
=
nvfp4_quantizer
.
update_quantized
(
x
,
x_nvfp4_sut
)
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
=
x_nvfp4_sut
.
_rowwise_scale_inv
qx_t
=
(
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_sut
.
_columnwise_data
is
not
None
else
None
)
sx_t
=
x_nvfp4_sut
.
_columnwise_scale_inv
qx_amax
=
x_nvfp4_sut
.
_amax_rowwise
ref_quantizer
=
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
rowwise
=
True
,
columnwise
=
return_transpose
,
pow_2_scales
=
False
,
eps
=
0.0
,
quant_tile_shape
=
(
1
,
16
),
)
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
)
qx_ref
=
x_nvfp4_ref
.
data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data
is
not
None
else
None
sx_ref
=
x_nvfp4_ref
.
scale
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale
is
not
None
else
None
qx_t_ref
=
(
x_nvfp4_ref
.
data_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data_t
is
not
None
else
None
)
sx_t_ref
=
(
x_nvfp4_ref
.
scale_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale_t
is
not
None
else
None
)
ref_amax
=
x_nvfp4_ref
.
global_amax_row
torch
.
testing
.
assert_close
(
qx
,
qx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
ref_sx_shape
=
sx_ref
.
shape
sx_valid
=
sx
[:
ref_sx_shape
[
0
],
:
ref_sx_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_valid
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
torch
.
testing
.
assert_close
(
qx_t
,
qx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
ref_sx_t_shape
=
sx_t_ref
.
shape
sx_t_valid
=
sx_t
[:
ref_sx_t_shape
[
0
],
:
ref_sx_t_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_t_valid
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx_amax
,
ref_amax
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
(
16
,
128
),
(
32
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
def
test_nvfp4_quantization_boundary_values
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
):
"""
Stress rounding/threshold behavior by placing values just below/above
many potential bin edges within each 16-element microblock.
Validates native vs reference byte-for-byte and scale parity.
"""
te_dtype
=
tex
.
DType
.
kFloat4E2M1
device
=
"cuda"
seed
=
123
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Construct a single row with paired boundary values: v-eps, v+eps
# spanning a wide dynamic range to exercise clipping and multiple bins.
# Ensure even N and N is multiple of 16 for microblocks, which holds for 128.
base
=
torch
.
linspace
(
-
12.0
,
12.0
,
steps
=
N
//
2
,
dtype
=
torch
.
float32
,
device
=
device
)
eps
=
torch
.
full_like
(
base
,
1e-3
)
# Avoid zero eps for very small magnitudes
eps
=
torch
.
maximum
(
eps
,
1e-4
*
torch
.
ones_like
(
base
))
lower
=
base
-
eps
upper
=
base
+
eps
row
=
torch
.
empty
(
N
,
dtype
=
torch
.
float32
,
device
=
device
)
row
[
0
::
2
]
=
lower
row
[
1
::
2
]
=
upper
x
=
row
.
unsqueeze
(
0
).
repeat
(
M
,
1
).
to
(
dtype
=
x_dtype
)
nvfp4_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
return_transpose
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
False
,
with_post_rht_amax
=
False
,
)
if
use_cpp_allocator
:
x_nvfp4_sut
=
nvfp4_quantizer
(
x
)
else
:
x_nvfp4_sut
=
nvfp4_quantizer
.
make_empty
(
(
M
,
N
),
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_nvfp4_sut
=
nvfp4_quantizer
.
update_quantized
(
x
,
x_nvfp4_sut
)
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
=
x_nvfp4_sut
.
_rowwise_scale_inv
qx_t
=
(
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_sut
.
_columnwise_data
is
not
None
else
None
)
sx_t
=
x_nvfp4_sut
.
_columnwise_scale_inv
qx_amax
=
x_nvfp4_sut
.
_amax_rowwise
ref_quantizer
=
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
rowwise
=
True
,
columnwise
=
return_transpose
,
pow_2_scales
=
False
,
eps
=
0.0
,
quant_tile_shape
=
(
1
,
16
),
)
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
)
qx_ref
=
x_nvfp4_ref
.
data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data
is
not
None
else
None
sx_ref
=
x_nvfp4_ref
.
scale
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale
is
not
None
else
None
qx_t_ref
=
(
x_nvfp4_ref
.
data_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data_t
is
not
None
else
None
)
sx_t_ref
=
(
x_nvfp4_ref
.
scale_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale_t
is
not
None
else
None
)
ref_amax
=
x_nvfp4_ref
.
global_amax_row
torch
.
testing
.
assert_close
(
qx
,
qx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only valid portion of scales (trim any padding)
ref_sx_shape
=
sx_ref
.
shape
sx_valid
=
sx
[:
ref_sx_shape
[
0
],
:
ref_sx_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_valid
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
torch
.
testing
.
assert_close
(
qx_t
,
qx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
ref_sx_t_shape
=
sx_t_ref
.
shape
sx_t_valid
=
sx_t
[:
ref_sx_t_shape
[
0
],
:
ref_sx_t_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_t_valid
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx_amax
,
ref_amax
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
(
32
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
def
test_nvfp4_quantization_noncontiguous_inputs
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
):
te_dtype
=
tex
.
DType
.
kFloat4E2M1
device
=
"cuda"
seed
=
17
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Start from a contiguous tensor, then make a non-contiguous view by transpose
x_base
=
torch
.
randn
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
x_nc
=
x_base
.
t
()
# shape (N, M), non-contiguous
assert
not
x_nc
.
is_contiguous
()
nvfp4_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
return_transpose
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
False
,
with_post_rht_amax
=
False
,
)
if
use_cpp_allocator
:
x_nvfp4_sut
=
nvfp4_quantizer
(
x_nc
)
else
:
x_nvfp4_sut
=
nvfp4_quantizer
.
make_empty
(
x_nc
.
shape
,
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_nvfp4_sut
=
nvfp4_quantizer
.
update_quantized
(
x_nc
,
x_nvfp4_sut
)
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
=
x_nvfp4_sut
.
_rowwise_scale_inv
qx_t
=
(
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_sut
.
_columnwise_data
is
not
None
else
None
)
sx_t
=
x_nvfp4_sut
.
_columnwise_scale_inv
qx_amax
=
x_nvfp4_sut
.
_amax_rowwise
ref_quantizer
=
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
rowwise
=
True
,
columnwise
=
return_transpose
,
pow_2_scales
=
False
,
eps
=
0.0
,
quant_tile_shape
=
(
1
,
16
),
)
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x_nc
)
qx_ref
=
x_nvfp4_ref
.
data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data
is
not
None
else
None
sx_ref
=
x_nvfp4_ref
.
scale
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale
is
not
None
else
None
qx_t_ref
=
(
x_nvfp4_ref
.
data_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data_t
is
not
None
else
None
)
sx_t_ref
=
(
x_nvfp4_ref
.
scale_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale_t
is
not
None
else
None
)
ref_amax
=
x_nvfp4_ref
.
global_amax_row
# Quantized must match
torch
.
testing
.
assert_close
(
qx
,
qx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only valid portion of scales (trim padding)
ref_sx_shape
=
sx_ref
.
shape
sx_valid
=
sx
[:
ref_sx_shape
[
0
],
:
ref_sx_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_valid
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
torch
.
testing
.
assert_close
(
qx_t
,
qx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
ref_sx_t_shape
=
sx_t_ref
.
shape
sx_t_valid
=
sx_t
[:
ref_sx_t_shape
[
0
],
:
ref_sx_t_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_t_valid
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx_amax
,
ref_amax
,
atol
=
0.0
,
rtol
=
0.0
)
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py.
# Separate to make sure all the functionalities are working as expected.
# Otherwise reference implementation will get messy.
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality.
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.common.recipe
import
NVFP4BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.experimental.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.pytorch.experimental
import
utils
import
pytest
import
torch
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
def
unpack_fp4
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
repeated
=
x
.
repeat_interleave
(
2
,
dim
=
1
)
repeated
[:,
0
::
2
]
&=
0x0F
repeated
[:,
1
::
2
]
>>=
4
return
repeated
def
check_quantization_nvfp4_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
contiguous
:
bool
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
swizzled_scale
:
bool
=
False
,
hadamard_dimension
:
int
=
16
,
with_rht
:
bool
=
True
,
with_post_rht_amax
:
bool
=
True
,
with_random_sign_mask
:
bool
=
True
,
)
->
None
:
assert
with_rht
and
with_post_rht_amax
,
"RHT and post-RHT amax reduction must be enabled."
te_dtype
=
tex
.
DType
.
kFloat4E2M1
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Input
x
=
torch
.
randn
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
x
=
x
.
transpose
(
0
,
1
)
if
not
contiguous
else
x
# Quantize
nvfp4_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
return_transpose
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
with_rht
,
with_post_rht_amax
=
with_post_rht_amax
,
with_random_sign_mask
=
with_random_sign_mask
,
)
if
use_cpp_allocator
:
x_nvfp4_sut
=
nvfp4_quantizer
(
x
)
else
:
x_nvfp4_sut
=
nvfp4_quantizer
.
make_empty
(
x
.
shape
,
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_nvfp4_sut
=
nvfp4_quantizer
.
update_quantized
(
x
,
x_nvfp4_sut
)
# Extract data from NVFP4Tensor
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_scale_inv
qx_t
=
(
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_sut
.
_columnwise_data
is
not
None
else
None
)
sx_t
=
x_nvfp4_sut
.
_columnwise_scale_inv
amax_rowwise
=
x_nvfp4_sut
.
_amax_rowwise
amax_colwise
=
x_nvfp4_sut
.
_amax_columnwise
qx
=
unpack_fp4
(
qx
)
qx_t
=
unpack_fp4
(
qx_t
)
if
qx_t
is
not
None
else
None
# Reference quantization using NVFP4QuantizerRef with built-in RHT
ref_quantizer
=
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
rowwise
=
True
,
columnwise
=
return_transpose
,
pow_2_scales
=
False
,
eps
=
0.0
,
quant_tile_shape
=
(
1
,
16
),
with_rht
=
with_rht
,
with_random_sign_mask
=
with_random_sign_mask
,
)
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
)
# Extract data from RefNVFP4Tensor
qx_ref
=
(
unpack_fp4
(
x_nvfp4_ref
.
data
.
view
(
dtype
=
torch
.
uint8
))
if
x_nvfp4_ref
.
data
is
not
None
else
None
)
sx_ref
=
x_nvfp4_ref
.
scale
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale
is
not
None
else
None
ref_amax_rowwise
=
x_nvfp4_ref
.
global_amax_row
if
return_transpose
:
assert
x_nvfp4_ref
.
data_t
is
not
None
assert
x_nvfp4_ref
.
scale_t
is
not
None
qx_t_ref
=
unpack_fp4
(
x_nvfp4_ref
.
data_t
.
view
(
dtype
=
torch
.
uint8
))
sx_t_ref
=
x_nvfp4_ref
.
scale_t
.
view
(
dtype
=
torch
.
uint8
)
# Compute transpose amax using the same reference quantizer
x_t_for_amax
=
(
ref_quantizer
.
_apply_rht
(
x
.
t
().
contiguous
())
if
with_rht
else
x
.
t
().
contiguous
()
)
ref_amax_colwise_t
=
torch
.
max
(
torch
.
abs
(
x_t_for_amax
)).
to
(
torch
.
float32
).
view
(
1
)
else
:
qx_t_ref
=
None
sx_t_ref
=
None
ref_amax_colwise_t
=
None
torch
.
testing
.
assert_close
(
amax_rowwise
,
ref_amax_rowwise
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx
,
qx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only the valid portion of scale tensors (reference may not have padding)
ref_sx_shape
=
sx_ref
.
shape
sx_valid
=
sx
[:
ref_sx_shape
[
0
],
:
ref_sx_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_valid
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
torch
.
testing
.
assert_close
(
amax_colwise
,
ref_amax_colwise_t
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx_t
,
qx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only the valid portion of transpose scale tensors
ref_sx_t_shape
=
sx_t_ref
.
shape
sx_t_valid
=
sx_t
[:
ref_sx_t_shape
[
0
],
:
ref_sx_t_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_t_valid
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
# full tile cases
(
128
,
128
),
(
256
,
256
),
(
256
,
1024
),
(
1024
,
256
),
# Padding required cases
(
256
,
272
),
(
304
,
304
),
(
320
,
256
),
# Some larger tiles
(
2048
,
2048
),
(
1024
,
2048
),
(
2048
,
1024
),
# Real shapes,
(
8192
,
5120
),
(
8192
,
10240
),
(
8192
,
2560
),
(
8192
,
11328
),
(
8192
,
512
),
(
8192
,
3584
),
(
5120
,
8192
),
(
10240
,
8192
),
(
2560
,
8192
),
(
11328
,
8192
),
(
512
,
8192
),
(
3584
,
8192
),
(
4096
,
16384
),
(
14336
,
16384
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
@
pytest
.
mark
.
parametrize
(
"with_random_sign_mask"
,
[
True
,
False
],
ids
=
[
"with_random_sign_mask"
,
"no_random_sign_mask"
]
)
def
test_rht_with_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
with_random_sign_mask
:
bool
,
)
->
None
:
check_quantization_nvfp4_versus_reference
(
x_dtype
=
x_dtype
,
M
=
M
,
N
=
N
,
contiguous
=
True
,
return_transpose
=
return_transpose
,
use_cpp_allocator
=
use_cpp_allocator
,
with_random_sign_mask
=
with_random_sign_mask
,
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
(
32
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
@
pytest
.
mark
.
parametrize
(
"with_random_sign_mask"
,
[
True
,
False
],
ids
=
[
"with_random_sign_mask"
,
"no_random_sign_mask"
]
)
def
test_nvfp4_quantization_noncontiguous_inputs
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
with_random_sign_mask
:
bool
,
):
check_quantization_nvfp4_versus_reference
(
x_dtype
=
x_dtype
,
M
=
M
,
N
=
N
,
contiguous
=
False
,
return_transpose
=
return_transpose
,
use_cpp_allocator
=
use_cpp_allocator
,
with_random_sign_mask
=
with_random_sign_mask
,
)
tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
0 → 100755
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch
import
NVFP4Quantizer
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
seed
=
12345
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
def
unpack_fp4
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
repeated
=
x
.
repeat_interleave
(
2
,
dim
=
1
)
repeated
[:,
0
::
2
]
&=
0x0F
repeated
[:,
1
::
2
]
>>=
4
return
repeated
_FP4_LUT
=
torch
.
tensor
(
[
0.0
,
# 0: 0000 - zero
0.5
,
# 1: 0001 - smallest positive normal
1.0
,
# 2: 0010
1.5
,
# 3: 0011
2.0
,
# 4: 0100
3.0
,
# 5: 0101
4.0
,
# 6: 0110
6.0
,
# 7: 0111 - largest positive normal
-
0.0
,
# 8: 1000 - negative zero
-
0.5
,
# 9: 1001 - smallest negative normal
-
1.0
,
# 10: 1010
-
1.5
,
# 11: 1011
-
2.0
,
# 12: 1100
-
3.0
,
# 13: 1101
-
4.0
,
# 14: 1110
-
6.0
,
# 15: 1111 - largest negative normal
],
dtype
=
torch
.
float32
,
)
def
fp4_to_fp32
(
fp4
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Convert FP4 indices to their corresponding floating point values
# Each index (0-15) represents a 4-bit FP4 value in E2M1 format
# Values based on the FP4 E2M1 specification
fp4_lut
=
_FP4_LUT
.
to
(
fp4
.
device
)
return
fp4_lut
[
fp4
.
to
(
torch
.
long
)]
def
dequantize_fp4
(
qx
:
torch
.
Tensor
,
sx
:
torch
.
Tensor
,
amax
:
torch
.
Tensor
)
->
torch
.
Tensor
:
sf
=
sx
.
repeat_interleave
(
16
,
dim
=
1
).
view
(
torch
.
float8_e4m3fn
).
to
(
torch
.
float32
)
dqx
=
fp4_to_fp32
(
unpack_fp4
(
qx
))
sf
=
sf
[:
dqx
.
shape
[
0
],
:
dqx
.
shape
[
1
]]
dequant
=
dqx
*
sf
*
(
amax
/
(
6.0
*
448
))
return
dequant
def
RHT
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_wgrad_sign_vector
()
->
torch
.
Tensor
:
"""Hard-coded signs for Hadamard transform"""
return
torch
.
tensor
(
[
1.0
,
1.0
,
1.0
,
-
1.0
,
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
1.0
,
-
1.0
,
1.0
,
-
1.0
,
-
1.0
,
],
dtype
=
torch
.
float32
,
)
def
_build_hadamard_matrix
(
size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
with_random_sign_mask
:
bool
=
True
)
->
torch
.
Tensor
:
"""Construct a Hadamard matrix of given power-of-two size with entries +-1.
Uses Sylvester construction to avoid SciPy dependency.
"""
assert
(
size
&
(
size
-
1
))
==
0
,
"Hadamard size must be a power of two"
h
=
torch
.
ones
((
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
while
h
.
shape
[
0
]
<
size
:
h
=
torch
.
cat
(
[
torch
.
cat
([
h
,
h
],
dim
=
1
),
torch
.
cat
([
h
,
-
h
],
dim
=
1
),
],
dim
=
0
,
)
if
with_random_sign_mask
:
sign_mat
=
get_wgrad_sign_vector
().
to
(
device
)
*
torch
.
eye
(
size
,
device
=
device
,
dtype
=
torch
.
float32
)
h
=
sign_mat
@
h
return
h
.
to
(
dtype
)
rht_dim
=
16
# Build H and scale
H
=
_build_hadamard_matrix
(
rht_dim
,
x
.
device
,
x
.
dtype
)
scale
=
1.0
/
float
(
rht_dim
)
**
0.5
# Perform blockwise transform along the last dimension
original_shape
=
x
.
shape
x_mat
=
x
.
contiguous
().
view
(
-
1
,
rht_dim
)
# Random sign matrix is identity in this reference (no sign flipping)
transform
=
H
*
scale
out
=
x_mat
@
transform
return
out
.
view
(
original_shape
)
def
quantize_fp4
(
x
:
torch
.
Tensor
,
use_stochastic_rounding
:
bool
,
use_2D
:
bool
,
use_RHT
:
bool
)
->
torch
.
Tensor
:
nvfp4_quantizer
=
NVFP4Quantizer
(
rowwise
=
True
,
columnwise
=
True
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
use_RHT
,
with_post_rht_amax
=
True
,
stochastic_rounding
=
use_stochastic_rounding
,
with_2d_quantization
=
use_2D
,
)
x_nvfp4_sut
=
nvfp4_quantizer
(
x
)
# Extract data from NVFP4Tensor
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_scale_inv
assert
x_nvfp4_sut
.
_columnwise_data
is
not
None
qx_t
:
torch
.
Tensor
=
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_columnwise_scale_inv
is
not
None
sx_t
:
torch
.
Tensor
=
x_nvfp4_sut
.
_columnwise_scale_inv
return
qx
,
sx
,
qx_t
,
sx_t
def
check_quantization_nvfp4_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
use_2D
:
bool
,
use_RHT
:
bool
)
->
None
:
device
=
"cuda"
torch
.
manual_seed
(
seed
)
n_iters
=
50
x
=
torch
.
randn
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
*
2
-
1
y
=
x
.
t
().
contiguous
()
if
use_RHT
:
y
=
RHT
(
y
)
amax
=
torch
.
max
(
torch
.
abs
(
x
)).
float
()
q_rn
,
s_rn
,
q_t_rn
,
s_t_rn
=
quantize_fp4
(
x
,
use_stochastic_rounding
=
False
,
use_2D
=
use_2D
,
use_RHT
=
use_RHT
)
dq_rn
=
dequantize_fp4
(
q_rn
,
s_rn
,
amax
)
dq_t_rn
=
dequantize_fp4
(
q_t_rn
,
s_t_rn
,
amax
)
error_rn
=
(
dq_rn
-
x
).
float
()
me_rn
=
torch
.
sqrt
((
error_rn
*
error_rn
).
mean
())
error_t_rn
=
(
dq_t_rn
-
y
).
float
()
me_t_rn
=
torch
.
sqrt
((
error_t_rn
*
error_t_rn
).
mean
())
sr_result
=
torch
.
zeros_like
(
x
).
float
()
sr_t_result
=
torch
.
zeros_like
(
x
).
float
().
t
().
contiguous
()
for
i
in
range
(
n_iters
):
q_sr
,
s_sr
,
q_t_sr
,
s_t_sr
=
quantize_fp4
(
x
,
use_stochastic_rounding
=
True
,
use_2D
=
use_2D
,
use_RHT
=
use_RHT
)
dq_sr
=
dequantize_fp4
(
q_sr
,
s_sr
,
amax
)
dq_t_sr
=
dequantize_fp4
(
q_t_sr
,
s_t_sr
,
amax
)
sr_result
+=
dq_sr
.
float
()
sr_t_result
+=
dq_t_sr
.
float
()
# sr_result_tmp = sr_result / (i + 1)
# error_sr = (sr_result_tmp - x).float()
# me_sr = torch.sqrt((error_sr * error_sr).mean())
# sr_t_result_tmp = sr_t_result / (i + 1)
# error_t_sr = (sr_t_result_tmp - y).float()
# me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
# print(f"Iteration {i}: RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
# print(f"Iteration {i}: RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
# Get the mean result of the stochastic rounding
# It should be more accurate than the RN result
sr_result
/=
n_iters
error_sr
=
(
sr_result
-
x
).
float
()
me_sr
=
torch
.
sqrt
((
error_sr
*
error_sr
).
mean
())
sr_t_result
/=
n_iters
error_t_sr
=
(
sr_t_result
-
y
).
float
()
me_t_sr
=
torch
.
sqrt
((
error_t_sr
*
error_t_sr
).
mean
())
print
(
f
"RMSE SR:
{
me_sr
:.
3
e
}
| RMSE RN:
{
me_rn
:.
3
e
}
"
)
print
(
f
"RMSE SR_t:
{
me_t_sr
:.
3
e
}
| RMSE RN_t:
{
me_t_rn
:.
3
e
}
"
)
assert
me_sr
<
me_rn
,
"Stochastic rounding failed - error larger than the round to nearest."
assert
me_t_sr
<
me_t_rn
,
"Stochastic rounding failed - error larger than the round to nearest."
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
(
8192
,
8192
),
(
8192
,
8256
),
# to test the nonfused RHT path
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"use_2D"
,
[
False
,
True
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"use_RHT"
,
[
False
,
True
],
ids
=
str
)
def
test_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
use_2D
:
bool
,
use_RHT
:
bool
,
M
:
int
,
N
:
int
,
)
->
None
:
if
x_dtype
==
torch
.
float32
and
use_RHT
:
pytest
.
skip
(
"RHT is only supported with bfloat16"
)
check_quantization_nvfp4_versus_reference
(
x_dtype
=
x_dtype
,
use_2D
=
use_2D
,
use_RHT
=
use_RHT
,
M
=
M
,
N
=
N
,
)
tests/pytorch/test_checkpoint.py
View file @
063ef88d
...
...
@@ -12,13 +12,15 @@ import pathlib
import
pytest
import
torch
from
typing
import
Optional
import
transformer_engine.pytorch
as
te
from
utils
import
make_recipe
# Check supported quantization schemes
fp8_available
,
reason_for_no_fp8
=
te
.
fp8
.
FP8GlobalStateManager
.
is_fp8_available
(
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
fp8
.
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
# Test cases for loading checkpoint files
...
...
@@ -65,16 +67,16 @@ class TestLoadCheckpoint:
if
name
==
"ops_linear"
:
return
te
.
ops
.
Linear
(
1
,
1
)
if
name
==
"linear.fp8"
:
with
te
.
fp8
_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
with
te
.
quantized
_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
return
te
.
Linear
(
16
,
16
)
if
name
==
"ops_linear.fp8"
:
with
te
.
fp8
_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
with
te
.
quantized
_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
return
te
.
ops
.
Linear
(
16
,
16
)
if
name
==
"linear.mxfp8"
:
with
te
.
fp8
_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
with
te
.
quantized
_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
return
te
.
Linear
(
32
,
32
)
if
name
==
"ops_linear.mxfp8"
:
with
te
.
fp8
_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
with
te
.
quantized
_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
return
te
.
ops
.
Linear
(
32
,
32
)
raise
ValueError
(
f
"Unrecognized module name (
{
name
}
)"
)
...
...
tests/pytorch/test_cpu_offloading.py
View file @
063ef88d
...
...
@@ -12,14 +12,13 @@ import torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
from
utils
import
ModelConfig
,
get_available_attention_backends
# Check supported quantization schemes
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_available
=
te
.
is_fp8_available
()
mxfp8_available
=
te
.
is_mxfp8_available
()
quantization_recipes
:
Optional
[
recipe
.
Recipe
]
=
[
None
]
if
fp8_available
:
...
...
@@ -79,9 +78,9 @@ def _warmup_model(
"""Perform forward and backward pass"""
tensor
=
_make_input
()
for
module
in
modules
:
with
te
.
fp8_
autocast
(
with
te
.
autocast
(
enabled
=
quantization_recipe
is
not
None
,
fp8_
recipe
=
quantization_recipe
,
recipe
=
quantization_recipe
,
):
tensor
=
module
(
tensor
)
tensor
.
sum
().
backward
()
...
...
@@ -159,8 +158,8 @@ def _measure_cached_memory(
tensor
=
inp
memory_before_forward
=
torch
.
cuda
.
memory_allocated
()
/
(
1024
**
2
)
for
module
in
modules
:
with
te
.
fp8_
autocast
(
enabled
=
quantization_recipe
is
not
None
,
fp8_
recipe
=
quantization_recipe
with
te
.
autocast
(
enabled
=
quantization_recipe
is
not
None
,
recipe
=
quantization_recipe
),
offload_context
:
tensor
=
module
(
tensor
)
tensor
=
sync_function
(
tensor
)
...
...
tests/pytorch/test_cuda_graphs.py
View file @
063ef88d
...
...
@@ -13,12 +13,15 @@ from transformer_engine.pytorch import (
Linear
,
MultiheadAttention
,
TransformerLayer
,
fp8_
autocast
,
fp8
_model_init
,
autocast
,
quantized
_model_init
,
make_graphed_callables
,
is_fp8_available
,
is_fp8_block_scaling_available
,
is_mxfp8_available
,
is_bf16_available
,
)
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.quantization
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.common
import
recipe
from
utils
import
ModelConfig
,
reset_rng_states
...
...
@@ -28,20 +31,67 @@ if IS_HIP_EXTENSION:
from
functools
import
cache
# Check if FP8 is supported.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_available
=
is_fp8_available
()
fp8_block_scaling_available
=
is_fp8_block_scaling_available
()
mxfp8_available
=
is_mxfp8_available
()
# Reset RNG states.
reset_rng_states
()
model_configs
=
{
"small"
:
ModelConfig
(
3
2
,
2
,
2
,
32
),
"small"
:
ModelConfig
(
2
,
3
2
,
2
,
32
),
}
def
nvfp4_vanilla
():
nvfp4_recipe
=
recipe
.
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
recipe
.
QParams
()
nvfp4_recipe
.
fp4_quant_fwd_weight
=
recipe
.
QParams
()
nvfp4_recipe
.
fp4_quant_bwd_grad
=
recipe
.
QParams
()
return
nvfp4_recipe
def
nvfp4_rht_and_2d_quantization
():
nvfp4_recipe
=
recipe
.
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
recipe
.
QParams
(
random_hadamard_transform
=
True
,
fp4_2d_quantization
=
False
)
nvfp4_recipe
.
fp4_quant_fwd_weight
=
recipe
.
QParams
(
random_hadamard_transform
=
False
,
fp4_2d_quantization
=
True
)
nvfp4_recipe
.
fp4_quant_bwd_grad
=
recipe
.
QParams
(
random_hadamard_transform
=
True
,
fp4_2d_quantization
=
False
)
return
nvfp4_recipe
def
check_rht_usage
(
recipe
:
recipe
.
Recipe
)
->
bool
:
# if using RHT, we can only support bf16
# check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
if
recipe
.
nvfp4
():
if
(
recipe
.
fp4_quant_fwd_inp
.
random_hadamard_transform
or
recipe
.
fp4_quant_fwd_weight
.
random_hadamard_transform
or
recipe
.
fp4_quant_bwd_grad
.
random_hadamard_transform
):
return
True
return
False
def
get_nvfp4_inp_supported_dtypes
(
recipe
:
recipe
.
Recipe
,
dtype
:
torch
.
dtype
)
->
bool
:
supported_input_dtypes
=
[]
if
recipe
.
nvfp4
():
supported_input_dtypes
.
append
(
torch
.
bfloat16
)
# if not using RHT, we can add fp32 as well
if
not
check_rht_usage
(
recipe
):
supported_input_dtypes
.
append
(
torch
.
float32
)
return
supported_input_dtypes
fp8_recipes
=
[]
if
mxfp8_available
:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
fp8_recipes
.
append
(
nvfp4_rht_and_2d_quantization
())
if
fp8_block_scaling_available
:
fp8_recipes
.
append
(
recipe
.
Float8BlockScaling
())
if
fp8_available
:
...
...
@@ -50,7 +100,7 @@ if fp8_available:
# Supported data types
dtypes
:
List
[
torch
.
dtype
]
=
[
torch
.
float32
,
torch
.
float16
]
if
is_bf16_
compati
ble
():
# bf16 requires sm_80 or higher
if
is_bf16_
availa
ble
():
# bf16 requires sm_80 or higher
dtypes
.
append
(
torch
.
bfloat16
)
...
...
@@ -167,7 +217,7 @@ def _test_cuda_graphs(
fp8_weight_caching
=
False
# Create modules.
with
fp8
_model_init
(
enabled
=
fp8_params
,
recipe
=
fp8_recipe
):
with
quantized
_model_init
(
enabled
=
fp8_params
,
recipe
=
fp8_recipe
):
if
module
==
"transformer"
:
modules
=
[
TransformerLayer
(
...
...
@@ -247,9 +297,9 @@ def _test_cuda_graphs(
model
,
(
generate_data
(
model_config
,
dtype
,
warmup
=
True
),),
num_warmup_iters
=
10
,
fp8_
enabled
=
fp8
,
fp8_weight_caching
=
fp8_weight_caching
,
fp8_
recipe
=
fp8_recipe
,
enabled
=
fp8
,
cache_quantized_params
=
fp8_weight_caching
,
recipe
=
fp8_recipe
,
)
elif
graph_mode
==
"individual"
:
# Graph individual modules.
...
...
@@ -258,9 +308,9 @@ def _test_cuda_graphs(
module
,
(
generate_data
(
model_config
,
dtype
,
warmup
=
True
),),
num_warmup_iters
=
10
,
fp8_
enabled
=
fp8
,
fp8_weight_caching
=
fp8_weight_caching
,
fp8_
recipe
=
fp8_recipe
,
enabled
=
fp8
,
cache_quantized_params
=
fp8_weight_caching
,
recipe
=
fp8_recipe
,
)
for
module
in
modules
]
...
...
@@ -277,7 +327,7 @@ def _test_cuda_graphs(
for
grad_accumulation_step
in
range
(
2
):
input_
=
generate_data
(
model_config
,
dtype
)
grad_output
=
generate_data
(
model_config
,
dtype
,
requires_grad
=
False
)
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
fp8_recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
fp8_recipe
):
kwargs
=
{}
if
fp8_weight_caching
:
kwargs
[
"is_first_microbatch"
]
=
grad_accumulation_step
==
0
...
...
@@ -291,7 +341,7 @@ def _test_cuda_graphs(
@
pytest
.
mark
.
parametrize
(
"module"
,
_test_cuda_graphs_modules
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
@
pytest
.
mark
.
parametrize
(
"fp8_params"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
+
[
None
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
+
[
None
]
,
ids
=
lambda
r
:
type
(
r
).
__name__
)
def
test_make_graphed_callables
(
*
,
module
:
str
,
...
...
@@ -308,15 +358,25 @@ def test_make_graphed_callables(
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
if
fp8_weight_caching
and
not
fp8
:
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
if
fp8
and
fp8_recipe
.
float8_block_scaling
()
and
module
==
"linear_op"
:
pytest
.
skip
(
"Module not yet supported for float8_block_scaling with CUDA graphs"
)
if
fp8
and
(
fp8_recipe
.
float8_block_scaling
()
or
fp8_recipe
.
nvfp4
())
and
module
==
"linear_op"
:
pytest
.
skip
(
f
"Module not yet supported for
{
fp8_recipe
.
__class__
.
__name__
}
with CUDA graphs"
)
if
fp8
and
fp8_recipe
.
nvfp4
():
if
dtype
not
in
get_nvfp4_inp_supported_dtypes
(
fp8_recipe
,
dtype
):
pytest
.
skip
(
f
"Input dtype
{
dtype
}
not supported for NVFP4 Recipe"
f
"
{
fp8_recipe
.
__class__
.
__name__
}
"
)
if
fp8_params
:
pytest
.
skip
(
"NVFP4 params not supported"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
"FP8 not supported on rocm GPU."
)
if
fp8
and
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_
block
_
scaling
)
pytest
.
skip
(
"FP8
block
scaling
not supported on rocm GPU."
)
if
fp8
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
"MXFP8 not supported on rocm GPU."
)
# Run model with different CUDA graph settings.
model_config
=
model_configs
[
model_config
]
kwargs
=
dict
(
...
...
@@ -353,17 +413,19 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
"module"
,
_test_make_graphed_callables_with_fp8_weight_caching_modules
,
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
@
pytest
.
mark
.
parametrize
(
"fp8_params"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
,
ids
=
lambda
r
:
type
(
r
).
__name__
)
def
test_make_graphed_callables_with_fp8_weight_caching
(
*
,
module
:
str
,
dtype
:
torch
.
dtype
,
fp8_params
:
bool
,
fp8_recipe
:
recipe
.
Recipe
,
)
->
None
:
test_make_graphed_callables
(
module
=
module
,
dtype
=
torch
.
float32
,
dtype
=
dtype
,
fp8_params
=
fp8_params
,
fp8_recipe
=
fp8_recipe
,
fp8_weight_caching
=
True
,
...
...
@@ -415,7 +477,7 @@ def _test_cuda_graphs_with_dot_product_attention(
model
,
generate_data_for_dot_product_attention
(
model_config
,
dtype
,
warmup
=
True
),
num_warmup_iters
=
10
,
fp8_
enabled
=
False
,
enabled
=
False
,
)
# Forward and backward passes.
...
...
tests/pytorch/test_custom_recipe.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch
import
(
autocast
,
Linear
,
LayerNormLinear
,
LayerNormMLP
,
GroupedLinear
,
Float8CurrentScalingQuantizer
,
)
import
transformer_engine.pytorch.ops
as
te_ops
@
pytest
.
mark
.
parametrize
(
"module_type"
,
[
"Linear"
,
"LayerNormLinear"
,
"OpsLinear"
,
"LayerNormMLP"
])
def
test_custom_recipe_sanity
(
module_type
):
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
torch
.
manual_seed
(
0
)
# Simple linear layer with dims divisible by 16
in_features
=
64
out_features
=
64
batch
=
32
if
module_type
==
"Linear"
:
model
=
Linear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
elif
module_type
==
"LayerNormLinear"
:
model
=
LayerNormLinear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
elif
module_type
==
"LayerNormMLP"
:
# hidden_size == in_features == out_features for simplicity
model
=
LayerNormMLP
(
hidden_size
=
in_features
,
ffn_hidden_size
=
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
else
:
# OpsLinear path
model
=
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
inp
=
torch
.
randn
(
batch
,
in_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
True
)
# Single factory: map roles to quantizers
def
quantizer_factory
(
role
):
if
role
in
(
"linear_input"
,
"linear_weight"
,
"linear_output"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
if
role
in
(
"linear_grad_output"
,
"linear_grad_input"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E5M2
,
device
=
"cuda"
)
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
custom_recipe
=
recipe
.
CustomRecipe
(
qfactory
=
quantizer_factory
)
# Execute with custom recipe
with
autocast
(
enabled
=
True
,
recipe
=
custom_recipe
):
out
=
model
(
inp
)
loss
=
out
.
float
().
sum
()
loss
.
backward
()
# Basic sanity: gradients exist
assert
inp
.
grad
is
not
None
def
test_custom_recipe_grouped_linear_sanity
():
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
torch
.
manual_seed
(
0
)
num_gemms
=
3
in_features
=
64
out_features
=
64
batch
=
32
base
=
batch
//
num_gemms
rem
=
batch
%
num_gemms
m_splits
=
[
base
+
(
1
if
i
<
rem
else
0
)
for
i
in
range
(
num_gemms
)]
model
=
GroupedLinear
(
num_gemms
,
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
inp
=
torch
.
randn
(
batch
,
in_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
True
)
def
quantizer_factory
(
role
):
if
role
in
(
"linear_input"
,
"linear_weight"
,
"linear_output"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
if
role
in
(
"linear_grad_output"
,
"linear_grad_input"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E5M2
,
device
=
"cuda"
)
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
custom_recipe
=
recipe
.
CustomRecipe
(
qfactory
=
quantizer_factory
)
with
autocast
(
enabled
=
True
,
recipe
=
custom_recipe
):
out
=
model
(
inp
,
m_splits
)
loss
=
out
.
float
().
sum
()
loss
.
backward
()
assert
inp
.
grad
is
not
None
def
test_custom_recipe_matches_current_scaling
():
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
torch
.
manual_seed
(
123
)
in_features
=
64
out_features
=
64
batch
=
32
# Create two identical models
model_ref
=
Linear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
model_custom
=
Linear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
model_custom
.
load_state_dict
(
model_ref
.
state_dict
())
# Identical inputs for both paths
base_inp
=
torch
.
randn
(
batch
,
in_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
inp_ref
=
base_inp
.
clone
().
detach
().
requires_grad_
(
True
)
inp_custom
=
base_inp
.
clone
().
detach
().
requires_grad_
(
True
)
# Reference: use Float8CurrentScaling recipe
ref_recipe
=
recipe
.
Float8CurrentScaling
()
with
autocast
(
enabled
=
True
,
recipe
=
ref_recipe
):
out_ref
=
model_ref
(
inp_ref
)
# Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd)
ref_fwd_in
=
model_ref
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
ref_fwd_w
=
model_ref
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
ref_fwd_out
=
model_ref
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
ref_bwd_go
=
model_ref
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
ref_bwd_gi
=
model_ref
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
]
assert
ref_fwd_in
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
ref_fwd_w
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
ref_fwd_out
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
ref_bwd_go
.
dtype
==
tex
.
DType
.
kFloat8E5M2
assert
ref_bwd_gi
.
dtype
==
tex
.
DType
.
kFloat8E5M2
# Stress dynamic range in grad_output
scale
=
torch
.
ones
(
out_features
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale
[
0
]
=
1e8
scale
[
1
]
=
1e-8
loss_ref
=
(
out_ref
.
float
()
*
scale
.
view
(
1
,
-
1
)).
sum
()
loss_ref
.
backward
()
# Custom: single factory returning quantizers per role to match Float8CurrentScaling
def
quantizer_factory
(
role
):
if
role
in
(
"linear_input"
,
"linear_weight"
,
"linear_output"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
if
role
in
(
"linear_grad_output"
,
"linear_grad_input"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E5M2
,
device
=
"cuda"
)
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
custom_recipe
=
recipe
.
CustomRecipe
(
qfactory
=
quantizer_factory
)
with
autocast
(
enabled
=
True
,
recipe
=
custom_recipe
):
out_custom
=
model_custom
(
inp_custom
)
# Assert dtypes for custom quantizers match reference mapping
cus_fwd_in
=
model_custom
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
cus_fwd_w
=
model_custom
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
cus_fwd_out
=
model_custom
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
cus_bwd_go
=
model_custom
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
cus_bwd_gi
=
model_custom
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
]
assert
cus_fwd_in
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
cus_fwd_w
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
cus_fwd_out
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
cus_bwd_go
.
dtype
==
tex
.
DType
.
kFloat8E5M2
assert
cus_bwd_gi
.
dtype
==
tex
.
DType
.
kFloat8E5M2
loss_custom
=
(
out_custom
.
float
()
*
scale
.
view
(
1
,
-
1
)).
sum
()
loss_custom
.
backward
()
# Compare forward outputs (exact match expected)
assert
torch
.
allclose
(
out_ref
,
out_custom
,
rtol
=
0.0
,
atol
=
0.0
)
# Compare input gradients
assert
inp_ref
.
grad
is
not
None
and
inp_custom
.
grad
is
not
None
assert
torch
.
allclose
(
inp_ref
.
grad
,
inp_custom
.
grad
,
rtol
=
0.0
,
atol
=
0.0
)
# Compare parameter gradients (weights and bias if present)
ref_params
=
dict
(
model_ref
.
named_parameters
())
custom_params
=
dict
(
model_custom
.
named_parameters
())
for
name
,
p_ref
in
ref_params
.
items
():
p_cus
=
custom_params
[
name
]
assert
p_ref
.
grad
is
not
None
and
p_cus
.
grad
is
not
None
assert
torch
.
allclose
(
p_ref
.
grad
,
p_cus
.
grad
,
rtol
=
0.0
,
atol
=
0.0
)
def
test_custom_recipe_ops_linear_2_1_layout
():
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
torch
.
manual_seed
(
7
)
in_features
=
64
out_features
=
64
batch
=
16
# Use ops.Linear which consumes 2 forward quantizers and 1 backward quantizer
op
=
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
inp
=
torch
.
randn
(
batch
,
in_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
True
)
def
quantizer_factory
(
role
):
if
role
in
(
"linear_input"
,
"linear_weight"
,
"linear_output"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
if
role
in
(
"linear_grad_output"
,
"linear_grad_input"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E5M2
,
device
=
"cuda"
)
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
custom
=
recipe
.
CustomRecipe
(
qfactory
=
quantizer_factory
)
with
autocast
(
enabled
=
True
,
recipe
=
custom
):
out
=
op
(
inp
)
loss
=
out
.
float
().
sum
()
loss
.
backward
()
assert
inp
.
grad
is
not
None
def
test_custom_recipe_factory_invocation_counts_and_cycling
():
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
torch
.
manual_seed
(
13
)
in_features
=
64
out_features
=
64
batch
=
8
op
=
Linear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
)
inp
=
torch
.
randn
(
batch
,
in_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
True
)
# Counters per role
counts
=
{
"linear_input"
:
0
,
"linear_weight"
:
0
,
"linear_output"
:
0
,
"linear_grad_output"
:
0
,
"linear_grad_input"
:
0
,
}
def
quantizer_factory
(
role
):
if
role
in
counts
:
counts
[
role
]
+=
1
if
role
in
(
"linear_input"
,
"linear_weight"
,
"linear_output"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
torch
.
device
(
"cuda"
))
if
role
in
(
"linear_grad_output"
,
"linear_grad_input"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E5M2
,
device
=
torch
.
device
(
"cuda"
))
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
torch
.
device
(
"cuda"
))
custom
=
recipe
.
CustomRecipe
(
qfactory
=
quantizer_factory
)
# Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory),
# and backward to build 2 quantizers (cycled from 1 factory).
with
autocast
(
enabled
=
True
,
recipe
=
custom
):
out
=
op
(
inp
)
loss
=
out
.
float
().
sum
()
loss
.
backward
()
# Single GEMM: forward should request input, weight, output; backward grad_output, grad_input
assert
counts
[
"linear_input"
]
==
1
assert
counts
[
"linear_weight"
]
==
1
assert
counts
[
"linear_output"
]
==
1
assert
counts
[
"linear_grad_output"
]
==
1
assert
counts
[
"linear_grad_input"
]
==
1
def
test_factories_return_distinct_instances_and_buffers
():
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
# Two calls should produce distinct quantizer objects and distinct tensor buffers
def
factory
():
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
torch
.
device
(
"cuda"
))
q1
=
factory
()
q2
=
factory
()
assert
q1
is
not
q2
assert
q1
.
scale
.
data_ptr
()
!=
q2
.
scale
.
data_ptr
()
assert
q1
.
amax
.
data_ptr
()
!=
q2
.
amax
.
data_ptr
()
# Mutating one should not affect the other
q1
.
scale
.
fill_
(
123.0
)
assert
not
torch
.
equal
(
q1
.
scale
,
q2
.
scale
)
tests/pytorch/test_deferred_init.py
View file @
063ef88d
...
...
@@ -4,7 +4,6 @@
import
pytest
import
torch
import
torch.distributed
as
dist
import
transformer_engine.pytorch
as
te
...
...
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
063ef88d
...
...
@@ -4,13 +4,13 @@
import
pytest
import
torch
import
transformer_engine
as
te
import
transformer_engine
.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
,
int8_simulation_fp8
)
from
transformer_engine.pytorch
.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch.fp8
import
(
blockwise_fp8_block_len
,
int8_simulation_fp8
)
from
transformer_engine.pytorch
import
(
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
get_device_compute_capability
,
)
from
references.blockwise_quantizer_reference
import
CuBLASScaleMunger
...
...
@@ -19,8 +19,9 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
from
transformer_engine.pytorch.cpp_extensions.gemm
import
w8a8_int8_general_gemm
def
fp8_blockwise_gemm_supported
()
->
bool
:
supported
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
return
supported
supported
=
te
.
is_fp8_block_scaling_available
()
emulated
=
get_device_compute_capability
()
>=
(
10
,
0
)
return
supported
and
not
emulated
def
cublas_gemm_fp8_blockwise_case
(
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
063ef88d
...
...
@@ -8,14 +8,13 @@ import os
import
pathlib
import
pytest
import
torch
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.common.recipe
import
Float8BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch
.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch
import
(
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
get_device_compute_capability
,
)
from
references.blockwise_quantizer_reference
import
(
BlockwiseQuantizerReference
,
...
...
@@ -32,7 +31,8 @@ TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tenso
tensor_dump_dir_env
=
os
.
getenv
(
"NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR"
)
if
tensor_dump_dir_env
is
not
None
:
TENSOR_DUMP_DIR
=
pathlib
.
Path
(
tensor_dump_dir_env
)
recipe_available
,
reason_for_no_recipe
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
recipe_available
,
reason_for_no_recipe
=
te
.
is_fp8_block_scaling_available
(
return_reason
=
True
)
recipe_emulated
=
get_device_compute_capability
()
>=
(
10
,
0
)
class
GetRecipes
:
...
...
@@ -219,6 +219,12 @@ def check_quantization_block_tiling_versus_reference(
pow_2_scales
:
bool
,
tile_size
:
Tuple
[
int
,
int
],
)
->
None
:
if
recipe_emulated
and
not
pow_2_scales
:
pytest
.
skip
(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)
te_dtype
=
TE_DType
[
quant_dtype
]
if
tile_size
in
((
1
,
128
),
(
1
,
64
)):
block_scaling_dim
=
1
...
...
@@ -414,6 +420,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
tile_size
:
Tuple
[
int
,
int
],
extrema_high
:
bool
,
)
->
None
:
if
recipe_emulated
and
not
pow_2_scales
:
pytest
.
skip
(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)
# This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation.
if
blockwise_fp8_block_len
!=
tile_size
[
1
]:
...
...
tests/pytorch/test_float8_current_scaling_exact.py
View file @
063ef88d
...
...
@@ -8,12 +8,9 @@ import torch
import
pytest
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.common.recipe
import
Float8CurrentScaling
from
transformer_engine.pytorch.
fp8
import
fp8_
autocast
,
get_fp8_torch_dtype
from
transformer_engine.pytorch.
quantization
import
autocast
,
get_fp8_torch_dtype
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8
...
...
@@ -25,7 +22,7 @@ if tensor_dump_dir_env is not None:
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
class
GetRecipes
:
...
...
@@ -274,6 +271,14 @@ class TestFP8RecipeLinearBase:
if
bgrad_list
is
not
None
and
bgrad
is
not
None
:
bgrad_list
.
append
(
bgrad
.
detach
().
clone
())
# Stack the results
return
(
torch
.
stack
(
y_q_list
),
torch
.
stack
(
dgrad_list
),
torch
.
stack
(
wgrad_list
),
torch
.
stack
(
bgrad_list
)
if
bgrad_list
is
not
None
else
None
,
)
@
classmethod
def
run_linear
(
cls
,
...
...
@@ -388,7 +393,7 @@ class TestFP8RecipeLinearBase:
# recipe1
using_fp8_recipe
=
recipe1
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe1
()):
with
autocast
(
enabled
=
True
,
recipe
=
recipe1
()):
y_q_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
else
:
y_q_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
...
...
@@ -396,7 +401,7 @@ class TestFP8RecipeLinearBase:
# recipe2
using_fp8_recipe
=
recipe2
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe2
()):
with
autocast
(
enabled
=
True
,
recipe
=
recipe2
()):
y_q
,
dgrad
,
wgrad
,
bgrad
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
else
:
y_q
,
dgrad
,
wgrad
,
bgrad
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
...
...
@@ -611,7 +616,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe1
using_fp8_recipe
=
recipe1
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe1
()):
with
autocast
(
enabled
=
True
,
recipe
=
recipe1
()):
y_q_ref
,
ln_out_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_layernorm_linear
(
x
,
w
,
...
...
@@ -633,7 +638,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe2
using_fp8_recipe
=
recipe2
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe2
()):
with
autocast
(
enabled
=
True
,
recipe
=
recipe2
()):
y_q
,
ln_out
,
dgrad
,
wgrad
,
bgrad
=
self
.
run_layernorm_linear
(
x
,
w
,
...
...
tests/pytorch/test_float8blockwisetensor.py
View file @
063ef88d
...
...
@@ -11,12 +11,11 @@ import pytest
import
torch
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch
import
(
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
get_device_compute_capability
,
)
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
...
...
tests/pytorch/test_float8tensor.py
View file @
063ef88d
...
...
@@ -11,13 +11,11 @@ import torch
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
from
transformer_engine.pytorch
import
(
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.constants
import
TE_DType
,
TE_DType_To_Torch
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
import
transformer_engine_torch
as
tex
...
...
@@ -47,7 +45,7 @@ def _to_list(x: Union[Iterable, Any]) -> List:
DimsType
=
Union
[
Iterable
[
int
],
int
]
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
# delayed scaling
...
...
tests/pytorch/test_fused_optimizer.py
View file @
063ef88d
...
...
@@ -11,14 +11,11 @@ from torch import nn
from
torch.testing._internal.common_device_type
import
largeTensorTest
import
transformer_engine.pytorch
as
te
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.pytorch.attention.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch
import
fp8_model_init
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch
import
MultiheadAttention
,
quantized_model_init
,
is_bf16_available
from
transformer_engine.pytorch.utils
import
gpu_autocast_ctx
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
class
TestFusedOptimizer
:
...
...
@@ -188,7 +185,7 @@ class TestFusedAdam(TestFusedOptimizer):
build_model_context
=
nullcontext
build_model_context_args
=
{}
if
use_fp8_params
:
build_model_context
=
fp8
_model_init
build_model_context
=
quantized
_model_init
build_model_context_args
[
"enabled"
]
=
True
with
build_model_context
(
**
build_model_context_args
):
...
...
@@ -286,7 +283,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype
=
torch
.
float32
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_fp32_master
(
self
):
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
...
...
@@ -298,7 +295,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype
=
torch
.
float32
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_fp32_master_store_param_remainders
(
self
):
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
...
...
@@ -311,7 +308,7 @@ class TestFusedAdam(TestFusedOptimizer):
store_param_remainders
=
True
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_fp16_master
(
self
):
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
...
...
@@ -325,7 +322,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_bf16_grad
(
self
):
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
...
...
@@ -339,7 +336,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_fp16_exp_avg
(
self
):
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
...
...
@@ -353,7 +350,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_bf16_exp_avg
(
self
):
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
...
...
@@ -367,7 +364,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_exp_avg
(
self
):
self
.
gen_precision_aware_test
(
...
...
@@ -382,7 +379,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
1e-2
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_fp16_exp_avg_sq
(
self
):
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
...
...
@@ -396,7 +393,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_bf16_exp_avg_sq
(
self
):
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
...
...
@@ -410,7 +407,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_exp_avg_sq
(
self
):
self
.
gen_precision_aware_test
(
...
...
@@ -424,7 +421,7 @@ class TestFusedAdam(TestFusedOptimizer):
skip_assert
=
True
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_bf16_model_weight_cast
(
self
):
dtype
=
torch
.
bfloat16
model
=
MultiheadAttention
(
...
...
@@ -468,7 +465,7 @@ class TestFusedAdam(TestFusedOptimizer):
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_model_weight_cast
(
self
):
dtype
=
torch
.
bfloat16
with
fp8
_model_init
(
enabled
=
True
,
recipe
=
DelayedScaling
()):
with
quantized
_model_init
(
enabled
=
True
,
recipe
=
DelayedScaling
()):
model
=
MultiheadAttention
(
hidden_size
=
1024
,
num_attention_heads
=
16
,
...
...
tests/pytorch/test_fused_rope.py
View file @
063ef88d
...
...
@@ -373,3 +373,19 @@ def test_fused_qkv_rope(
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
def
test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_without_autocast
():
rope_layer
=
RotaryPositionEmbedding
(
128
)
rope_embeddings_no_autocast
=
rope_layer
(
max_seq_len
=
1024
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
torch
.
bfloat16
):
rope_embeddings_autocast
=
rope_layer
(
max_seq_len
=
1024
)
torch
.
testing
.
assert_close
(
rope_embeddings_no_autocast
.
to
(
dtype
=
torch
.
bfloat16
),
rope_embeddings_autocast
.
to
(
dtype
=
torch
.
bfloat16
),
atol
=
1e-8
,
rtol
=
1e-8
,
)
tests/pytorch/test_fusible_ops.py
View file @
063ef88d
...
...
@@ -7,8 +7,6 @@ from __future__ import annotations
from
collections.abc
import
Iterable
import
io
import
math
import
pathlib
import
sys
from
typing
import
Optional
import
pytest
...
...
@@ -18,7 +16,6 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import
transformer_engine
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops.fused
import
(
BackwardActivationBias
,
...
...
@@ -29,20 +26,18 @@ from transformer_engine.pytorch.ops.fused import (
ForwardLinearBiasAdd
,
ForwardLinearScaleAdd
,
)
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Tensor
,
from
transformer_engine.pytorch
import
(
QuantizedTensor
,
Float8CurrentScalingQuantizer
,
Float8Quantizer
,
MXFP8Quantizer
,
NVFP4Quantizer
,
is_bf16_available
,
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
,
MXFP8Quantizer
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine_torch
as
tex
# Import utility functions
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
))
from
utils
import
dtype_tols
,
make_recipe
,
reset_rng_states
from
utils
import
dtype_tols
,
make_recipe
,
quantization_tols
,
reset_rng_states
if
IS_HIP_EXTENSION
:
import
os
...
...
@@ -52,13 +47,14 @@ if IS_HIP_EXTENSION:
return
(
os
.
getenv
(
"NVTE_USE_HIPBLASLT"
)
is
not
None
or
os
.
getenv
(
"NVTE_USE_ROCBLAS"
)
is
None
)
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Check for supported quantization schemes
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
nvfp4_available
,
reason_for_no_nvfp4
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
# Supported data types
_dtypes
:
list
[
torch
.
dtype
]
=
[
torch
.
float32
,
torch
.
float16
]
if
is_bf16_
compati
ble
():
# bf16 requires sm_80 or higher
if
is_bf16_
availa
ble
():
# bf16 requires sm_80 or higher
_dtypes
.
append
(
torch
.
bfloat16
)
# Supported devices
...
...
@@ -70,6 +66,8 @@ if fp8_available:
_quantization_list
.
extend
((
"fp8_delayed_scaling"
,
"fp8_current_scaling"
))
if
mxfp8_available
:
_quantization_list
.
append
(
"mxfp8"
)
if
nvfp4_available
:
_quantization_list
.
append
(
"nvfp4"
)
def
maybe_skip_quantization
(
...
...
@@ -77,6 +75,7 @@ def maybe_skip_quantization(
*
,
dims
:
Optional
[
Iterable
[
int
]
|
int
]
=
None
,
device
:
Optional
[
torch
.
device
|
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
None
:
"""Skip test case if a quantization scheme is not supported"""
...
...
@@ -84,12 +83,17 @@ def maybe_skip_quantization(
if
quantization
is
None
:
return
# Check if quantization scheme is supported
# Check if quantization scheme is supported on device
if
device
is
not
None
and
torch
.
device
(
device
).
type
!=
"cuda"
:
pytest
.
skip
(
"Quantization is only supported on CUDA devices"
)
if
quantization
in
(
"fp8"
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
)
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
quantization
==
"nvfp4"
and
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
# Check dims
if
dims
is
not
None
:
if
not
isinstance
(
dims
,
Iterable
):
dims
=
(
dims
,)
...
...
@@ -99,10 +103,14 @@ def maybe_skip_quantization(
elif
quantization
==
"mxfp8"
:
if
math
.
prod
(
dims
[:
-
1
])
%
32
!=
0
or
dims
[
-
1
]
%
32
!=
0
:
pytest
.
skip
(
"MXFP8 GEMMs require dims that are divisible by 32"
)
elif
quantization
==
"nvfp4"
:
if
math
.
prod
(
dims
[:
-
1
])
%
16
!=
0
or
dims
[
-
1
]
%
16
!=
0
:
pytest
.
skip
(
"NVFP4 GEMMs require dims that are divisible by 16"
)
# Check if device is supported
if
device
is
not
None
and
torch
.
device
(
device
).
type
!=
"cuda"
:
pytest
.
skip
(
"Quantization is only supported on CUDA devices"
)
# Check dtype
if
dtype
is
not
None
:
if
quantization
==
"nvfp4"
and
dtype
!=
torch
.
bfloat16
:
pytest
.
skip
(
"NVFP4 quantization is only supported with BF16 data"
)
@
torch
.
no_grad
()
...
...
@@ -152,6 +160,14 @@ def make_reference_and_test_tensors(
test
=
quantizer
(
test
)
elif
quantization
==
"mxfp8"
:
test
=
MXFP8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)(
test
)
elif
quantization
==
"nvfp4"
:
test
=
NVFP4Quantizer
(
with_rht
=
False
,
with_post_rht_amax
=
False
,
with_2d_quantization
=
False
,
stochastic_rounding
=
False
,
with_random_sign_mask
=
False
,
)(
test
)
else
:
raise
ValueError
(
f
"Unsupported quantization scheme (
{
quantization
}
)"
)
if
isinstance
(
test
,
QuantizedTensor
)
and
not
test_is_quantized
:
...
...
@@ -361,7 +377,7 @@ class TestFuser:
)
# Construct model
with
te
.
fp8
_model_init
(
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
recipe
=
recipe
):
model
=
te_ops
.
basic
.
BasicLinear
(
size
,
size
,
...
...
@@ -393,7 +409,7 @@ class TestFuser:
)
# Training step
with
te
.
fp8_
autocast
(
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
recipe
=
recipe
):
y
=
model
(
x
)
y
.
backward
(
dy
)
with
torch
.
no_grad
():
...
...
@@ -406,12 +422,12 @@ class TestFuser:
torch
.
testing
.
assert_close
(
y
,
torch
.
full_like
(
y
,
y_val_ref
),
**
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
),
**
quantization_tols
(
"fp8_delayed_scaling"
),
)
torch
.
testing
.
assert_close
(
x
.
grad
,
torch
.
full_like
(
x
.
grad
,
dx_val_ref
),
**
dtype_tols
(
tex
.
DType
.
kFloat8E5M2
),
**
quantization_tols
(
"fp8_delayed_scaling"
),
)
# Check that scaling factors match expected
...
...
@@ -445,7 +461,8 @@ class TestFuser:
# Skip invalid configurations
in_shape
=
(
size
,
size
)
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
init_dtype
)
maybe_skip_quantization
(
quantization
,
dtype
=
final_dtype
)
# Random data
dtype
=
torch
.
float32
...
...
@@ -461,7 +478,7 @@ class TestFuser:
)
# Construct operation
with
te
.
fp8
_model_init
(
enabled
=
with_quantization
,
recipe
=
make_recipe
(
quantization
)):
with
te
.
quantized
_model_init
(
enabled
=
with_quantization
,
recipe
=
make_recipe
(
quantization
)):
op
=
te_ops
.
Linear
(
size
,
size
,
bias
=
False
,
device
=
device
,
dtype
=
init_dtype
)
with
torch
.
no_grad
():
op
.
weight
.
copy_
(
w_test
)
...
...
@@ -513,11 +530,12 @@ class TestFuser:
# Skip invalid configurations
in_shape
=
(
size
,
size
)
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
model_dtype
)
maybe_skip_quantization
(
quantization
,
dtype
=
autocast_dtype
)
# Construct operation
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weights
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weights
,
recipe
=
recipe
):
op
=
te_ops
.
Linear
(
size
,
size
,
bias
=
False
,
device
=
device
,
dtype
=
model_dtype
)
# Check forward and backward pass
...
...
@@ -527,7 +545,7 @@ class TestFuser:
device
=
device
,
requires_grad
=
True
,
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
with
torch
.
autocast
(
device_type
=
device
.
type
,
dtype
=
autocast_dtype
):
y
=
op
(
x
)
y
.
backward
(
torch
.
zeros_like
(
y
))
...
...
@@ -540,7 +558,7 @@ class TestFuser:
x
.
grad
=
None
op
.
weight
.
grad
=
None
with
torch
.
autocast
(
device_type
=
device
.
type
,
dtype
=
autocast_dtype
):
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y
=
op
(
x
)
y
.
backward
(
torch
.
zeros_like
(
y
))
assert
y
.
dtype
==
autocast_dtype
...
...
@@ -569,7 +587,7 @@ class TestBasicOps:
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
...
@@ -635,7 +653,7 @@ class TestBasicOps:
# Skip invalid configurations
if
memory_format
==
torch
.
channels_last
and
len
(
in_shape
)
!=
4
:
pytest
.
skip
(
"torch.channels_last only supports 4D tensors"
)
maybe_skip_quantization
(
quantization
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
device
=
device
,
dtype
=
dtype
)
with_quantization
=
quantization
is
not
None
# Random data
...
...
@@ -701,7 +719,7 @@ class TestBasicOps:
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
...
@@ -763,7 +781,7 @@ class TestBasicOps:
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
device
=
device
,
dtype
=
dtype
)
if
quantization
==
"mxfp8"
:
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
)
...
...
@@ -790,7 +808,7 @@ class TestBasicOps:
# Implementation with fusible operation
op
=
te_ops
.
Quantize
(
forward
=
cast_forward
,
backward
=
cast_backward
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_
autocast
(
enabled
=
with_quantization
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
with_quantization
,
recipe
=
recipe
):
y_test
=
op
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -830,7 +848,7 @@ class TestBasicOps:
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
quantization_needed
=
any
(
(
...
...
@@ -887,7 +905,7 @@ class TestBasicOps:
# Implementation with fusible operation
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
op
=
te_ops
.
BasicLinear
(
in_features
,
out_features
,
...
...
@@ -904,7 +922,7 @@ class TestBasicOps:
op
,
te_ops
.
Quantize
(
forward
=
quantized_output
,
backward
=
quantized_grad_output
),
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -913,7 +931,7 @@ class TestBasicOps:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
or
quantized_output
or
quantized_grad_input
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -1024,7 +1042,7 @@ class TestBasicOps:
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantization
is
None
and
(
quantized_compute
or
quantized_weight
):
pytest
.
skip
(
"Quantization scheme is not specified"
)
...
...
@@ -1065,7 +1083,7 @@ class TestBasicOps:
# Implementation with fusible operation
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
op
=
te_ops
.
Linear
(
in_features
,
out_features
,
...
...
@@ -1081,7 +1099,7 @@ class TestBasicOps:
del
b_test
for
param
in
op
.
parameters
():
param
.
requires_grad_
(
requires_grad
=
weight_requires_grad
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
op
(
x_test
)
if
input_requires_grad
or
weight_requires_grad
:
y_test
.
backward
(
dy_test
)
...
...
@@ -1091,7 +1109,7 @@ class TestBasicOps:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -1128,7 +1146,7 @@ class TestBasicOps:
in_shape
=
list
(
in_shape
)[:
-
1
]
+
list
(
weight_shape
)
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
...
@@ -1182,14 +1200,14 @@ class TestBasicOps:
op
,
te_ops
.
Quantize
(
forward
=
quantized_compute
,
backward
=
False
),
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -1298,7 +1316,7 @@ class TestBasicOps:
in_shape
=
list
(
in_shape
)[:
-
1
]
+
list
(
weight_shape
)
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
...
@@ -1344,14 +1362,14 @@ class TestBasicOps:
op
,
te_ops
.
Quantize
(
forward
=
quantized_compute
,
backward
=
False
),
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -1431,7 +1449,7 @@ class TestBasicOps:
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
x1_ref
,
x1_test
=
make_reference_and_test_tensors
(
...
...
@@ -1470,8 +1488,11 @@ class TestBasicOps:
# Check results
tols
=
dtype_tols
(
dtype
)
if
with_quantization
:
tols
=
dtype_tols
(
x1_test
.
_fp8_dtype
)
if
in_place
:
if
quantization
in
(
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
):
tols
=
dtype_tols
(
x1_test
.
_fp8_dtype
)
elif
quantization
==
"nvfp4"
:
tols
=
dtype_tols
(
x1_test
.
_fp4_dtype
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx1_test
=
x1_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx2_test
=
x2_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -1500,7 +1521,7 @@ class TestBasicOps:
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
...
@@ -1573,7 +1594,7 @@ class TestBasicOps:
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
if
cache_quantized_input
:
maybe_skip_quantization
(
"fp8_current_scaling"
,
device
=
device
)
...
...
@@ -1641,14 +1662,16 @@ class TestBasicOps:
make_op
(
cache_quantized_input
=
cache_quantized_input
),
te_ops
.
Quantize
(
forward
=
quantized_compute
,
backward
=
False
),
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
or
cache_quantized_input
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
if
quantized_compute
:
tols
=
quantization_tols
(
quantization
)
elif
cache_quantized_input
:
tols
=
quantization_tols
(
"fp8_current_scaling"
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -1679,7 +1702,7 @@ class TestBasicOps:
quantized_compute
=
quantization
is
not
None
if
not
quantized_compute
and
(
quantize_forward
or
quantize_backward
):
pytest
.
skip
(
"Quantization scheme has not been provided"
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
...
@@ -1706,13 +1729,87 @@ class TestBasicOps:
te_ops
.
SwiGLU
(),
te_ops
.
Quantize
(
forward
=
quantize_forward
,
backward
=
False
),
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
:
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantize_forward"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantize_backward"
,
(
False
,
True
))
def
test_clamped_swiglu
(
self
,
*
,
out_shape
:
Iterable
[
int
]
=
(
32
,
32
),
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantize_forward
:
bool
,
quantize_backward
:
bool
,
limit
:
float
=
0.75
,
alpha
:
float
=
1.702
,
):
# Test SwiGLU variant used in GPT OSS.
# Tensor dimensions
in_shape
=
list
(
out_shape
)
in_shape
[
-
1
]
*=
2
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
if
not
quantized_compute
and
(
quantize_forward
or
quantize_backward
):
pytest
.
skip
(
"Quantization scheme has not been provided"
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
x_glu
,
x_linear
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
x_glu
=
x_glu
.
clamp
(
min
=
None
,
max
=
limit
)
x_linear
=
x_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
out_glu
=
x_glu
*
torch
.
sigmoid
(
alpha
*
x_glu
)
y_ref
=
out_glu
*
(
x_linear
+
1
)
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
recipe
=
make_recipe
(
quantization
)
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
False
,
backward
=
quantize_backward
),
te_ops
.
ClampedSwiGLU
(
limit
=
limit
,
alpha
=
alpha
),
te_ops
.
Quantize
(
forward
=
quantize_forward
,
backward
=
False
),
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
and
quantization
==
"nvfp4"
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat4E2M1
)
elif
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
# Check results
...
...
@@ -1781,7 +1878,7 @@ class TestBasicOps:
# Skip invalid configurations
quantized_input
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
# Note: Shift values to make sure inputs are non-zero
...
...
@@ -1872,7 +1969,7 @@ class TestFusedOps:
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
...
...
@@ -1913,7 +2010,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
...
...
@@ -1929,7 +2026,7 @@ class TestFusedOps:
model
[
0
].
bias
.
copy_
(
b_test
)
del
w_test
del
b_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -1943,7 +2040,7 @@ class TestFusedOps:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -1979,7 +2076,7 @@ class TestFusedOps:
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
...
...
@@ -2023,7 +2120,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
...
...
@@ -2040,7 +2137,7 @@ class TestFusedOps:
model
[
0
].
bias
.
copy_
(
b_test
)
del
w_test
del
b_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x1_test
,
x2_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -2054,7 +2151,7 @@ class TestFusedOps:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -2093,7 +2190,7 @@ class TestFusedOps:
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
...
...
@@ -2130,7 +2227,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
...
...
@@ -2146,7 +2243,7 @@ class TestFusedOps:
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
del
w_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x1_test
,
x2_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -2161,7 +2258,7 @@ class TestFusedOps:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -2194,7 +2291,7 @@ class TestFusedOps:
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
device
=
device
,
dtype
=
dtype
)
if
quantization
==
"mxfp8"
and
(
len
(
in_shape
)
<
2
or
in_shape
[
-
1
]
%
32
!=
0
):
pytest
.
skip
(
"Unsupported tensor size for MXFP8"
)
...
...
@@ -2237,7 +2334,7 @@ class TestFusedOps:
with
torch
.
no_grad
():
model
[
1
].
bias
.
copy_
(
b_test
)
del
b_test
with
te
.
fp8_
autocast
(
enabled
=
with_quantization
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
with_quantization
,
recipe
=
recipe
):
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -2256,7 +2353,7 @@ class TestFusedOps:
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
with_quantization
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -2375,7 +2472,7 @@ class TestFusedOps:
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
...
...
@@ -2415,7 +2512,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
):
model
=
te_ops
.
Sequential
(
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
te_ops
.
Linear
(
...
...
@@ -2429,7 +2526,7 @@ class TestFusedOps:
with
torch
.
no_grad
():
model
[
1
].
weight
.
copy_
(
w_test
)
del
w_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y1_test
,
y2_test
=
model
(
x_test
)
(
y1_test
*
dy1_test
+
y2_test
*
dy2_test
).
sum
().
backward
()
...
...
@@ -2443,7 +2540,7 @@ class TestFusedOps:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y1_test
=
y1_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -2479,7 +2576,7 @@ class TestFusedOps:
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
...
...
@@ -2511,7 +2608,7 @@ class TestFusedOps:
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
):
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
...
...
@@ -2525,7 +2622,7 @@ class TestFusedOps:
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
del
w_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x_test
)
(
y_test
*
dy_test
).
sum
().
backward
()
...
...
@@ -2539,7 +2636,7 @@ class TestFusedOps:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -2580,12 +2677,12 @@ class TestCheckpointing:
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
# Construct model
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model_save
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
)
...
...
@@ -2596,7 +2693,7 @@ class TestCheckpointing:
x
=
torch
.
randn
(
in_shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
dy
=
torch
.
randn
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
optim_save
.
zero_grad
()
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y
=
model_save
(
x
)
y
.
backward
(
dy
)
optim_save
.
step
()
...
...
@@ -2625,14 +2722,14 @@ class TestCheckpointing:
ys_save
=
[]
for
i
in
range
(
post_checkpoint_steps
):
optim_save
.
zero_grad
()
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y
=
model_save
(
xs_save
[
i
])
y
.
backward
(
dys
[
i
])
optim_save
.
step
()
ys_save
.
append
(
y
)
# Load checkpoint
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model_load
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
)
...
...
@@ -2645,7 +2742,7 @@ class TestCheckpointing:
ys_load
=
[]
for
i
in
range
(
post_checkpoint_steps
):
optim_load
.
zero_grad
()
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y
=
model_load
(
xs_load
[
i
])
y
.
backward
(
dys
[
i
])
optim_load
.
step
()
...
...
@@ -2706,7 +2803,7 @@ class TestSequentialModules:
ffn_shape
=
in_shape
[:
-
1
]
+
(
ffn_hidden_size
,)
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
ffn_shape
,
device
=
device
)
quantization_needed
=
quantized_compute
or
quantized_weight
if
quantization
is
None
and
quantization_needed
:
...
...
@@ -2732,7 +2829,7 @@ class TestSequentialModules:
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
if
normalization
==
"LayerNorm"
:
norm
=
te_ops
.
LayerNorm
(
hidden_size
,
...
...
@@ -2763,6 +2860,6 @@ class TestSequentialModules:
dtype
=
dtype
,
)
forward
=
te_ops
.
Sequential
(
norm
,
ffn1
,
act
,
ffn2
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
tests/pytorch/test_hf_integration.py
View file @
063ef88d
...
...
@@ -6,7 +6,7 @@ import pytest
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_utils
import
PreTrainedModel
from
transformer_engine.pytorch
.transformer
import
TransformerLayer
from
transformer_engine.pytorch
import
TransformerLayer
class
SimpleTEModel
(
PreTrainedModel
):
...
...
tests/pytorch/test_multi_tensor.py
View file @
063ef88d
...
...
@@ -5,7 +5,7 @@
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.optimizers
import
MultiTensorApply
...
...
tests/pytorch/test_numerics.py
View file @
063ef88d
...
...
@@ -13,18 +13,15 @@ import torch.nn as nn
from
torch.nn
import
Parameter
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
fp8_autocast
,
fp8_model_init
,
)
from
transformer_engine.pytorch.quantization
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
(
init_method_normal
,
scaled_init_method_normal
,
attention_mask_func
,
is_bf16_compatible
,
)
from
transformer_engine.pytorch
import
(
autocast
,
quantized_model_init
,
DotProductAttention
,
LayerNormLinear
,
LayerNormMLP
,
...
...
@@ -36,27 +33,29 @@ from transformer_engine.pytorch import (
LayerNorm
,
Fp8Padding
,
Fp8Unpadding
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
MXFP8Quantizer
,
get_device_compute_capability
,
is_fp8_available
,
is_mxfp8_available
,
is_fp8_block_scaling_available
,
is_bf16_available
,
)
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
,
get_available_attention_backends
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
is_mxfp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
=
is_fp8_block_scaling_available
(
return_reason
=
True
)
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
...
...
@@ -82,7 +81,7 @@ module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference
=
[
"sbhd"
,
"bshd"
]
param_types
=
[
torch
.
float32
,
torch
.
float16
]
if
is_bf16_
compati
ble
():
# bf16 requires sm_80 or higher
if
is_bf16_
availa
ble
():
# bf16 requires sm_80 or higher
param_types
.
append
(
torch
.
bfloat16
)
batch_sizes
=
[
1
,
2
]
...
...
@@ -553,7 +552,7 @@ def _test_e2e_selective_recompute(
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
...
...
@@ -580,7 +579,7 @@ def _test_e2e_selective_recompute(
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_seqlen_q
)
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
recipe
):
te_out
=
block
(
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
,
...
...
@@ -649,7 +648,7 @@ def _test_e2e_full_recompute(
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
...
...
@@ -677,7 +676,7 @@ def _test_e2e_full_recompute(
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_seqlen_q
)
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
recipe
):
if
recompute
:
te_out
=
te_checkpoint
(
block
,
...
...
@@ -1107,7 +1106,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
)
inp_hidden_states
.
retain_grad
()
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
recipe
):
out
=
block
(
inp_hidden_states
)
if
isinstance
(
out
,
(
List
,
Tuple
)):
out
=
out
[
0
]
...
...
@@ -1328,7 +1327,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
te_linear_ref
=
Linear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
...
...
@@ -1782,7 +1781,7 @@ def _test_grouped_linear_accuracy(
else
:
m_splits
=
torch
.
tensor
([
config
.
max_seqlen_q
])
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
recipe
):
if
isinstance
(
block
,
GroupedLinear
):
m_splits
=
m_splits
*
bs
out
=
block
(
inp_hidden_states
,
m_splits
.
tolist
())
...
...
@@ -1850,7 +1849,7 @@ def test_grouped_linear_accuracy(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
...
...
@@ -1994,7 +1993,7 @@ def test_grouped_linear_accuracy_save_original_input(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
...
...
@@ -2154,7 +2153,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
m_splits
=
_generate_random_numbers
(
num_gemms
,
config
.
max_seqlen_q
*
bs
)
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
recipe
):
if
isinstance
(
block
,
TorchGroupedLinearWithPadding
):
out
=
block
(
inp_hidden_states
,
m_splits
)
else
:
...
...
@@ -2208,7 +2207,7 @@ def test_padding_grouped_linear_accuracy(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
TorchGroupedLinearWithPadding
(
num_gemms
,
config
.
hidden_size
,
...
...
@@ -2219,7 +2218,7 @@ def test_padding_grouped_linear_accuracy(
fp8
=
fp8
,
).
eval
()
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
ref_grouped_linear
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
...
...
@@ -2285,7 +2284,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
TorchGroupedLinearWithPadding
(
num_gemms
,
config
.
hidden_size
,
...
...
@@ -2296,7 +2295,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8
=
fp8
,
).
eval
()
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
ref_grouped_linear
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
...
...
@@ -2446,7 +2445,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
with
fp8
_model_init
(
enabled
=
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8_model_params
,
recipe
=
recipe
):
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
...
...
@@ -2473,7 +2472,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_seqlen_q
)
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
True
,
recipe
=
recipe
):
te_out
=
block
(
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
)
loss
=
te_out
.
sum
()
loss
.
backward
()
...
...
tests/pytorch/test_onnx_export.py
View file @
063ef88d
...
...
@@ -33,10 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.onnx_extensions
import
te_translation_table
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
from
transformer_engine.pytorch.
fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
,
te_translation_table
from
transformer_engine.pytorch.
quantization
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
get_default_init_method
import
tensorrt
as
trt
...
...
@@ -59,8 +58,8 @@ NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
# The directory where this file is stored.
TESTS_DIR
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
fp8_recipes
=
[]
if
mxfp8_available
:
...
...
@@ -179,8 +178,8 @@ def do_export(
input_names
=
input_names
or
[
"input"
]
output_names
=
output_names
or
[
"output"
]
with
torch
.
inference_mode
(),
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
with
torch
.
inference_mode
(),
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
),
warnings
.
catch_warnings
():
warnings
.
filterwarnings
(
action
=
"ignore"
,
category
=
torch
.
jit
.
TracerWarning
,
module
=
r
".*"
)
...
...
@@ -234,8 +233,8 @@ def te_infer(
fp8_recipe
:
recipe
.
Recipe
,
):
"""Transformer Engine forward propagation."""
with
torch
.
inference_mode
(),
te
.
fp8_
autocast
(
enabled
=
is_fp8
,
fp8_
recipe
=
fp8_recipe
with
torch
.
inference_mode
(),
te
.
autocast
(
enabled
=
is_fp8
,
recipe
=
fp8_recipe
),
warnings
.
catch_warnings
():
te_outputs
=
model
(
*
inps
if
isinstance
(
inps
,
tuple
)
else
(
inps
,))
if
not
isinstance
(
te_outputs
,
tuple
):
...
...
@@ -441,7 +440,7 @@ def _test_export_linear(
bias_str
=
"_bias"
if
use_bias
else
""
high_prec_str
=
dtype2str
(
precision
)
fname
=
f
"te.linear
{
fp8_str
}{
bias_str
}{
high_prec_str
}
.onnx"
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
model
=
Test_Linear
(
in_features
,
out_features
,
use_bias
,
return_bias
,
precision
).
to
(
device
=
"cuda"
)
...
...
@@ -507,7 +506,7 @@ def _test_export_layernorm(
fname
=
f
"te.layernorm_linear
{
fp8_str
}{
high_prec_str
}
.onnx"
with
torch
.
no_grad
():
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
layernorm_cls
=
te
.
LayerNorm
if
normalization
==
"LayerNorm"
else
te
.
RMSNorm
model
=
layernorm_cls
(
hidden_size
,
...
...
@@ -577,7 +576,7 @@ def _test_export_layernorm_linear(
fname
=
f
"te.layernorm_linear
{
fp8_str
}{
bias_str
}{
high_prec_str
}
.onnx"
with
torch
.
no_grad
():
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
model
=
te
.
LayerNormLinear
(
hidden_size
,
3
*
hidden_size
,
...
...
@@ -673,7 +672,7 @@ def _test_export_layernorm_mlp(
bias_str
=
"_bias"
if
use_bias
else
""
high_prec_str
=
dtype2str
(
precision
)
fname
=
f
"te.layernorm_mlp
{
fp8_str
}{
bias_str
}{
high_prec_str
}
_
{
activation
}
.onnx"
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
model
=
te
.
LayerNormMLP
(
hidden_size
,
ffn_hidden_size
,
...
...
@@ -1215,13 +1214,13 @@ def test_trt_integration(fp8_recipe: recipe.Recipe):
).
eval
()
inps
=
(
torch
.
randn
([
16
,
16
,
128
],
device
=
"cuda"
,
requires_grad
=
False
),)
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
out_ref
=
model
(
*
inps
)
onnx_fd
,
onnx_path
=
tempfile
.
mkstemp
(
suffix
=
".onnx"
)
os
.
close
(
onnx_fd
)
try
:
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
with
te
.
onnx_export
(
enabled
=
True
):
torch
.
onnx
.
export
(
model
,
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment