Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1680 additions
and
228 deletions
+1680
-228
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
+491
-0
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
+248
-0
tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
+238
-0
tests/pytorch/test_checkpoint.py
tests/pytorch/test_checkpoint.py
+8
-6
tests/pytorch/test_cpu_offloading.py
tests/pytorch/test_cpu_offloading.py
+6
-7
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+89
-27
tests/pytorch/test_custom_recipe.py
tests/pytorch/test_custom_recipe.py
+290
-0
tests/pytorch/test_deferred_init.py
tests/pytorch/test_deferred_init.py
+0
-1
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+7
-6
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+18
-6
tests/pytorch/test_float8_current_scaling_exact.py
tests/pytorch/test_float8_current_scaling_exact.py
+14
-9
tests/pytorch/test_float8blockwisetensor.py
tests/pytorch/test_float8blockwisetensor.py
+2
-3
tests/pytorch/test_float8tensor.py
tests/pytorch/test_float8tensor.py
+2
-4
tests/pytorch/test_fused_optimizer.py
tests/pytorch/test_fused_optimizer.py
+15
-18
tests/pytorch/test_fused_rope.py
tests/pytorch/test_fused_rope.py
+16
-0
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+188
-91
tests/pytorch/test_hf_integration.py
tests/pytorch/test_hf_integration.py
+1
-1
tests/pytorch/test_multi_tensor.py
tests/pytorch/test_multi_tensor.py
+1
-1
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+32
-33
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+14
-15
No files found.
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch.experimental.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.common.recipe
import
NVFP4BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.experimental
import
utils
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
def
unpack_fp4
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
repeated
=
x
.
repeat_interleave
(
2
,
dim
=
1
)
repeated
[:,
0
::
2
]
&=
0x0F
repeated
[:,
1
::
2
]
>>=
4
return
repeated
def
check_quantization_nvfp4_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
swizzled_scale
:
bool
,
use_cpp_allocator
:
bool
,
with_2d_quantization
:
bool
,
)
->
None
:
te_dtype
=
tex
.
DType
.
kFloat4E2M1
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Input
x
=
torch
.
randn
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
# Quantize
nvfp4_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
return_transpose
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
False
,
with_post_rht_amax
=
False
,
with_2d_quantization
=
with_2d_quantization
,
)
if
use_cpp_allocator
:
x_nvfp4_sut
=
nvfp4_quantizer
(
x
)
else
:
x_nvfp4_sut
=
nvfp4_quantizer
.
make_empty
(
(
M
,
N
),
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_nvfp4_sut
=
nvfp4_quantizer
.
update_quantized
(
x
,
x_nvfp4_sut
)
# Extract data from NVFP4Tensor
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_scale_inv
qx_t
=
(
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_sut
.
_columnwise_data
is
not
None
else
None
)
sx_t
=
x_nvfp4_sut
.
_columnwise_scale_inv
qx_amax
=
x_nvfp4_sut
.
_amax_rowwise
# Reference quantization
quant_tile_shape
=
(
1
,
16
)
if
not
with_2d_quantization
else
(
16
,
16
)
ref_quantizer
=
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
rowwise
=
True
,
columnwise
=
return_transpose
,
pow_2_scales
=
False
,
eps
=
0.0
,
quant_tile_shape
=
quant_tile_shape
,
)
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
)
# Extract data from RefNVFP4Tensor
qx_ref
=
(
unpack_fp4
(
x_nvfp4_ref
.
data
.
view
(
dtype
=
torch
.
uint8
))
if
x_nvfp4_ref
.
data
is
not
None
else
None
)
sx_ref
=
x_nvfp4_ref
.
scale
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale
is
not
None
else
None
qx_t_ref
=
(
unpack_fp4
(
x_nvfp4_ref
.
data_t
.
view
(
dtype
=
torch
.
uint8
))
if
x_nvfp4_ref
.
data_t
is
not
None
else
None
)
sx_t_ref
=
(
x_nvfp4_ref
.
scale_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale_t
is
not
None
else
None
)
ref_amax
=
x_nvfp4_ref
.
global_amax_row
qx
=
unpack_fp4
(
qx
)
qx_t
=
unpack_fp4
(
qx_t
)
if
qx_t
is
not
None
else
None
torch
.
testing
.
assert_close
(
qx
,
qx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only the valid portion of scale tensors (reference may not have padding)
ref_sx_shape
=
sx_ref
.
shape
sx_valid
=
sx
[:
ref_sx_shape
[
0
],
:
ref_sx_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_valid
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
torch
.
testing
.
assert_close
(
qx_t
,
qx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only the valid portion of transpose scale tensors
ref_sx_t_shape
=
sx_t_ref
.
shape
sx_t_valid
=
sx_t
[:
ref_sx_t_shape
[
0
],
:
ref_sx_t_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_t_valid
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx_amax
,
ref_amax
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
# full tile cases
(
128
,
128
),
(
256
,
256
),
(
256
,
1024
),
(
1024
,
256
),
# Padding required cases
(
256
,
272
),
(
304
,
304
),
(
320
,
256
),
# Some larger tiles
(
2048
,
2048
),
(
1024
,
2048
),
(
2048
,
1024
),
# # largest tile
(
8192
,
8192
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"swizzled_scale"
,
[
False
],
ids
=
[
"linear_scale"
])
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
@
pytest
.
mark
.
parametrize
(
"with_2d_quantization"
,
[
True
,
False
],
ids
=
[
"2d_quantization"
,
"1d_quantization"
]
)
def
test_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
swizzled_scale
:
bool
,
use_cpp_allocator
:
bool
,
with_2d_quantization
:
bool
,
)
->
None
:
check_quantization_nvfp4_versus_reference
(
x_dtype
=
x_dtype
,
M
=
M
,
N
=
N
,
return_transpose
=
return_transpose
,
swizzled_scale
=
swizzled_scale
,
use_cpp_allocator
=
use_cpp_allocator
,
with_2d_quantization
=
with_2d_quantization
,
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
(
128
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"extrema_high"
,
[
False
,
True
],
ids
=
[
"zeros"
,
"maxes"
])
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
def
test_nvfp4_quantization_extrema_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
extrema_high
:
bool
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
):
te_dtype
=
tex
.
DType
.
kFloat4E2M1
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
if
extrema_high
:
x
=
torch
.
full
((
M
,
N
),
torch
.
finfo
(
x_dtype
).
max
,
dtype
=
x_dtype
,
device
=
device
)
else
:
x
=
torch
.
zeros
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
nvfp4_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
return_transpose
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
False
,
with_post_rht_amax
=
False
,
)
if
use_cpp_allocator
:
x_nvfp4_sut
=
nvfp4_quantizer
(
x
)
else
:
x_nvfp4_sut
=
nvfp4_quantizer
.
make_empty
(
(
M
,
N
),
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_nvfp4_sut
=
nvfp4_quantizer
.
update_quantized
(
x
,
x_nvfp4_sut
)
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
=
x_nvfp4_sut
.
_rowwise_scale_inv
qx_t
=
(
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_sut
.
_columnwise_data
is
not
None
else
None
)
sx_t
=
x_nvfp4_sut
.
_columnwise_scale_inv
qx_amax
=
x_nvfp4_sut
.
_amax_rowwise
ref_quantizer
=
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
rowwise
=
True
,
columnwise
=
return_transpose
,
pow_2_scales
=
False
,
eps
=
0.0
,
quant_tile_shape
=
(
1
,
16
),
)
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
)
qx_ref
=
x_nvfp4_ref
.
data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data
is
not
None
else
None
sx_ref
=
x_nvfp4_ref
.
scale
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale
is
not
None
else
None
qx_t_ref
=
(
x_nvfp4_ref
.
data_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data_t
is
not
None
else
None
)
sx_t_ref
=
(
x_nvfp4_ref
.
scale_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale_t
is
not
None
else
None
)
ref_amax
=
x_nvfp4_ref
.
global_amax_row
torch
.
testing
.
assert_close
(
qx
,
qx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
ref_sx_shape
=
sx_ref
.
shape
sx_valid
=
sx
[:
ref_sx_shape
[
0
],
:
ref_sx_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_valid
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
torch
.
testing
.
assert_close
(
qx_t
,
qx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
ref_sx_t_shape
=
sx_t_ref
.
shape
sx_t_valid
=
sx_t
[:
ref_sx_t_shape
[
0
],
:
ref_sx_t_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_t_valid
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx_amax
,
ref_amax
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
(
16
,
128
),
(
32
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
def
test_nvfp4_quantization_boundary_values
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
):
"""
Stress rounding/threshold behavior by placing values just below/above
many potential bin edges within each 16-element microblock.
Validates native vs reference byte-for-byte and scale parity.
"""
te_dtype
=
tex
.
DType
.
kFloat4E2M1
device
=
"cuda"
seed
=
123
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Construct a single row with paired boundary values: v-eps, v+eps
# spanning a wide dynamic range to exercise clipping and multiple bins.
# Ensure even N and N is multiple of 16 for microblocks, which holds for 128.
base
=
torch
.
linspace
(
-
12.0
,
12.0
,
steps
=
N
//
2
,
dtype
=
torch
.
float32
,
device
=
device
)
eps
=
torch
.
full_like
(
base
,
1e-3
)
# Avoid zero eps for very small magnitudes
eps
=
torch
.
maximum
(
eps
,
1e-4
*
torch
.
ones_like
(
base
))
lower
=
base
-
eps
upper
=
base
+
eps
row
=
torch
.
empty
(
N
,
dtype
=
torch
.
float32
,
device
=
device
)
row
[
0
::
2
]
=
lower
row
[
1
::
2
]
=
upper
x
=
row
.
unsqueeze
(
0
).
repeat
(
M
,
1
).
to
(
dtype
=
x_dtype
)
nvfp4_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
return_transpose
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
False
,
with_post_rht_amax
=
False
,
)
if
use_cpp_allocator
:
x_nvfp4_sut
=
nvfp4_quantizer
(
x
)
else
:
x_nvfp4_sut
=
nvfp4_quantizer
.
make_empty
(
(
M
,
N
),
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_nvfp4_sut
=
nvfp4_quantizer
.
update_quantized
(
x
,
x_nvfp4_sut
)
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
=
x_nvfp4_sut
.
_rowwise_scale_inv
qx_t
=
(
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_sut
.
_columnwise_data
is
not
None
else
None
)
sx_t
=
x_nvfp4_sut
.
_columnwise_scale_inv
qx_amax
=
x_nvfp4_sut
.
_amax_rowwise
ref_quantizer
=
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
rowwise
=
True
,
columnwise
=
return_transpose
,
pow_2_scales
=
False
,
eps
=
0.0
,
quant_tile_shape
=
(
1
,
16
),
)
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
)
qx_ref
=
x_nvfp4_ref
.
data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data
is
not
None
else
None
sx_ref
=
x_nvfp4_ref
.
scale
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale
is
not
None
else
None
qx_t_ref
=
(
x_nvfp4_ref
.
data_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data_t
is
not
None
else
None
)
sx_t_ref
=
(
x_nvfp4_ref
.
scale_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale_t
is
not
None
else
None
)
ref_amax
=
x_nvfp4_ref
.
global_amax_row
torch
.
testing
.
assert_close
(
qx
,
qx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only valid portion of scales (trim any padding)
ref_sx_shape
=
sx_ref
.
shape
sx_valid
=
sx
[:
ref_sx_shape
[
0
],
:
ref_sx_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_valid
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
torch
.
testing
.
assert_close
(
qx_t
,
qx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
ref_sx_t_shape
=
sx_t_ref
.
shape
sx_t_valid
=
sx_t
[:
ref_sx_t_shape
[
0
],
:
ref_sx_t_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_t_valid
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx_amax
,
ref_amax
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
(
32
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
def
test_nvfp4_quantization_noncontiguous_inputs
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
):
te_dtype
=
tex
.
DType
.
kFloat4E2M1
device
=
"cuda"
seed
=
17
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Start from a contiguous tensor, then make a non-contiguous view by transpose
x_base
=
torch
.
randn
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
x_nc
=
x_base
.
t
()
# shape (N, M), non-contiguous
assert
not
x_nc
.
is_contiguous
()
nvfp4_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
return_transpose
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
False
,
with_post_rht_amax
=
False
,
)
if
use_cpp_allocator
:
x_nvfp4_sut
=
nvfp4_quantizer
(
x_nc
)
else
:
x_nvfp4_sut
=
nvfp4_quantizer
.
make_empty
(
x_nc
.
shape
,
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_nvfp4_sut
=
nvfp4_quantizer
.
update_quantized
(
x_nc
,
x_nvfp4_sut
)
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
=
x_nvfp4_sut
.
_rowwise_scale_inv
qx_t
=
(
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_sut
.
_columnwise_data
is
not
None
else
None
)
sx_t
=
x_nvfp4_sut
.
_columnwise_scale_inv
qx_amax
=
x_nvfp4_sut
.
_amax_rowwise
ref_quantizer
=
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
rowwise
=
True
,
columnwise
=
return_transpose
,
pow_2_scales
=
False
,
eps
=
0.0
,
quant_tile_shape
=
(
1
,
16
),
)
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x_nc
)
qx_ref
=
x_nvfp4_ref
.
data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data
is
not
None
else
None
sx_ref
=
x_nvfp4_ref
.
scale
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale
is
not
None
else
None
qx_t_ref
=
(
x_nvfp4_ref
.
data_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
data_t
is
not
None
else
None
)
sx_t_ref
=
(
x_nvfp4_ref
.
scale_t
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale_t
is
not
None
else
None
)
ref_amax
=
x_nvfp4_ref
.
global_amax_row
# Quantized must match
torch
.
testing
.
assert_close
(
qx
,
qx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only valid portion of scales (trim padding)
ref_sx_shape
=
sx_ref
.
shape
sx_valid
=
sx
[:
ref_sx_shape
[
0
],
:
ref_sx_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_valid
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
torch
.
testing
.
assert_close
(
qx_t
,
qx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
ref_sx_t_shape
=
sx_t_ref
.
shape
sx_t_valid
=
sx_t
[:
ref_sx_t_shape
[
0
],
:
ref_sx_t_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_t_valid
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx_amax
,
ref_amax
,
atol
=
0.0
,
rtol
=
0.0
)
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py.
# Separate to make sure all the functionalities are working as expected.
# Otherwise reference implementation will get messy.
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality.
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.common.recipe
import
NVFP4BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.experimental.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.pytorch.experimental
import
utils
import
pytest
import
torch
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
def
unpack_fp4
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
repeated
=
x
.
repeat_interleave
(
2
,
dim
=
1
)
repeated
[:,
0
::
2
]
&=
0x0F
repeated
[:,
1
::
2
]
>>=
4
return
repeated
def
check_quantization_nvfp4_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
contiguous
:
bool
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
swizzled_scale
:
bool
=
False
,
hadamard_dimension
:
int
=
16
,
with_rht
:
bool
=
True
,
with_post_rht_amax
:
bool
=
True
,
with_random_sign_mask
:
bool
=
True
,
)
->
None
:
assert
with_rht
and
with_post_rht_amax
,
"RHT and post-RHT amax reduction must be enabled."
te_dtype
=
tex
.
DType
.
kFloat4E2M1
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Input
x
=
torch
.
randn
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
x
=
x
.
transpose
(
0
,
1
)
if
not
contiguous
else
x
# Quantize
nvfp4_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
return_transpose
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
with_rht
,
with_post_rht_amax
=
with_post_rht_amax
,
with_random_sign_mask
=
with_random_sign_mask
,
)
if
use_cpp_allocator
:
x_nvfp4_sut
=
nvfp4_quantizer
(
x
)
else
:
x_nvfp4_sut
=
nvfp4_quantizer
.
make_empty
(
x
.
shape
,
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_nvfp4_sut
=
nvfp4_quantizer
.
update_quantized
(
x
,
x_nvfp4_sut
)
# Extract data from NVFP4Tensor
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_scale_inv
qx_t
=
(
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_sut
.
_columnwise_data
is
not
None
else
None
)
sx_t
=
x_nvfp4_sut
.
_columnwise_scale_inv
amax_rowwise
=
x_nvfp4_sut
.
_amax_rowwise
amax_colwise
=
x_nvfp4_sut
.
_amax_columnwise
qx
=
unpack_fp4
(
qx
)
qx_t
=
unpack_fp4
(
qx_t
)
if
qx_t
is
not
None
else
None
# Reference quantization using NVFP4QuantizerRef with built-in RHT
ref_quantizer
=
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
rowwise
=
True
,
columnwise
=
return_transpose
,
pow_2_scales
=
False
,
eps
=
0.0
,
quant_tile_shape
=
(
1
,
16
),
with_rht
=
with_rht
,
with_random_sign_mask
=
with_random_sign_mask
,
)
x_nvfp4_ref
=
ref_quantizer
.
quantize
(
x
)
# Extract data from RefNVFP4Tensor
qx_ref
=
(
unpack_fp4
(
x_nvfp4_ref
.
data
.
view
(
dtype
=
torch
.
uint8
))
if
x_nvfp4_ref
.
data
is
not
None
else
None
)
sx_ref
=
x_nvfp4_ref
.
scale
.
view
(
dtype
=
torch
.
uint8
)
if
x_nvfp4_ref
.
scale
is
not
None
else
None
ref_amax_rowwise
=
x_nvfp4_ref
.
global_amax_row
if
return_transpose
:
assert
x_nvfp4_ref
.
data_t
is
not
None
assert
x_nvfp4_ref
.
scale_t
is
not
None
qx_t_ref
=
unpack_fp4
(
x_nvfp4_ref
.
data_t
.
view
(
dtype
=
torch
.
uint8
))
sx_t_ref
=
x_nvfp4_ref
.
scale_t
.
view
(
dtype
=
torch
.
uint8
)
# Compute transpose amax using the same reference quantizer
x_t_for_amax
=
(
ref_quantizer
.
_apply_rht
(
x
.
t
().
contiguous
())
if
with_rht
else
x
.
t
().
contiguous
()
)
ref_amax_colwise_t
=
torch
.
max
(
torch
.
abs
(
x_t_for_amax
)).
to
(
torch
.
float32
).
view
(
1
)
else
:
qx_t_ref
=
None
sx_t_ref
=
None
ref_amax_colwise_t
=
None
torch
.
testing
.
assert_close
(
amax_rowwise
,
ref_amax_rowwise
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx
,
qx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only the valid portion of scale tensors (reference may not have padding)
ref_sx_shape
=
sx_ref
.
shape
sx_valid
=
sx
[:
ref_sx_shape
[
0
],
:
ref_sx_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_valid
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
torch
.
testing
.
assert_close
(
amax_colwise
,
ref_amax_colwise_t
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx_t
,
qx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# Compare only the valid portion of transpose scale tensors
ref_sx_t_shape
=
sx_t_ref
.
shape
sx_t_valid
=
sx_t
[:
ref_sx_t_shape
[
0
],
:
ref_sx_t_shape
[
1
]]
torch
.
testing
.
assert_close
(
sx_t_valid
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
# full tile cases
(
128
,
128
),
(
256
,
256
),
(
256
,
1024
),
(
1024
,
256
),
# Padding required cases
(
256
,
272
),
(
304
,
304
),
(
320
,
256
),
# Some larger tiles
(
2048
,
2048
),
(
1024
,
2048
),
(
2048
,
1024
),
# Real shapes,
(
8192
,
5120
),
(
8192
,
10240
),
(
8192
,
2560
),
(
8192
,
11328
),
(
8192
,
512
),
(
8192
,
3584
),
(
5120
,
8192
),
(
10240
,
8192
),
(
2560
,
8192
),
(
11328
,
8192
),
(
512
,
8192
),
(
3584
,
8192
),
(
4096
,
16384
),
(
14336
,
16384
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
@
pytest
.
mark
.
parametrize
(
"with_random_sign_mask"
,
[
True
,
False
],
ids
=
[
"with_random_sign_mask"
,
"no_random_sign_mask"
]
)
def
test_rht_with_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
with_random_sign_mask
:
bool
,
)
->
None
:
check_quantization_nvfp4_versus_reference
(
x_dtype
=
x_dtype
,
M
=
M
,
N
=
N
,
contiguous
=
True
,
return_transpose
=
return_transpose
,
use_cpp_allocator
=
use_cpp_allocator
,
with_random_sign_mask
=
with_random_sign_mask
,
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
(
32
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"skip_transpose"
]
)
@
pytest
.
mark
.
parametrize
(
"use_cpp_allocator"
,
[
True
,
False
],
ids
=
[
"cpp_allocator"
,
"python_allocator"
]
)
@
pytest
.
mark
.
parametrize
(
"with_random_sign_mask"
,
[
True
,
False
],
ids
=
[
"with_random_sign_mask"
,
"no_random_sign_mask"
]
)
def
test_nvfp4_quantization_noncontiguous_inputs
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
return_transpose
:
bool
,
use_cpp_allocator
:
bool
,
with_random_sign_mask
:
bool
,
):
check_quantization_nvfp4_versus_reference
(
x_dtype
=
x_dtype
,
M
=
M
,
N
=
N
,
contiguous
=
False
,
return_transpose
=
return_transpose
,
use_cpp_allocator
=
use_cpp_allocator
,
with_random_sign_mask
=
with_random_sign_mask
,
)
tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
0 → 100755
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch
import
NVFP4Quantizer
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
seed
=
12345
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
def
unpack_fp4
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
repeated
=
x
.
repeat_interleave
(
2
,
dim
=
1
)
repeated
[:,
0
::
2
]
&=
0x0F
repeated
[:,
1
::
2
]
>>=
4
return
repeated
_FP4_LUT
=
torch
.
tensor
(
[
0.0
,
# 0: 0000 - zero
0.5
,
# 1: 0001 - smallest positive normal
1.0
,
# 2: 0010
1.5
,
# 3: 0011
2.0
,
# 4: 0100
3.0
,
# 5: 0101
4.0
,
# 6: 0110
6.0
,
# 7: 0111 - largest positive normal
-
0.0
,
# 8: 1000 - negative zero
-
0.5
,
# 9: 1001 - smallest negative normal
-
1.0
,
# 10: 1010
-
1.5
,
# 11: 1011
-
2.0
,
# 12: 1100
-
3.0
,
# 13: 1101
-
4.0
,
# 14: 1110
-
6.0
,
# 15: 1111 - largest negative normal
],
dtype
=
torch
.
float32
,
)
def
fp4_to_fp32
(
fp4
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Convert FP4 indices to their corresponding floating point values
# Each index (0-15) represents a 4-bit FP4 value in E2M1 format
# Values based on the FP4 E2M1 specification
fp4_lut
=
_FP4_LUT
.
to
(
fp4
.
device
)
return
fp4_lut
[
fp4
.
to
(
torch
.
long
)]
def
dequantize_fp4
(
qx
:
torch
.
Tensor
,
sx
:
torch
.
Tensor
,
amax
:
torch
.
Tensor
)
->
torch
.
Tensor
:
sf
=
sx
.
repeat_interleave
(
16
,
dim
=
1
).
view
(
torch
.
float8_e4m3fn
).
to
(
torch
.
float32
)
dqx
=
fp4_to_fp32
(
unpack_fp4
(
qx
))
sf
=
sf
[:
dqx
.
shape
[
0
],
:
dqx
.
shape
[
1
]]
dequant
=
dqx
*
sf
*
(
amax
/
(
6.0
*
448
))
return
dequant
def
RHT
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_wgrad_sign_vector
()
->
torch
.
Tensor
:
"""Hard-coded signs for Hadamard transform"""
return
torch
.
tensor
(
[
1.0
,
1.0
,
1.0
,
-
1.0
,
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
1.0
,
-
1.0
,
1.0
,
-
1.0
,
-
1.0
,
],
dtype
=
torch
.
float32
,
)
def
_build_hadamard_matrix
(
size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
with_random_sign_mask
:
bool
=
True
)
->
torch
.
Tensor
:
"""Construct a Hadamard matrix of given power-of-two size with entries +-1.
Uses Sylvester construction to avoid SciPy dependency.
"""
assert
(
size
&
(
size
-
1
))
==
0
,
"Hadamard size must be a power of two"
h
=
torch
.
ones
((
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
while
h
.
shape
[
0
]
<
size
:
h
=
torch
.
cat
(
[
torch
.
cat
([
h
,
h
],
dim
=
1
),
torch
.
cat
([
h
,
-
h
],
dim
=
1
),
],
dim
=
0
,
)
if
with_random_sign_mask
:
sign_mat
=
get_wgrad_sign_vector
().
to
(
device
)
*
torch
.
eye
(
size
,
device
=
device
,
dtype
=
torch
.
float32
)
h
=
sign_mat
@
h
return
h
.
to
(
dtype
)
rht_dim
=
16
# Build H and scale
H
=
_build_hadamard_matrix
(
rht_dim
,
x
.
device
,
x
.
dtype
)
scale
=
1.0
/
float
(
rht_dim
)
**
0.5
# Perform blockwise transform along the last dimension
original_shape
=
x
.
shape
x_mat
=
x
.
contiguous
().
view
(
-
1
,
rht_dim
)
# Random sign matrix is identity in this reference (no sign flipping)
transform
=
H
*
scale
out
=
x_mat
@
transform
return
out
.
view
(
original_shape
)
def
quantize_fp4
(
x
:
torch
.
Tensor
,
use_stochastic_rounding
:
bool
,
use_2D
:
bool
,
use_RHT
:
bool
)
->
torch
.
Tensor
:
nvfp4_quantizer
=
NVFP4Quantizer
(
rowwise
=
True
,
columnwise
=
True
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
use_RHT
,
with_post_rht_amax
=
True
,
stochastic_rounding
=
use_stochastic_rounding
,
with_2d_quantization
=
use_2D
,
)
x_nvfp4_sut
=
nvfp4_quantizer
(
x
)
# Extract data from NVFP4Tensor
assert
x_nvfp4_sut
.
_rowwise_data
is
not
None
qx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_rowwise_scale_inv
is
not
None
sx
:
torch
.
Tensor
=
x_nvfp4_sut
.
_rowwise_scale_inv
assert
x_nvfp4_sut
.
_columnwise_data
is
not
None
qx_t
:
torch
.
Tensor
=
x_nvfp4_sut
.
_columnwise_data
.
view
(
dtype
=
torch
.
uint8
)
assert
x_nvfp4_sut
.
_columnwise_scale_inv
is
not
None
sx_t
:
torch
.
Tensor
=
x_nvfp4_sut
.
_columnwise_scale_inv
return
qx
,
sx
,
qx_t
,
sx_t
def
check_quantization_nvfp4_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
use_2D
:
bool
,
use_RHT
:
bool
)
->
None
:
device
=
"cuda"
torch
.
manual_seed
(
seed
)
n_iters
=
50
x
=
torch
.
randn
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
*
2
-
1
y
=
x
.
t
().
contiguous
()
if
use_RHT
:
y
=
RHT
(
y
)
amax
=
torch
.
max
(
torch
.
abs
(
x
)).
float
()
q_rn
,
s_rn
,
q_t_rn
,
s_t_rn
=
quantize_fp4
(
x
,
use_stochastic_rounding
=
False
,
use_2D
=
use_2D
,
use_RHT
=
use_RHT
)
dq_rn
=
dequantize_fp4
(
q_rn
,
s_rn
,
amax
)
dq_t_rn
=
dequantize_fp4
(
q_t_rn
,
s_t_rn
,
amax
)
error_rn
=
(
dq_rn
-
x
).
float
()
me_rn
=
torch
.
sqrt
((
error_rn
*
error_rn
).
mean
())
error_t_rn
=
(
dq_t_rn
-
y
).
float
()
me_t_rn
=
torch
.
sqrt
((
error_t_rn
*
error_t_rn
).
mean
())
sr_result
=
torch
.
zeros_like
(
x
).
float
()
sr_t_result
=
torch
.
zeros_like
(
x
).
float
().
t
().
contiguous
()
for
i
in
range
(
n_iters
):
q_sr
,
s_sr
,
q_t_sr
,
s_t_sr
=
quantize_fp4
(
x
,
use_stochastic_rounding
=
True
,
use_2D
=
use_2D
,
use_RHT
=
use_RHT
)
dq_sr
=
dequantize_fp4
(
q_sr
,
s_sr
,
amax
)
dq_t_sr
=
dequantize_fp4
(
q_t_sr
,
s_t_sr
,
amax
)
sr_result
+=
dq_sr
.
float
()
sr_t_result
+=
dq_t_sr
.
float
()
# sr_result_tmp = sr_result / (i + 1)
# error_sr = (sr_result_tmp - x).float()
# me_sr = torch.sqrt((error_sr * error_sr).mean())
# sr_t_result_tmp = sr_t_result / (i + 1)
# error_t_sr = (sr_t_result_tmp - y).float()
# me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
# print(f"Iteration {i}: RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
# print(f"Iteration {i}: RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
# Get the mean result of the stochastic rounding
# It should be more accurate than the RN result
sr_result
/=
n_iters
error_sr
=
(
sr_result
-
x
).
float
()
me_sr
=
torch
.
sqrt
((
error_sr
*
error_sr
).
mean
())
sr_t_result
/=
n_iters
error_t_sr
=
(
sr_t_result
-
y
).
float
()
me_t_sr
=
torch
.
sqrt
((
error_t_sr
*
error_t_sr
).
mean
())
print
(
f
"RMSE SR:
{
me_sr
:.
3
e
}
| RMSE RN:
{
me_rn
:.
3
e
}
"
)
print
(
f
"RMSE SR_t:
{
me_t_sr
:.
3
e
}
| RMSE RN_t:
{
me_t_rn
:.
3
e
}
"
)
assert
me_sr
<
me_rn
,
"Stochastic rounding failed - error larger than the round to nearest."
assert
me_t_sr
<
me_t_rn
,
"Stochastic rounding failed - error larger than the round to nearest."
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
(
8192
,
8192
),
(
8192
,
8256
),
# to test the nonfused RHT path
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"use_2D"
,
[
False
,
True
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"use_RHT"
,
[
False
,
True
],
ids
=
str
)
def
test_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
use_2D
:
bool
,
use_RHT
:
bool
,
M
:
int
,
N
:
int
,
)
->
None
:
if
x_dtype
==
torch
.
float32
and
use_RHT
:
pytest
.
skip
(
"RHT is only supported with bfloat16"
)
check_quantization_nvfp4_versus_reference
(
x_dtype
=
x_dtype
,
use_2D
=
use_2D
,
use_RHT
=
use_RHT
,
M
=
M
,
N
=
N
,
)
tests/pytorch/test_checkpoint.py
View file @
063ef88d
...
@@ -12,13 +12,15 @@ import pathlib
...
@@ -12,13 +12,15 @@ import pathlib
import
pytest
import
pytest
import
torch
import
torch
from
typing
import
Optional
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
utils
import
make_recipe
from
utils
import
make_recipe
# Check supported quantization schemes
# Check supported quantization schemes
fp8_available
,
reason_for_no_fp8
=
te
.
fp8
.
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
fp8
.
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
# Test cases for loading checkpoint files
# Test cases for loading checkpoint files
...
@@ -65,16 +67,16 @@ class TestLoadCheckpoint:
...
@@ -65,16 +67,16 @@ class TestLoadCheckpoint:
if
name
==
"ops_linear"
:
if
name
==
"ops_linear"
:
return
te
.
ops
.
Linear
(
1
,
1
)
return
te
.
ops
.
Linear
(
1
,
1
)
if
name
==
"linear.fp8"
:
if
name
==
"linear.fp8"
:
with
te
.
fp8
_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
with
te
.
quantized
_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
return
te
.
Linear
(
16
,
16
)
return
te
.
Linear
(
16
,
16
)
if
name
==
"ops_linear.fp8"
:
if
name
==
"ops_linear.fp8"
:
with
te
.
fp8
_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
with
te
.
quantized
_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
return
te
.
ops
.
Linear
(
16
,
16
)
return
te
.
ops
.
Linear
(
16
,
16
)
if
name
==
"linear.mxfp8"
:
if
name
==
"linear.mxfp8"
:
with
te
.
fp8
_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
with
te
.
quantized
_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
return
te
.
Linear
(
32
,
32
)
return
te
.
Linear
(
32
,
32
)
if
name
==
"ops_linear.mxfp8"
:
if
name
==
"ops_linear.mxfp8"
:
with
te
.
fp8
_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
with
te
.
quantized
_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
return
te
.
ops
.
Linear
(
32
,
32
)
return
te
.
ops
.
Linear
(
32
,
32
)
raise
ValueError
(
f
"Unrecognized module name (
{
name
}
)"
)
raise
ValueError
(
f
"Unrecognized module name (
{
name
}
)"
)
...
...
tests/pytorch/test_cpu_offloading.py
View file @
063ef88d
...
@@ -12,14 +12,13 @@ import torch
...
@@ -12,14 +12,13 @@ import torch
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
from
utils
import
ModelConfig
,
get_available_attention_backends
from
utils
import
ModelConfig
,
get_available_attention_backends
# Check supported quantization schemes
# Check supported quantization schemes
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
=
te
.
is_fp8_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
=
te
.
is_mxfp8_available
()
quantization_recipes
:
Optional
[
recipe
.
Recipe
]
=
[
None
]
quantization_recipes
:
Optional
[
recipe
.
Recipe
]
=
[
None
]
if
fp8_available
:
if
fp8_available
:
...
@@ -79,9 +78,9 @@ def _warmup_model(
...
@@ -79,9 +78,9 @@ def _warmup_model(
"""Perform forward and backward pass"""
"""Perform forward and backward pass"""
tensor
=
_make_input
()
tensor
=
_make_input
()
for
module
in
modules
:
for
module
in
modules
:
with
te
.
fp8_
autocast
(
with
te
.
autocast
(
enabled
=
quantization_recipe
is
not
None
,
enabled
=
quantization_recipe
is
not
None
,
fp8_
recipe
=
quantization_recipe
,
recipe
=
quantization_recipe
,
):
):
tensor
=
module
(
tensor
)
tensor
=
module
(
tensor
)
tensor
.
sum
().
backward
()
tensor
.
sum
().
backward
()
...
@@ -159,8 +158,8 @@ def _measure_cached_memory(
...
@@ -159,8 +158,8 @@ def _measure_cached_memory(
tensor
=
inp
tensor
=
inp
memory_before_forward
=
torch
.
cuda
.
memory_allocated
()
/
(
1024
**
2
)
memory_before_forward
=
torch
.
cuda
.
memory_allocated
()
/
(
1024
**
2
)
for
module
in
modules
:
for
module
in
modules
:
with
te
.
fp8_
autocast
(
with
te
.
autocast
(
enabled
=
quantization_recipe
is
not
None
,
fp8_
recipe
=
quantization_recipe
enabled
=
quantization_recipe
is
not
None
,
recipe
=
quantization_recipe
),
offload_context
:
),
offload_context
:
tensor
=
module
(
tensor
)
tensor
=
module
(
tensor
)
tensor
=
sync_function
(
tensor
)
tensor
=
sync_function
(
tensor
)
...
...
tests/pytorch/test_cuda_graphs.py
View file @
063ef88d
...
@@ -13,12 +13,15 @@ from transformer_engine.pytorch import (
...
@@ -13,12 +13,15 @@ from transformer_engine.pytorch import (
Linear
,
Linear
,
MultiheadAttention
,
MultiheadAttention
,
TransformerLayer
,
TransformerLayer
,
fp8_
autocast
,
autocast
,
fp8
_model_init
,
quantized
_model_init
,
make_graphed_callables
,
make_graphed_callables
,
is_fp8_available
,
is_fp8_block_scaling_available
,
is_mxfp8_available
,
is_bf16_available
,
)
)
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.quantization
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine.pytorch.ops
as
te_ops
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
from
utils
import
ModelConfig
,
reset_rng_states
from
utils
import
ModelConfig
,
reset_rng_states
...
@@ -28,20 +31,67 @@ if IS_HIP_EXTENSION:
...
@@ -28,20 +31,67 @@ if IS_HIP_EXTENSION:
from
functools
import
cache
from
functools
import
cache
# Check if FP8 is supported.
# Check if FP8 is supported.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
=
is_fp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
=
is_fp8_block_scaling_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
=
is_mxfp8_available
()
# Reset RNG states.
# Reset RNG states.
reset_rng_states
()
reset_rng_states
()
model_configs
=
{
model_configs
=
{
"small"
:
ModelConfig
(
3
2
,
2
,
2
,
32
),
"small"
:
ModelConfig
(
2
,
3
2
,
2
,
32
),
}
}
def
nvfp4_vanilla
():
nvfp4_recipe
=
recipe
.
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
recipe
.
QParams
()
nvfp4_recipe
.
fp4_quant_fwd_weight
=
recipe
.
QParams
()
nvfp4_recipe
.
fp4_quant_bwd_grad
=
recipe
.
QParams
()
return
nvfp4_recipe
def
nvfp4_rht_and_2d_quantization
():
nvfp4_recipe
=
recipe
.
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
recipe
.
QParams
(
random_hadamard_transform
=
True
,
fp4_2d_quantization
=
False
)
nvfp4_recipe
.
fp4_quant_fwd_weight
=
recipe
.
QParams
(
random_hadamard_transform
=
False
,
fp4_2d_quantization
=
True
)
nvfp4_recipe
.
fp4_quant_bwd_grad
=
recipe
.
QParams
(
random_hadamard_transform
=
True
,
fp4_2d_quantization
=
False
)
return
nvfp4_recipe
def
check_rht_usage
(
recipe
:
recipe
.
Recipe
)
->
bool
:
# if using RHT, we can only support bf16
# check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
if
recipe
.
nvfp4
():
if
(
recipe
.
fp4_quant_fwd_inp
.
random_hadamard_transform
or
recipe
.
fp4_quant_fwd_weight
.
random_hadamard_transform
or
recipe
.
fp4_quant_bwd_grad
.
random_hadamard_transform
):
return
True
return
False
def
get_nvfp4_inp_supported_dtypes
(
recipe
:
recipe
.
Recipe
,
dtype
:
torch
.
dtype
)
->
bool
:
supported_input_dtypes
=
[]
if
recipe
.
nvfp4
():
supported_input_dtypes
.
append
(
torch
.
bfloat16
)
# if not using RHT, we can add fp32 as well
if
not
check_rht_usage
(
recipe
):
supported_input_dtypes
.
append
(
torch
.
float32
)
return
supported_input_dtypes
fp8_recipes
=
[]
fp8_recipes
=
[]
if
mxfp8_available
:
if
mxfp8_available
:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
fp8_recipes
.
append
(
nvfp4_rht_and_2d_quantization
())
if
fp8_block_scaling_available
:
if
fp8_block_scaling_available
:
fp8_recipes
.
append
(
recipe
.
Float8BlockScaling
())
fp8_recipes
.
append
(
recipe
.
Float8BlockScaling
())
if
fp8_available
:
if
fp8_available
:
...
@@ -50,7 +100,7 @@ if fp8_available:
...
@@ -50,7 +100,7 @@ if fp8_available:
# Supported data types
# Supported data types
dtypes
:
List
[
torch
.
dtype
]
=
[
torch
.
float32
,
torch
.
float16
]
dtypes
:
List
[
torch
.
dtype
]
=
[
torch
.
float32
,
torch
.
float16
]
if
is_bf16_
compati
ble
():
# bf16 requires sm_80 or higher
if
is_bf16_
availa
ble
():
# bf16 requires sm_80 or higher
dtypes
.
append
(
torch
.
bfloat16
)
dtypes
.
append
(
torch
.
bfloat16
)
...
@@ -167,7 +217,7 @@ def _test_cuda_graphs(
...
@@ -167,7 +217,7 @@ def _test_cuda_graphs(
fp8_weight_caching
=
False
fp8_weight_caching
=
False
# Create modules.
# Create modules.
with
fp8
_model_init
(
enabled
=
fp8_params
,
recipe
=
fp8_recipe
):
with
quantized
_model_init
(
enabled
=
fp8_params
,
recipe
=
fp8_recipe
):
if
module
==
"transformer"
:
if
module
==
"transformer"
:
modules
=
[
modules
=
[
TransformerLayer
(
TransformerLayer
(
...
@@ -247,9 +297,9 @@ def _test_cuda_graphs(
...
@@ -247,9 +297,9 @@ def _test_cuda_graphs(
model
,
model
,
(
generate_data
(
model_config
,
dtype
,
warmup
=
True
),),
(
generate_data
(
model_config
,
dtype
,
warmup
=
True
),),
num_warmup_iters
=
10
,
num_warmup_iters
=
10
,
fp8_
enabled
=
fp8
,
enabled
=
fp8
,
fp8_weight_caching
=
fp8_weight_caching
,
cache_quantized_params
=
fp8_weight_caching
,
fp8_
recipe
=
fp8_recipe
,
recipe
=
fp8_recipe
,
)
)
elif
graph_mode
==
"individual"
:
elif
graph_mode
==
"individual"
:
# Graph individual modules.
# Graph individual modules.
...
@@ -258,9 +308,9 @@ def _test_cuda_graphs(
...
@@ -258,9 +308,9 @@ def _test_cuda_graphs(
module
,
module
,
(
generate_data
(
model_config
,
dtype
,
warmup
=
True
),),
(
generate_data
(
model_config
,
dtype
,
warmup
=
True
),),
num_warmup_iters
=
10
,
num_warmup_iters
=
10
,
fp8_
enabled
=
fp8
,
enabled
=
fp8
,
fp8_weight_caching
=
fp8_weight_caching
,
cache_quantized_params
=
fp8_weight_caching
,
fp8_
recipe
=
fp8_recipe
,
recipe
=
fp8_recipe
,
)
)
for
module
in
modules
for
module
in
modules
]
]
...
@@ -277,7 +327,7 @@ def _test_cuda_graphs(
...
@@ -277,7 +327,7 @@ def _test_cuda_graphs(
for
grad_accumulation_step
in
range
(
2
):
for
grad_accumulation_step
in
range
(
2
):
input_
=
generate_data
(
model_config
,
dtype
)
input_
=
generate_data
(
model_config
,
dtype
)
grad_output
=
generate_data
(
model_config
,
dtype
,
requires_grad
=
False
)
grad_output
=
generate_data
(
model_config
,
dtype
,
requires_grad
=
False
)
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
fp8_recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
fp8_recipe
):
kwargs
=
{}
kwargs
=
{}
if
fp8_weight_caching
:
if
fp8_weight_caching
:
kwargs
[
"is_first_microbatch"
]
=
grad_accumulation_step
==
0
kwargs
[
"is_first_microbatch"
]
=
grad_accumulation_step
==
0
...
@@ -291,7 +341,7 @@ def _test_cuda_graphs(
...
@@ -291,7 +341,7 @@ def _test_cuda_graphs(
@
pytest
.
mark
.
parametrize
(
"module"
,
_test_cuda_graphs_modules
)
@
pytest
.
mark
.
parametrize
(
"module"
,
_test_cuda_graphs_modules
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
@
pytest
.
mark
.
parametrize
(
"fp8_params"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_params"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
+
[
None
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
+
[
None
]
,
ids
=
lambda
r
:
type
(
r
).
__name__
)
def
test_make_graphed_callables
(
def
test_make_graphed_callables
(
*
,
*
,
module
:
str
,
module
:
str
,
...
@@ -308,15 +358,25 @@ def test_make_graphed_callables(
...
@@ -308,15 +358,25 @@ def test_make_graphed_callables(
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
if
fp8_weight_caching
and
not
fp8
:
if
fp8_weight_caching
and
not
fp8
:
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
if
fp8
and
fp8_recipe
.
float8_block_scaling
()
and
module
==
"linear_op"
:
if
fp8
and
(
fp8_recipe
.
float8_block_scaling
()
or
fp8_recipe
.
nvfp4
())
and
module
==
"linear_op"
:
pytest
.
skip
(
"Module not yet supported for float8_block_scaling with CUDA graphs"
)
pytest
.
skip
(
f
"Module not yet supported for
{
fp8_recipe
.
__class__
.
__name__
}
with CUDA graphs"
)
if
fp8
and
fp8_recipe
.
nvfp4
():
if
dtype
not
in
get_nvfp4_inp_supported_dtypes
(
fp8_recipe
,
dtype
):
pytest
.
skip
(
f
"Input dtype
{
dtype
}
not supported for NVFP4 Recipe"
f
"
{
fp8_recipe
.
__class__
.
__name__
}
"
)
if
fp8_params
:
pytest
.
skip
(
"NVFP4 params not supported"
)
if
fp8
and
not
fp8_available
:
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
"FP8 not supported on rocm GPU."
)
if
fp8
and
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
if
fp8
and
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_
block
_
scaling
)
pytest
.
skip
(
"FP8
block
scaling
not supported on rocm GPU."
)
if
fp8
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
"MXFP8 not supported on rocm GPU."
)
# Run model with different CUDA graph settings.
# Run model with different CUDA graph settings.
model_config
=
model_configs
[
model_config
]
model_config
=
model_configs
[
model_config
]
kwargs
=
dict
(
kwargs
=
dict
(
...
@@ -353,17 +413,19 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
...
@@ -353,17 +413,19 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
"module"
,
"module"
,
_test_make_graphed_callables_with_fp8_weight_caching_modules
,
_test_make_graphed_callables_with_fp8_weight_caching_modules
,
)
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
@
pytest
.
mark
.
parametrize
(
"fp8_params"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_params"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
,
ids
=
lambda
r
:
type
(
r
).
__name__
)
def
test_make_graphed_callables_with_fp8_weight_caching
(
def
test_make_graphed_callables_with_fp8_weight_caching
(
*
,
*
,
module
:
str
,
module
:
str
,
dtype
:
torch
.
dtype
,
fp8_params
:
bool
,
fp8_params
:
bool
,
fp8_recipe
:
recipe
.
Recipe
,
fp8_recipe
:
recipe
.
Recipe
,
)
->
None
:
)
->
None
:
test_make_graphed_callables
(
test_make_graphed_callables
(
module
=
module
,
module
=
module
,
dtype
=
torch
.
float32
,
dtype
=
dtype
,
fp8_params
=
fp8_params
,
fp8_params
=
fp8_params
,
fp8_recipe
=
fp8_recipe
,
fp8_recipe
=
fp8_recipe
,
fp8_weight_caching
=
True
,
fp8_weight_caching
=
True
,
...
@@ -415,7 +477,7 @@ def _test_cuda_graphs_with_dot_product_attention(
...
@@ -415,7 +477,7 @@ def _test_cuda_graphs_with_dot_product_attention(
model
,
model
,
generate_data_for_dot_product_attention
(
model_config
,
dtype
,
warmup
=
True
),
generate_data_for_dot_product_attention
(
model_config
,
dtype
,
warmup
=
True
),
num_warmup_iters
=
10
,
num_warmup_iters
=
10
,
fp8_
enabled
=
False
,
enabled
=
False
,
)
)
# Forward and backward passes.
# Forward and backward passes.
...
...
tests/pytorch/test_custom_recipe.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch
import
(
autocast
,
Linear
,
LayerNormLinear
,
LayerNormMLP
,
GroupedLinear
,
Float8CurrentScalingQuantizer
,
)
import
transformer_engine.pytorch.ops
as
te_ops
@
pytest
.
mark
.
parametrize
(
"module_type"
,
[
"Linear"
,
"LayerNormLinear"
,
"OpsLinear"
,
"LayerNormMLP"
])
def
test_custom_recipe_sanity
(
module_type
):
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
torch
.
manual_seed
(
0
)
# Simple linear layer with dims divisible by 16
in_features
=
64
out_features
=
64
batch
=
32
if
module_type
==
"Linear"
:
model
=
Linear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
elif
module_type
==
"LayerNormLinear"
:
model
=
LayerNormLinear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
elif
module_type
==
"LayerNormMLP"
:
# hidden_size == in_features == out_features for simplicity
model
=
LayerNormMLP
(
hidden_size
=
in_features
,
ffn_hidden_size
=
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
else
:
# OpsLinear path
model
=
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
inp
=
torch
.
randn
(
batch
,
in_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
True
)
# Single factory: map roles to quantizers
def
quantizer_factory
(
role
):
if
role
in
(
"linear_input"
,
"linear_weight"
,
"linear_output"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
if
role
in
(
"linear_grad_output"
,
"linear_grad_input"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E5M2
,
device
=
"cuda"
)
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
custom_recipe
=
recipe
.
CustomRecipe
(
qfactory
=
quantizer_factory
)
# Execute with custom recipe
with
autocast
(
enabled
=
True
,
recipe
=
custom_recipe
):
out
=
model
(
inp
)
loss
=
out
.
float
().
sum
()
loss
.
backward
()
# Basic sanity: gradients exist
assert
inp
.
grad
is
not
None
def
test_custom_recipe_grouped_linear_sanity
():
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
torch
.
manual_seed
(
0
)
num_gemms
=
3
in_features
=
64
out_features
=
64
batch
=
32
base
=
batch
//
num_gemms
rem
=
batch
%
num_gemms
m_splits
=
[
base
+
(
1
if
i
<
rem
else
0
)
for
i
in
range
(
num_gemms
)]
model
=
GroupedLinear
(
num_gemms
,
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
inp
=
torch
.
randn
(
batch
,
in_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
True
)
def
quantizer_factory
(
role
):
if
role
in
(
"linear_input"
,
"linear_weight"
,
"linear_output"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
if
role
in
(
"linear_grad_output"
,
"linear_grad_input"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E5M2
,
device
=
"cuda"
)
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
custom_recipe
=
recipe
.
CustomRecipe
(
qfactory
=
quantizer_factory
)
with
autocast
(
enabled
=
True
,
recipe
=
custom_recipe
):
out
=
model
(
inp
,
m_splits
)
loss
=
out
.
float
().
sum
()
loss
.
backward
()
assert
inp
.
grad
is
not
None
def
test_custom_recipe_matches_current_scaling
():
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
torch
.
manual_seed
(
123
)
in_features
=
64
out_features
=
64
batch
=
32
# Create two identical models
model_ref
=
Linear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
model_custom
=
Linear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
).
cuda
()
model_custom
.
load_state_dict
(
model_ref
.
state_dict
())
# Identical inputs for both paths
base_inp
=
torch
.
randn
(
batch
,
in_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
inp_ref
=
base_inp
.
clone
().
detach
().
requires_grad_
(
True
)
inp_custom
=
base_inp
.
clone
().
detach
().
requires_grad_
(
True
)
# Reference: use Float8CurrentScaling recipe
ref_recipe
=
recipe
.
Float8CurrentScaling
()
with
autocast
(
enabled
=
True
,
recipe
=
ref_recipe
):
out_ref
=
model_ref
(
inp_ref
)
# Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd)
ref_fwd_in
=
model_ref
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
ref_fwd_w
=
model_ref
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
ref_fwd_out
=
model_ref
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
ref_bwd_go
=
model_ref
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
ref_bwd_gi
=
model_ref
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
]
assert
ref_fwd_in
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
ref_fwd_w
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
ref_fwd_out
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
ref_bwd_go
.
dtype
==
tex
.
DType
.
kFloat8E5M2
assert
ref_bwd_gi
.
dtype
==
tex
.
DType
.
kFloat8E5M2
# Stress dynamic range in grad_output
scale
=
torch
.
ones
(
out_features
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale
[
0
]
=
1e8
scale
[
1
]
=
1e-8
loss_ref
=
(
out_ref
.
float
()
*
scale
.
view
(
1
,
-
1
)).
sum
()
loss_ref
.
backward
()
# Custom: single factory returning quantizers per role to match Float8CurrentScaling
def
quantizer_factory
(
role
):
if
role
in
(
"linear_input"
,
"linear_weight"
,
"linear_output"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
if
role
in
(
"linear_grad_output"
,
"linear_grad_input"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E5M2
,
device
=
"cuda"
)
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
custom_recipe
=
recipe
.
CustomRecipe
(
qfactory
=
quantizer_factory
)
with
autocast
(
enabled
=
True
,
recipe
=
custom_recipe
):
out_custom
=
model_custom
(
inp_custom
)
# Assert dtypes for custom quantizers match reference mapping
cus_fwd_in
=
model_custom
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
cus_fwd_w
=
model_custom
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
cus_fwd_out
=
model_custom
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
cus_bwd_go
=
model_custom
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
cus_bwd_gi
=
model_custom
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
]
assert
cus_fwd_in
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
cus_fwd_w
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
cus_fwd_out
.
dtype
==
tex
.
DType
.
kFloat8E4M3
assert
cus_bwd_go
.
dtype
==
tex
.
DType
.
kFloat8E5M2
assert
cus_bwd_gi
.
dtype
==
tex
.
DType
.
kFloat8E5M2
loss_custom
=
(
out_custom
.
float
()
*
scale
.
view
(
1
,
-
1
)).
sum
()
loss_custom
.
backward
()
# Compare forward outputs (exact match expected)
assert
torch
.
allclose
(
out_ref
,
out_custom
,
rtol
=
0.0
,
atol
=
0.0
)
# Compare input gradients
assert
inp_ref
.
grad
is
not
None
and
inp_custom
.
grad
is
not
None
assert
torch
.
allclose
(
inp_ref
.
grad
,
inp_custom
.
grad
,
rtol
=
0.0
,
atol
=
0.0
)
# Compare parameter gradients (weights and bias if present)
ref_params
=
dict
(
model_ref
.
named_parameters
())
custom_params
=
dict
(
model_custom
.
named_parameters
())
for
name
,
p_ref
in
ref_params
.
items
():
p_cus
=
custom_params
[
name
]
assert
p_ref
.
grad
is
not
None
and
p_cus
.
grad
is
not
None
assert
torch
.
allclose
(
p_ref
.
grad
,
p_cus
.
grad
,
rtol
=
0.0
,
atol
=
0.0
)
def
test_custom_recipe_ops_linear_2_1_layout
():
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
torch
.
manual_seed
(
7
)
in_features
=
64
out_features
=
64
batch
=
16
# Use ops.Linear which consumes 2 forward quantizers and 1 backward quantizer
op
=
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
inp
=
torch
.
randn
(
batch
,
in_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
True
)
def
quantizer_factory
(
role
):
if
role
in
(
"linear_input"
,
"linear_weight"
,
"linear_output"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
if
role
in
(
"linear_grad_output"
,
"linear_grad_input"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E5M2
,
device
=
"cuda"
)
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
custom
=
recipe
.
CustomRecipe
(
qfactory
=
quantizer_factory
)
with
autocast
(
enabled
=
True
,
recipe
=
custom
):
out
=
op
(
inp
)
loss
=
out
.
float
().
sum
()
loss
.
backward
()
assert
inp
.
grad
is
not
None
def
test_custom_recipe_factory_invocation_counts_and_cycling
():
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
torch
.
manual_seed
(
13
)
in_features
=
64
out_features
=
64
batch
=
8
op
=
Linear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
)
inp
=
torch
.
randn
(
batch
,
in_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
True
)
# Counters per role
counts
=
{
"linear_input"
:
0
,
"linear_weight"
:
0
,
"linear_output"
:
0
,
"linear_grad_output"
:
0
,
"linear_grad_input"
:
0
,
}
def
quantizer_factory
(
role
):
if
role
in
counts
:
counts
[
role
]
+=
1
if
role
in
(
"linear_input"
,
"linear_weight"
,
"linear_output"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
torch
.
device
(
"cuda"
))
if
role
in
(
"linear_grad_output"
,
"linear_grad_input"
):
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E5M2
,
device
=
torch
.
device
(
"cuda"
))
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
torch
.
device
(
"cuda"
))
custom
=
recipe
.
CustomRecipe
(
qfactory
=
quantizer_factory
)
# Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory),
# and backward to build 2 quantizers (cycled from 1 factory).
with
autocast
(
enabled
=
True
,
recipe
=
custom
):
out
=
op
(
inp
)
loss
=
out
.
float
().
sum
()
loss
.
backward
()
# Single GEMM: forward should request input, weight, output; backward grad_output, grad_input
assert
counts
[
"linear_input"
]
==
1
assert
counts
[
"linear_weight"
]
==
1
assert
counts
[
"linear_output"
]
==
1
assert
counts
[
"linear_grad_output"
]
==
1
assert
counts
[
"linear_grad_input"
]
==
1
def
test_factories_return_distinct_instances_and_buffers
():
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
# Two calls should produce distinct quantizer objects and distinct tensor buffers
def
factory
():
return
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
device
=
torch
.
device
(
"cuda"
))
q1
=
factory
()
q2
=
factory
()
assert
q1
is
not
q2
assert
q1
.
scale
.
data_ptr
()
!=
q2
.
scale
.
data_ptr
()
assert
q1
.
amax
.
data_ptr
()
!=
q2
.
amax
.
data_ptr
()
# Mutating one should not affect the other
q1
.
scale
.
fill_
(
123.0
)
assert
not
torch
.
equal
(
q1
.
scale
,
q2
.
scale
)
tests/pytorch/test_deferred_init.py
View file @
063ef88d
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
...
...
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
063ef88d
...
@@ -4,13 +4,13 @@
...
@@ -4,13 +4,13 @@
import
pytest
import
pytest
import
torch
import
torch
import
transformer_engine
as
te
import
transformer_engine
.pytorch
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
,
int8_simulation_fp8
)
from
transformer_engine.pytorch.fp8
import
(
blockwise_fp8_block_len
,
int8_simulation_fp8
)
from
transformer_engine.pytorch
.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch
import
(
Float8BlockQuantizer
,
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
get_device_compute_capability
,
)
)
from
references.blockwise_quantizer_reference
import
CuBLASScaleMunger
from
references.blockwise_quantizer_reference
import
CuBLASScaleMunger
...
@@ -19,8 +19,9 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
...
@@ -19,8 +19,9 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
from
transformer_engine.pytorch.cpp_extensions.gemm
import
w8a8_int8_general_gemm
from
transformer_engine.pytorch.cpp_extensions.gemm
import
w8a8_int8_general_gemm
def
fp8_blockwise_gemm_supported
()
->
bool
:
def
fp8_blockwise_gemm_supported
()
->
bool
:
supported
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
supported
=
te
.
is_fp8_block_scaling_available
()
return
supported
emulated
=
get_device_compute_capability
()
>=
(
10
,
0
)
return
supported
and
not
emulated
def
cublas_gemm_fp8_blockwise_case
(
def
cublas_gemm_fp8_blockwise_case
(
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
063ef88d
...
@@ -8,14 +8,13 @@ import os
...
@@ -8,14 +8,13 @@ import os
import
pathlib
import
pathlib
import
pytest
import
pytest
import
torch
import
torch
import
transformer_engine
as
te
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.common.recipe
import
Float8BlockScaling
from
transformer_engine.common.recipe
import
Float8BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch
.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch
import
(
Float8BlockQuantizer
,
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
get_device_compute_capability
,
)
)
from
references.blockwise_quantizer_reference
import
(
from
references.blockwise_quantizer_reference
import
(
BlockwiseQuantizerReference
,
BlockwiseQuantizerReference
,
...
@@ -32,7 +31,8 @@ TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tenso
...
@@ -32,7 +31,8 @@ TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tenso
tensor_dump_dir_env
=
os
.
getenv
(
"NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR"
)
tensor_dump_dir_env
=
os
.
getenv
(
"NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR"
)
if
tensor_dump_dir_env
is
not
None
:
if
tensor_dump_dir_env
is
not
None
:
TENSOR_DUMP_DIR
=
pathlib
.
Path
(
tensor_dump_dir_env
)
TENSOR_DUMP_DIR
=
pathlib
.
Path
(
tensor_dump_dir_env
)
recipe_available
,
reason_for_no_recipe
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
recipe_available
,
reason_for_no_recipe
=
te
.
is_fp8_block_scaling_available
(
return_reason
=
True
)
recipe_emulated
=
get_device_compute_capability
()
>=
(
10
,
0
)
class
GetRecipes
:
class
GetRecipes
:
...
@@ -219,6 +219,12 @@ def check_quantization_block_tiling_versus_reference(
...
@@ -219,6 +219,12 @@ def check_quantization_block_tiling_versus_reference(
pow_2_scales
:
bool
,
pow_2_scales
:
bool
,
tile_size
:
Tuple
[
int
,
int
],
tile_size
:
Tuple
[
int
,
int
],
)
->
None
:
)
->
None
:
if
recipe_emulated
and
not
pow_2_scales
:
pytest
.
skip
(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)
te_dtype
=
TE_DType
[
quant_dtype
]
te_dtype
=
TE_DType
[
quant_dtype
]
if
tile_size
in
((
1
,
128
),
(
1
,
64
)):
if
tile_size
in
((
1
,
128
),
(
1
,
64
)):
block_scaling_dim
=
1
block_scaling_dim
=
1
...
@@ -414,6 +420,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
...
@@ -414,6 +420,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
tile_size
:
Tuple
[
int
,
int
],
tile_size
:
Tuple
[
int
,
int
],
extrema_high
:
bool
,
extrema_high
:
bool
,
)
->
None
:
)
->
None
:
if
recipe_emulated
and
not
pow_2_scales
:
pytest
.
skip
(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)
# This test runs a single tile through a quantizer as a way to test
# This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation.
# branch coverage of scale computation.
if
blockwise_fp8_block_len
!=
tile_size
[
1
]:
if
blockwise_fp8_block_len
!=
tile_size
[
1
]:
...
...
tests/pytorch/test_float8_current_scaling_exact.py
View file @
063ef88d
...
@@ -8,12 +8,9 @@ import torch
...
@@ -8,12 +8,9 @@ import torch
import
pytest
import
pytest
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.common.recipe
import
Float8CurrentScaling
from
transformer_engine.common.recipe
import
Float8CurrentScaling
from
transformer_engine.pytorch.
fp8
import
fp8_
autocast
,
get_fp8_torch_dtype
from
transformer_engine.pytorch.
quantization
import
autocast
,
get_fp8_torch_dtype
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8
...
@@ -25,7 +22,7 @@ if tensor_dump_dir_env is not None:
...
@@ -25,7 +22,7 @@ if tensor_dump_dir_env is not None:
# Check if FP8 is supported
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
class
GetRecipes
:
class
GetRecipes
:
...
@@ -274,6 +271,14 @@ class TestFP8RecipeLinearBase:
...
@@ -274,6 +271,14 @@ class TestFP8RecipeLinearBase:
if
bgrad_list
is
not
None
and
bgrad
is
not
None
:
if
bgrad_list
is
not
None
and
bgrad
is
not
None
:
bgrad_list
.
append
(
bgrad
.
detach
().
clone
())
bgrad_list
.
append
(
bgrad
.
detach
().
clone
())
# Stack the results
return
(
torch
.
stack
(
y_q_list
),
torch
.
stack
(
dgrad_list
),
torch
.
stack
(
wgrad_list
),
torch
.
stack
(
bgrad_list
)
if
bgrad_list
is
not
None
else
None
,
)
@
classmethod
@
classmethod
def
run_linear
(
def
run_linear
(
cls
,
cls
,
...
@@ -388,7 +393,7 @@ class TestFP8RecipeLinearBase:
...
@@ -388,7 +393,7 @@ class TestFP8RecipeLinearBase:
# recipe1
# recipe1
using_fp8_recipe
=
recipe1
()
!=
GetRecipes
.
none
()
using_fp8_recipe
=
recipe1
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
if
using_fp8_recipe
:
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe1
()):
with
autocast
(
enabled
=
True
,
recipe
=
recipe1
()):
y_q_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
y_q_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
else
:
else
:
y_q_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
y_q_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
...
@@ -396,7 +401,7 @@ class TestFP8RecipeLinearBase:
...
@@ -396,7 +401,7 @@ class TestFP8RecipeLinearBase:
# recipe2
# recipe2
using_fp8_recipe
=
recipe2
()
!=
GetRecipes
.
none
()
using_fp8_recipe
=
recipe2
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
if
using_fp8_recipe
:
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe2
()):
with
autocast
(
enabled
=
True
,
recipe
=
recipe2
()):
y_q
,
dgrad
,
wgrad
,
bgrad
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
y_q
,
dgrad
,
wgrad
,
bgrad
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
else
:
else
:
y_q
,
dgrad
,
wgrad
,
bgrad
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
y_q
,
dgrad
,
wgrad
,
bgrad
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
...
@@ -611,7 +616,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
...
@@ -611,7 +616,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe1
# recipe1
using_fp8_recipe
=
recipe1
()
!=
GetRecipes
.
none
()
using_fp8_recipe
=
recipe1
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
if
using_fp8_recipe
:
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe1
()):
with
autocast
(
enabled
=
True
,
recipe
=
recipe1
()):
y_q_ref
,
ln_out_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_layernorm_linear
(
y_q_ref
,
ln_out_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_layernorm_linear
(
x
,
x
,
w
,
w
,
...
@@ -633,7 +638,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
...
@@ -633,7 +638,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
# recipe2
# recipe2
using_fp8_recipe
=
recipe2
()
!=
GetRecipes
.
none
()
using_fp8_recipe
=
recipe2
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
if
using_fp8_recipe
:
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe2
()):
with
autocast
(
enabled
=
True
,
recipe
=
recipe2
()):
y_q
,
ln_out
,
dgrad
,
wgrad
,
bgrad
=
self
.
run_layernorm_linear
(
y_q
,
ln_out
,
dgrad
,
wgrad
,
bgrad
=
self
.
run_layernorm_linear
(
x
,
x
,
w
,
w
,
...
...
tests/pytorch/test_float8blockwisetensor.py
View file @
063ef88d
...
@@ -11,12 +11,11 @@ import pytest
...
@@ -11,12 +11,11 @@ import pytest
import
torch
import
torch
import
transformer_engine.common.recipe
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch
import
(
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockQuantizer
,
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
Float8BlockwiseQTensor
,
get_device_compute_capability
,
)
)
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
...
...
tests/pytorch/test_float8tensor.py
View file @
063ef88d
...
@@ -11,13 +11,11 @@ import torch
...
@@ -11,13 +11,11 @@ import torch
import
transformer_engine.common.recipe
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch
import
(
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8Quantizer
,
Float8Tensor
,
Float8Tensor
,
Float8CurrentScalingQuantizer
,
Float8CurrentScalingQuantizer
,
)
)
from
transformer_engine.pytorch.constants
import
TE_DType
,
TE_DType_To_Torch
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
...
@@ -47,7 +45,7 @@ def _to_list(x: Union[Iterable, Any]) -> List:
...
@@ -47,7 +45,7 @@ def _to_list(x: Union[Iterable, Any]) -> List:
DimsType
=
Union
[
Iterable
[
int
],
int
]
DimsType
=
Union
[
Iterable
[
int
],
int
]
# Check if FP8 is supported
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
# delayed scaling
# delayed scaling
...
...
tests/pytorch/test_fused_optimizer.py
View file @
063ef88d
...
@@ -11,14 +11,11 @@ from torch import nn
...
@@ -11,14 +11,11 @@ from torch import nn
from
torch.testing._internal.common_device_type
import
largeTensorTest
from
torch.testing._internal.common_device_type
import
largeTensorTest
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.pytorch.attention.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch
import
MultiheadAttention
,
quantized_model_init
,
is_bf16_available
from
transformer_engine.pytorch
import
fp8_model_init
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
gpu_autocast_ctx
from
transformer_engine.pytorch.utils
import
gpu_autocast_ctx
# Check if FP8 is supported
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
class
TestFusedOptimizer
:
class
TestFusedOptimizer
:
...
@@ -188,7 +185,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -188,7 +185,7 @@ class TestFusedAdam(TestFusedOptimizer):
build_model_context
=
nullcontext
build_model_context
=
nullcontext
build_model_context_args
=
{}
build_model_context_args
=
{}
if
use_fp8_params
:
if
use_fp8_params
:
build_model_context
=
fp8
_model_init
build_model_context
=
quantized
_model_init
build_model_context_args
[
"enabled"
]
=
True
build_model_context_args
[
"enabled"
]
=
True
with
build_model_context
(
**
build_model_context_args
):
with
build_model_context
(
**
build_model_context_args
):
...
@@ -286,7 +283,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -286,7 +283,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype
=
torch
.
float32
,
exp_avg_sq_dtype
=
torch
.
float32
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_fp32_master
(
self
):
def
test_fp32_master
(
self
):
self
.
gen_precision_aware_test
(
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
use_fp8_params
=
False
,
...
@@ -298,7 +295,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -298,7 +295,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype
=
torch
.
float32
,
exp_avg_sq_dtype
=
torch
.
float32
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_fp32_master_store_param_remainders
(
self
):
def
test_fp32_master_store_param_remainders
(
self
):
self
.
gen_precision_aware_test
(
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
use_fp8_params
=
False
,
...
@@ -311,7 +308,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -311,7 +308,7 @@ class TestFusedAdam(TestFusedOptimizer):
store_param_remainders
=
True
,
store_param_remainders
=
True
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_fp16_master
(
self
):
def
test_fp16_master
(
self
):
self
.
gen_precision_aware_test
(
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
use_fp8_params
=
False
,
...
@@ -325,7 +322,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -325,7 +322,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
master_atol
=
2e-3
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_bf16_grad
(
self
):
def
test_bf16_grad
(
self
):
self
.
gen_precision_aware_test
(
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
use_fp8_params
=
False
,
...
@@ -339,7 +336,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -339,7 +336,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
master_atol
=
2e-3
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_fp16_exp_avg
(
self
):
def
test_fp16_exp_avg
(
self
):
self
.
gen_precision_aware_test
(
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
use_fp8_params
=
False
,
...
@@ -353,7 +350,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -353,7 +350,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
master_atol
=
2e-3
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_bf16_exp_avg
(
self
):
def
test_bf16_exp_avg
(
self
):
self
.
gen_precision_aware_test
(
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
use_fp8_params
=
False
,
...
@@ -367,7 +364,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -367,7 +364,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
master_atol
=
2e-3
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_exp_avg
(
self
):
def
test_fp8_exp_avg
(
self
):
self
.
gen_precision_aware_test
(
self
.
gen_precision_aware_test
(
...
@@ -382,7 +379,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -382,7 +379,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
1e-2
,
master_atol
=
1e-2
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_fp16_exp_avg_sq
(
self
):
def
test_fp16_exp_avg_sq
(
self
):
self
.
gen_precision_aware_test
(
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
use_fp8_params
=
False
,
...
@@ -396,7 +393,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -396,7 +393,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
master_atol
=
2e-3
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_bf16_exp_avg_sq
(
self
):
def
test_bf16_exp_avg_sq
(
self
):
self
.
gen_precision_aware_test
(
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
use_fp8_params
=
False
,
...
@@ -410,7 +407,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -410,7 +407,7 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
master_atol
=
2e-3
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_exp_avg_sq
(
self
):
def
test_fp8_exp_avg_sq
(
self
):
self
.
gen_precision_aware_test
(
self
.
gen_precision_aware_test
(
...
@@ -424,7 +421,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -424,7 +421,7 @@ class TestFusedAdam(TestFusedOptimizer):
skip_assert
=
True
,
skip_assert
=
True
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
compati
ble
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_
availa
ble
(),
reason
=
"bf16 if not supported"
)
def
test_bf16_model_weight_cast
(
self
):
def
test_bf16_model_weight_cast
(
self
):
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
model
=
MultiheadAttention
(
model
=
MultiheadAttention
(
...
@@ -468,7 +465,7 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -468,7 +465,7 @@ class TestFusedAdam(TestFusedOptimizer):
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_model_weight_cast
(
self
):
def
test_fp8_model_weight_cast
(
self
):
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
with
fp8
_model_init
(
enabled
=
True
,
recipe
=
DelayedScaling
()):
with
quantized
_model_init
(
enabled
=
True
,
recipe
=
DelayedScaling
()):
model
=
MultiheadAttention
(
model
=
MultiheadAttention
(
hidden_size
=
1024
,
hidden_size
=
1024
,
num_attention_heads
=
16
,
num_attention_heads
=
16
,
...
...
tests/pytorch/test_fused_rope.py
View file @
063ef88d
...
@@ -373,3 +373,19 @@ def test_fused_qkv_rope(
...
@@ -373,3 +373,19 @@ def test_fused_qkv_rope(
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
def
test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_without_autocast
():
rope_layer
=
RotaryPositionEmbedding
(
128
)
rope_embeddings_no_autocast
=
rope_layer
(
max_seq_len
=
1024
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
torch
.
bfloat16
):
rope_embeddings_autocast
=
rope_layer
(
max_seq_len
=
1024
)
torch
.
testing
.
assert_close
(
rope_embeddings_no_autocast
.
to
(
dtype
=
torch
.
bfloat16
),
rope_embeddings_autocast
.
to
(
dtype
=
torch
.
bfloat16
),
atol
=
1e-8
,
rtol
=
1e-8
,
)
tests/pytorch/test_fusible_ops.py
View file @
063ef88d
...
@@ -7,8 +7,6 @@ from __future__ import annotations
...
@@ -7,8 +7,6 @@ from __future__ import annotations
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
import
io
import
io
import
math
import
math
import
pathlib
import
sys
from
typing
import
Optional
from
typing
import
Optional
import
pytest
import
pytest
...
@@ -18,7 +16,6 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
...
@@ -18,7 +16,6 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import
transformer_engine
import
transformer_engine
import
transformer_engine.common.recipe
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops.fused
import
(
from
transformer_engine.pytorch.ops.fused
import
(
BackwardActivationBias
,
BackwardActivationBias
,
...
@@ -29,20 +26,18 @@ from transformer_engine.pytorch.ops.fused import (
...
@@ -29,20 +26,18 @@ from transformer_engine.pytorch.ops.fused import (
ForwardLinearBiasAdd
,
ForwardLinearBiasAdd
,
ForwardLinearScaleAdd
,
ForwardLinearScaleAdd
,
)
)
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
from
transformer_engine.pytorch
import
(
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
QuantizedTensor
,
Float8Tensor
,
Float8CurrentScalingQuantizer
,
Float8CurrentScalingQuantizer
,
Float8Quantizer
,
Float8Quantizer
,
MXFP8Quantizer
,
NVFP4Quantizer
,
is_bf16_available
,
)
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
,
MXFP8Quantizer
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
# Import utility functions
# Import utility functions
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
from
utils
import
dtype_tols
,
make_recipe
,
quantization_tols
,
reset_rng_states
sys
.
path
.
append
(
str
(
_current_file
.
parent
))
from
utils
import
dtype_tols
,
make_recipe
,
reset_rng_states
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
import
os
import
os
...
@@ -52,13 +47,14 @@ if IS_HIP_EXTENSION:
...
@@ -52,13 +47,14 @@ if IS_HIP_EXTENSION:
return
(
os
.
getenv
(
"NVTE_USE_HIPBLASLT"
)
is
not
None
return
(
os
.
getenv
(
"NVTE_USE_HIPBLASLT"
)
is
not
None
or
os
.
getenv
(
"NVTE_USE_ROCBLAS"
)
is
None
)
or
os
.
getenv
(
"NVTE_USE_ROCBLAS"
)
is
None
)
# Check if FP8 is supported
# Check for supported quantization schemes
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
nvfp4_available
,
reason_for_no_nvfp4
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
# Supported data types
# Supported data types
_dtypes
:
list
[
torch
.
dtype
]
=
[
torch
.
float32
,
torch
.
float16
]
_dtypes
:
list
[
torch
.
dtype
]
=
[
torch
.
float32
,
torch
.
float16
]
if
is_bf16_
compati
ble
():
# bf16 requires sm_80 or higher
if
is_bf16_
availa
ble
():
# bf16 requires sm_80 or higher
_dtypes
.
append
(
torch
.
bfloat16
)
_dtypes
.
append
(
torch
.
bfloat16
)
# Supported devices
# Supported devices
...
@@ -70,6 +66,8 @@ if fp8_available:
...
@@ -70,6 +66,8 @@ if fp8_available:
_quantization_list
.
extend
((
"fp8_delayed_scaling"
,
"fp8_current_scaling"
))
_quantization_list
.
extend
((
"fp8_delayed_scaling"
,
"fp8_current_scaling"
))
if
mxfp8_available
:
if
mxfp8_available
:
_quantization_list
.
append
(
"mxfp8"
)
_quantization_list
.
append
(
"mxfp8"
)
if
nvfp4_available
:
_quantization_list
.
append
(
"nvfp4"
)
def
maybe_skip_quantization
(
def
maybe_skip_quantization
(
...
@@ -77,6 +75,7 @@ def maybe_skip_quantization(
...
@@ -77,6 +75,7 @@ def maybe_skip_quantization(
*
,
*
,
dims
:
Optional
[
Iterable
[
int
]
|
int
]
=
None
,
dims
:
Optional
[
Iterable
[
int
]
|
int
]
=
None
,
device
:
Optional
[
torch
.
device
|
str
]
=
None
,
device
:
Optional
[
torch
.
device
|
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
None
:
)
->
None
:
"""Skip test case if a quantization scheme is not supported"""
"""Skip test case if a quantization scheme is not supported"""
...
@@ -84,12 +83,17 @@ def maybe_skip_quantization(
...
@@ -84,12 +83,17 @@ def maybe_skip_quantization(
if
quantization
is
None
:
if
quantization
is
None
:
return
return
# Check if quantization scheme is supported
# Check if quantization scheme is supported on device
if
device
is
not
None
and
torch
.
device
(
device
).
type
!=
"cuda"
:
pytest
.
skip
(
"Quantization is only supported on CUDA devices"
)
if
quantization
in
(
"fp8"
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
)
and
not
fp8_available
:
if
quantization
in
(
"fp8"
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
)
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
quantization
==
"nvfp4"
and
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
# Check dims
if
dims
is
not
None
:
if
dims
is
not
None
:
if
not
isinstance
(
dims
,
Iterable
):
if
not
isinstance
(
dims
,
Iterable
):
dims
=
(
dims
,)
dims
=
(
dims
,)
...
@@ -99,10 +103,14 @@ def maybe_skip_quantization(
...
@@ -99,10 +103,14 @@ def maybe_skip_quantization(
elif
quantization
==
"mxfp8"
:
elif
quantization
==
"mxfp8"
:
if
math
.
prod
(
dims
[:
-
1
])
%
32
!=
0
or
dims
[
-
1
]
%
32
!=
0
:
if
math
.
prod
(
dims
[:
-
1
])
%
32
!=
0
or
dims
[
-
1
]
%
32
!=
0
:
pytest
.
skip
(
"MXFP8 GEMMs require dims that are divisible by 32"
)
pytest
.
skip
(
"MXFP8 GEMMs require dims that are divisible by 32"
)
elif
quantization
==
"nvfp4"
:
if
math
.
prod
(
dims
[:
-
1
])
%
16
!=
0
or
dims
[
-
1
]
%
16
!=
0
:
pytest
.
skip
(
"NVFP4 GEMMs require dims that are divisible by 16"
)
# Check if device is supported
# Check dtype
if
device
is
not
None
and
torch
.
device
(
device
).
type
!=
"cuda"
:
if
dtype
is
not
None
:
pytest
.
skip
(
"Quantization is only supported on CUDA devices"
)
if
quantization
==
"nvfp4"
and
dtype
!=
torch
.
bfloat16
:
pytest
.
skip
(
"NVFP4 quantization is only supported with BF16 data"
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -152,6 +160,14 @@ def make_reference_and_test_tensors(
...
@@ -152,6 +160,14 @@ def make_reference_and_test_tensors(
test
=
quantizer
(
test
)
test
=
quantizer
(
test
)
elif
quantization
==
"mxfp8"
:
elif
quantization
==
"mxfp8"
:
test
=
MXFP8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)(
test
)
test
=
MXFP8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)(
test
)
elif
quantization
==
"nvfp4"
:
test
=
NVFP4Quantizer
(
with_rht
=
False
,
with_post_rht_amax
=
False
,
with_2d_quantization
=
False
,
stochastic_rounding
=
False
,
with_random_sign_mask
=
False
,
)(
test
)
else
:
else
:
raise
ValueError
(
f
"Unsupported quantization scheme (
{
quantization
}
)"
)
raise
ValueError
(
f
"Unsupported quantization scheme (
{
quantization
}
)"
)
if
isinstance
(
test
,
QuantizedTensor
)
and
not
test_is_quantized
:
if
isinstance
(
test
,
QuantizedTensor
)
and
not
test_is_quantized
:
...
@@ -361,7 +377,7 @@ class TestFuser:
...
@@ -361,7 +377,7 @@ class TestFuser:
)
)
# Construct model
# Construct model
with
te
.
fp8
_model_init
(
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
recipe
=
recipe
):
model
=
te_ops
.
basic
.
BasicLinear
(
model
=
te_ops
.
basic
.
BasicLinear
(
size
,
size
,
size
,
size
,
...
@@ -393,7 +409,7 @@ class TestFuser:
...
@@ -393,7 +409,7 @@ class TestFuser:
)
)
# Training step
# Training step
with
te
.
fp8_
autocast
(
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
recipe
=
recipe
):
y
=
model
(
x
)
y
=
model
(
x
)
y
.
backward
(
dy
)
y
.
backward
(
dy
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -406,12 +422,12 @@ class TestFuser:
...
@@ -406,12 +422,12 @@ class TestFuser:
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
y
,
y
,
torch
.
full_like
(
y
,
y_val_ref
),
torch
.
full_like
(
y
,
y_val_ref
),
**
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
),
**
quantization_tols
(
"fp8_delayed_scaling"
),
)
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
x
.
grad
,
x
.
grad
,
torch
.
full_like
(
x
.
grad
,
dx_val_ref
),
torch
.
full_like
(
x
.
grad
,
dx_val_ref
),
**
dtype_tols
(
tex
.
DType
.
kFloat8E5M2
),
**
quantization_tols
(
"fp8_delayed_scaling"
),
)
)
# Check that scaling factors match expected
# Check that scaling factors match expected
...
@@ -445,7 +461,8 @@ class TestFuser:
...
@@ -445,7 +461,8 @@ class TestFuser:
# Skip invalid configurations
# Skip invalid configurations
in_shape
=
(
size
,
size
)
in_shape
=
(
size
,
size
)
with_quantization
=
quantization
is
not
None
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
init_dtype
)
maybe_skip_quantization
(
quantization
,
dtype
=
final_dtype
)
# Random data
# Random data
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
@@ -461,7 +478,7 @@ class TestFuser:
...
@@ -461,7 +478,7 @@ class TestFuser:
)
)
# Construct operation
# Construct operation
with
te
.
fp8
_model_init
(
enabled
=
with_quantization
,
recipe
=
make_recipe
(
quantization
)):
with
te
.
quantized
_model_init
(
enabled
=
with_quantization
,
recipe
=
make_recipe
(
quantization
)):
op
=
te_ops
.
Linear
(
size
,
size
,
bias
=
False
,
device
=
device
,
dtype
=
init_dtype
)
op
=
te_ops
.
Linear
(
size
,
size
,
bias
=
False
,
device
=
device
,
dtype
=
init_dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
op
.
weight
.
copy_
(
w_test
)
op
.
weight
.
copy_
(
w_test
)
...
@@ -513,11 +530,12 @@ class TestFuser:
...
@@ -513,11 +530,12 @@ class TestFuser:
# Skip invalid configurations
# Skip invalid configurations
in_shape
=
(
size
,
size
)
in_shape
=
(
size
,
size
)
quantized_compute
=
quantization
is
not
None
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
model_dtype
)
maybe_skip_quantization
(
quantization
,
dtype
=
autocast_dtype
)
# Construct operation
# Construct operation
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weights
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weights
,
recipe
=
recipe
):
op
=
te_ops
.
Linear
(
size
,
size
,
bias
=
False
,
device
=
device
,
dtype
=
model_dtype
)
op
=
te_ops
.
Linear
(
size
,
size
,
bias
=
False
,
device
=
device
,
dtype
=
model_dtype
)
# Check forward and backward pass
# Check forward and backward pass
...
@@ -527,7 +545,7 @@ class TestFuser:
...
@@ -527,7 +545,7 @@ class TestFuser:
device
=
device
,
device
=
device
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
with
torch
.
autocast
(
device_type
=
device
.
type
,
dtype
=
autocast_dtype
):
with
torch
.
autocast
(
device_type
=
device
.
type
,
dtype
=
autocast_dtype
):
y
=
op
(
x
)
y
=
op
(
x
)
y
.
backward
(
torch
.
zeros_like
(
y
))
y
.
backward
(
torch
.
zeros_like
(
y
))
...
@@ -540,7 +558,7 @@ class TestFuser:
...
@@ -540,7 +558,7 @@ class TestFuser:
x
.
grad
=
None
x
.
grad
=
None
op
.
weight
.
grad
=
None
op
.
weight
.
grad
=
None
with
torch
.
autocast
(
device_type
=
device
.
type
,
dtype
=
autocast_dtype
):
with
torch
.
autocast
(
device_type
=
device
.
type
,
dtype
=
autocast_dtype
):
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y
=
op
(
x
)
y
=
op
(
x
)
y
.
backward
(
torch
.
zeros_like
(
y
))
y
.
backward
(
torch
.
zeros_like
(
y
))
assert
y
.
dtype
==
autocast_dtype
assert
y
.
dtype
==
autocast_dtype
...
@@ -569,7 +587,7 @@ class TestBasicOps:
...
@@ -569,7 +587,7 @@ class TestBasicOps:
# Skip invalid configurations
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
@@ -635,7 +653,7 @@ class TestBasicOps:
...
@@ -635,7 +653,7 @@ class TestBasicOps:
# Skip invalid configurations
# Skip invalid configurations
if
memory_format
==
torch
.
channels_last
and
len
(
in_shape
)
!=
4
:
if
memory_format
==
torch
.
channels_last
and
len
(
in_shape
)
!=
4
:
pytest
.
skip
(
"torch.channels_last only supports 4D tensors"
)
pytest
.
skip
(
"torch.channels_last only supports 4D tensors"
)
maybe_skip_quantization
(
quantization
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
device
=
device
,
dtype
=
dtype
)
with_quantization
=
quantization
is
not
None
with_quantization
=
quantization
is
not
None
# Random data
# Random data
...
@@ -701,7 +719,7 @@ class TestBasicOps:
...
@@ -701,7 +719,7 @@ class TestBasicOps:
# Skip invalid configurations
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
@@ -763,7 +781,7 @@ class TestBasicOps:
...
@@ -763,7 +781,7 @@ class TestBasicOps:
# Skip invalid configurations
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
device
=
device
,
dtype
=
dtype
)
if
quantization
==
"mxfp8"
:
if
quantization
==
"mxfp8"
:
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
)
...
@@ -790,7 +808,7 @@ class TestBasicOps:
...
@@ -790,7 +808,7 @@ class TestBasicOps:
# Implementation with fusible operation
# Implementation with fusible operation
op
=
te_ops
.
Quantize
(
forward
=
cast_forward
,
backward
=
cast_backward
)
op
=
te_ops
.
Quantize
(
forward
=
cast_forward
,
backward
=
cast_backward
)
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_
autocast
(
enabled
=
with_quantization
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
with_quantization
,
recipe
=
recipe
):
y_test
=
op
(
x_test
)
y_test
=
op
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
...
@@ -830,7 +848,7 @@ class TestBasicOps:
...
@@ -830,7 +848,7 @@ class TestBasicOps:
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
# Skip invalid configurations
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
quantization_needed
=
any
(
quantization_needed
=
any
(
(
(
...
@@ -887,7 +905,7 @@ class TestBasicOps:
...
@@ -887,7 +905,7 @@ class TestBasicOps:
# Implementation with fusible operation
# Implementation with fusible operation
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
op
=
te_ops
.
BasicLinear
(
op
=
te_ops
.
BasicLinear
(
in_features
,
in_features
,
out_features
,
out_features
,
...
@@ -904,7 +922,7 @@ class TestBasicOps:
...
@@ -904,7 +922,7 @@ class TestBasicOps:
op
,
op
,
te_ops
.
Quantize
(
forward
=
quantized_output
,
backward
=
quantized_grad_output
),
te_ops
.
Quantize
(
forward
=
quantized_output
,
backward
=
quantized_grad_output
),
)
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
...
@@ -913,7 +931,7 @@ class TestBasicOps:
...
@@ -913,7 +931,7 @@ class TestBasicOps:
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
or
quantized_output
or
quantized_grad_input
:
if
quantized_compute
or
quantized_output
or
quantized_grad_input
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -1024,7 +1042,7 @@ class TestBasicOps:
...
@@ -1024,7 +1042,7 @@ class TestBasicOps:
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
# Skip invalid configurations
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantization
is
None
and
(
quantized_compute
or
quantized_weight
):
if
quantization
is
None
and
(
quantized_compute
or
quantized_weight
):
pytest
.
skip
(
"Quantization scheme is not specified"
)
pytest
.
skip
(
"Quantization scheme is not specified"
)
...
@@ -1065,7 +1083,7 @@ class TestBasicOps:
...
@@ -1065,7 +1083,7 @@ class TestBasicOps:
# Implementation with fusible operation
# Implementation with fusible operation
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
op
=
te_ops
.
Linear
(
op
=
te_ops
.
Linear
(
in_features
,
in_features
,
out_features
,
out_features
,
...
@@ -1081,7 +1099,7 @@ class TestBasicOps:
...
@@ -1081,7 +1099,7 @@ class TestBasicOps:
del
b_test
del
b_test
for
param
in
op
.
parameters
():
for
param
in
op
.
parameters
():
param
.
requires_grad_
(
requires_grad
=
weight_requires_grad
)
param
.
requires_grad_
(
requires_grad
=
weight_requires_grad
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
op
(
x_test
)
y_test
=
op
(
x_test
)
if
input_requires_grad
or
weight_requires_grad
:
if
input_requires_grad
or
weight_requires_grad
:
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
...
@@ -1091,7 +1109,7 @@ class TestBasicOps:
...
@@ -1091,7 +1109,7 @@ class TestBasicOps:
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -1128,7 +1146,7 @@ class TestBasicOps:
...
@@ -1128,7 +1146,7 @@ class TestBasicOps:
in_shape
=
list
(
in_shape
)[:
-
1
]
+
list
(
weight_shape
)
in_shape
=
list
(
in_shape
)[:
-
1
]
+
list
(
weight_shape
)
# Skip invalid configurations
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
@@ -1182,14 +1200,14 @@ class TestBasicOps:
...
@@ -1182,14 +1200,14 @@ class TestBasicOps:
op
,
op
,
te_ops
.
Quantize
(
forward
=
quantized_compute
,
backward
=
False
),
te_ops
.
Quantize
(
forward
=
quantized_compute
,
backward
=
False
),
)
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
:
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -1298,7 +1316,7 @@ class TestBasicOps:
...
@@ -1298,7 +1316,7 @@ class TestBasicOps:
in_shape
=
list
(
in_shape
)[:
-
1
]
+
list
(
weight_shape
)
in_shape
=
list
(
in_shape
)[:
-
1
]
+
list
(
weight_shape
)
# Skip invalid configurations
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
@@ -1344,14 +1362,14 @@ class TestBasicOps:
...
@@ -1344,14 +1362,14 @@ class TestBasicOps:
op
,
op
,
te_ops
.
Quantize
(
forward
=
quantized_compute
,
backward
=
False
),
te_ops
.
Quantize
(
forward
=
quantized_compute
,
backward
=
False
),
)
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
:
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -1431,7 +1449,7 @@ class TestBasicOps:
...
@@ -1431,7 +1449,7 @@ class TestBasicOps:
# Skip invalid configurations
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
# Random data
x1_ref
,
x1_test
=
make_reference_and_test_tensors
(
x1_ref
,
x1_test
=
make_reference_and_test_tensors
(
...
@@ -1470,8 +1488,11 @@ class TestBasicOps:
...
@@ -1470,8 +1488,11 @@ class TestBasicOps:
# Check results
# Check results
tols
=
dtype_tols
(
dtype
)
tols
=
dtype_tols
(
dtype
)
if
with_quantization
:
if
in_place
:
tols
=
dtype_tols
(
x1_test
.
_fp8_dtype
)
if
quantization
in
(
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
):
tols
=
dtype_tols
(
x1_test
.
_fp8_dtype
)
elif
quantization
==
"nvfp4"
:
tols
=
dtype_tols
(
x1_test
.
_fp4_dtype
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx1_test
=
x1_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx1_test
=
x1_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx2_test
=
x2_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx2_test
=
x2_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -1500,7 +1521,7 @@ class TestBasicOps:
...
@@ -1500,7 +1521,7 @@ class TestBasicOps:
# Skip invalid configurations
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
@@ -1573,7 +1594,7 @@ class TestBasicOps:
...
@@ -1573,7 +1594,7 @@ class TestBasicOps:
# Skip invalid configurations
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
if
cache_quantized_input
:
if
cache_quantized_input
:
maybe_skip_quantization
(
"fp8_current_scaling"
,
device
=
device
)
maybe_skip_quantization
(
"fp8_current_scaling"
,
device
=
device
)
...
@@ -1641,14 +1662,16 @@ class TestBasicOps:
...
@@ -1641,14 +1662,16 @@ class TestBasicOps:
make_op
(
cache_quantized_input
=
cache_quantized_input
),
make_op
(
cache_quantized_input
=
cache_quantized_input
),
te_ops
.
Quantize
(
forward
=
quantized_compute
,
backward
=
False
),
te_ops
.
Quantize
(
forward
=
quantized_compute
,
backward
=
False
),
)
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
or
cache_quantized_input
:
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
elif
cache_quantized_input
:
tols
=
quantization_tols
(
"fp8_current_scaling"
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -1679,7 +1702,7 @@ class TestBasicOps:
...
@@ -1679,7 +1702,7 @@ class TestBasicOps:
quantized_compute
=
quantization
is
not
None
quantized_compute
=
quantization
is
not
None
if
not
quantized_compute
and
(
quantize_forward
or
quantize_backward
):
if
not
quantized_compute
and
(
quantize_forward
or
quantize_backward
):
pytest
.
skip
(
"Quantization scheme has not been provided"
)
pytest
.
skip
(
"Quantization scheme has not been provided"
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
@@ -1706,13 +1729,87 @@ class TestBasicOps:
...
@@ -1706,13 +1729,87 @@ class TestBasicOps:
te_ops
.
SwiGLU
(),
te_ops
.
SwiGLU
(),
te_ops
.
Quantize
(
forward
=
quantize_forward
,
backward
=
False
),
te_ops
.
Quantize
(
forward
=
quantize_forward
,
backward
=
False
),
)
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
:
if
quantized_compute
:
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantize_forward"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantize_backward"
,
(
False
,
True
))
def
test_clamped_swiglu
(
self
,
*
,
out_shape
:
Iterable
[
int
]
=
(
32
,
32
),
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantize_forward
:
bool
,
quantize_backward
:
bool
,
limit
:
float
=
0.75
,
alpha
:
float
=
1.702
,
):
# Test SwiGLU variant used in GPT OSS.
# Tensor dimensions
in_shape
=
list
(
out_shape
)
in_shape
[
-
1
]
*=
2
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
if
not
quantized_compute
and
(
quantize_forward
or
quantize_backward
):
pytest
.
skip
(
"Quantization scheme has not been provided"
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
x_glu
,
x_linear
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
x_glu
=
x_glu
.
clamp
(
min
=
None
,
max
=
limit
)
x_linear
=
x_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
out_glu
=
x_glu
*
torch
.
sigmoid
(
alpha
*
x_glu
)
y_ref
=
out_glu
*
(
x_linear
+
1
)
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
recipe
=
make_recipe
(
quantization
)
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
False
,
backward
=
quantize_backward
),
te_ops
.
ClampedSwiGLU
(
limit
=
limit
,
alpha
=
alpha
),
te_ops
.
Quantize
(
forward
=
quantize_forward
,
backward
=
False
),
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
and
quantization
==
"nvfp4"
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat4E2M1
)
elif
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
# Check results
# Check results
...
@@ -1781,7 +1878,7 @@ class TestBasicOps:
...
@@ -1781,7 +1878,7 @@ class TestBasicOps:
# Skip invalid configurations
# Skip invalid configurations
quantized_input
=
quantization
is
not
None
quantized_input
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
shape
,
device
=
device
,
dtype
=
dtype
)
# Random data
# Random data
# Note: Shift values to make sure inputs are non-zero
# Note: Shift values to make sure inputs are non-zero
...
@@ -1872,7 +1969,7 @@ class TestFusedOps:
...
@@ -1872,7 +1969,7 @@ class TestFusedOps:
# Skip invalid configurations
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
pytest
.
skip
(
...
@@ -1913,7 +2010,7 @@ class TestFusedOps:
...
@@ -1913,7 +2010,7 @@ class TestFusedOps:
# Implementation with fusible operations
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
model
=
te_ops
.
Sequential
(
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
te_ops
.
Linear
(
in_features
,
in_features
,
...
@@ -1929,7 +2026,7 @@ class TestFusedOps:
...
@@ -1929,7 +2026,7 @@ class TestFusedOps:
model
[
0
].
bias
.
copy_
(
b_test
)
model
[
0
].
bias
.
copy_
(
b_test
)
del
w_test
del
w_test
del
b_test
del
b_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x_test
)
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
...
@@ -1943,7 +2040,7 @@ class TestFusedOps:
...
@@ -1943,7 +2040,7 @@ class TestFusedOps:
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -1979,7 +2076,7 @@ class TestFusedOps:
...
@@ -1979,7 +2076,7 @@ class TestFusedOps:
# Skip invalid configurations
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
...
@@ -2023,7 +2120,7 @@ class TestFusedOps:
...
@@ -2023,7 +2120,7 @@ class TestFusedOps:
# Implementation with fusible operations
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model
=
te_ops
.
Sequential
(
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
te_ops
.
Linear
(
in_features
,
in_features
,
...
@@ -2040,7 +2137,7 @@ class TestFusedOps:
...
@@ -2040,7 +2137,7 @@ class TestFusedOps:
model
[
0
].
bias
.
copy_
(
b_test
)
model
[
0
].
bias
.
copy_
(
b_test
)
del
w_test
del
w_test
del
b_test
del
b_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x1_test
,
x2_test
)
y_test
=
model
(
x1_test
,
x2_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
...
@@ -2054,7 +2151,7 @@ class TestFusedOps:
...
@@ -2054,7 +2151,7 @@ class TestFusedOps:
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -2093,7 +2190,7 @@ class TestFusedOps:
...
@@ -2093,7 +2190,7 @@ class TestFusedOps:
# Skip invalid configurations
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
...
@@ -2130,7 +2227,7 @@ class TestFusedOps:
...
@@ -2130,7 +2227,7 @@ class TestFusedOps:
# Implementation with fusible operations
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model
=
te_ops
.
Sequential
(
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
te_ops
.
Linear
(
in_features
,
in_features
,
...
@@ -2146,7 +2243,7 @@ class TestFusedOps:
...
@@ -2146,7 +2243,7 @@ class TestFusedOps:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
model
[
0
].
weight
.
copy_
(
w_test
)
del
w_test
del
w_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x1_test
,
x2_test
)
y_test
=
model
(
x1_test
,
x2_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
...
@@ -2161,7 +2258,7 @@ class TestFusedOps:
...
@@ -2161,7 +2258,7 @@ class TestFusedOps:
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -2194,7 +2291,7 @@ class TestFusedOps:
...
@@ -2194,7 +2291,7 @@ class TestFusedOps:
# Skip invalid configurations
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
device
=
device
,
dtype
=
dtype
)
if
quantization
==
"mxfp8"
and
(
len
(
in_shape
)
<
2
or
in_shape
[
-
1
]
%
32
!=
0
):
if
quantization
==
"mxfp8"
and
(
len
(
in_shape
)
<
2
or
in_shape
[
-
1
]
%
32
!=
0
):
pytest
.
skip
(
"Unsupported tensor size for MXFP8"
)
pytest
.
skip
(
"Unsupported tensor size for MXFP8"
)
...
@@ -2237,7 +2334,7 @@ class TestFusedOps:
...
@@ -2237,7 +2334,7 @@ class TestFusedOps:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
[
1
].
bias
.
copy_
(
b_test
)
model
[
1
].
bias
.
copy_
(
b_test
)
del
b_test
del
b_test
with
te
.
fp8_
autocast
(
enabled
=
with_quantization
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
with_quantization
,
recipe
=
recipe
):
y_test
=
model
(
x_test
)
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
...
@@ -2256,7 +2353,7 @@ class TestFusedOps:
...
@@ -2256,7 +2353,7 @@ class TestFusedOps:
# Expected numerical error
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
tols
=
dtype_tols
(
dtype
)
if
with_quantization
:
if
with_quantization
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -2375,7 +2472,7 @@ class TestFusedOps:
...
@@ -2375,7 +2472,7 @@ class TestFusedOps:
# Skip invalid configurations
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
...
@@ -2415,7 +2512,7 @@ class TestFusedOps:
...
@@ -2415,7 +2512,7 @@ class TestFusedOps:
# Implementation with fusible operations
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
):
model
=
te_ops
.
Sequential
(
model
=
te_ops
.
Sequential
(
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
te_ops
.
MakeExtraOutput
(
in_place
=
True
),
te_ops
.
Linear
(
te_ops
.
Linear
(
...
@@ -2429,7 +2526,7 @@ class TestFusedOps:
...
@@ -2429,7 +2526,7 @@ class TestFusedOps:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
[
1
].
weight
.
copy_
(
w_test
)
model
[
1
].
weight
.
copy_
(
w_test
)
del
w_test
del
w_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y1_test
,
y2_test
=
model
(
x_test
)
y1_test
,
y2_test
=
model
(
x_test
)
(
y1_test
*
dy1_test
+
y2_test
*
dy2_test
).
sum
().
backward
()
(
y1_test
*
dy1_test
+
y2_test
*
dy2_test
).
sum
().
backward
()
...
@@ -2443,7 +2540,7 @@ class TestFusedOps:
...
@@ -2443,7 +2540,7 @@ class TestFusedOps:
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
# Check results
y1_test
=
y1_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y1_test
=
y1_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -2479,7 +2576,7 @@ class TestFusedOps:
...
@@ -2479,7 +2576,7 @@ class TestFusedOps:
# Skip invalid configurations
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
if
quantized_compute
and
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
pytest
.
skip
(
"FP8 GEMM is only supported with FP8, FP16, or BF16 output"
)
...
@@ -2511,7 +2608,7 @@ class TestFusedOps:
...
@@ -2511,7 +2608,7 @@ class TestFusedOps:
# Implementation with fusible operations
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
):
model
=
te_ops
.
Sequential
(
model
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
te_ops
.
Linear
(
in_features
,
in_features
,
...
@@ -2525,7 +2622,7 @@ class TestFusedOps:
...
@@ -2525,7 +2622,7 @@ class TestFusedOps:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
model
[
0
].
weight
.
copy_
(
w_test
)
del
w_test
del
w_test
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
model
(
x_test
)
y_test
=
model
(
x_test
)
(
y_test
*
dy_test
).
sum
().
backward
()
(
y_test
*
dy_test
).
sum
().
backward
()
...
@@ -2539,7 +2636,7 @@ class TestFusedOps:
...
@@ -2539,7 +2636,7 @@ class TestFusedOps:
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
if
quantized_compute
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
quantization_tols
(
quantization
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
@@ -2580,12 +2677,12 @@ class TestCheckpointing:
...
@@ -2580,12 +2677,12 @@ class TestCheckpointing:
# Skip invalid configurations
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
# Construct model
# Construct model
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model_save
=
te_ops
.
Sequential
(
model_save
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
)
)
...
@@ -2596,7 +2693,7 @@ class TestCheckpointing:
...
@@ -2596,7 +2693,7 @@ class TestCheckpointing:
x
=
torch
.
randn
(
in_shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
x
=
torch
.
randn
(
in_shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
dy
=
torch
.
randn
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
optim_save
.
zero_grad
()
optim_save
.
zero_grad
()
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y
=
model_save
(
x
)
y
=
model_save
(
x
)
y
.
backward
(
dy
)
y
.
backward
(
dy
)
optim_save
.
step
()
optim_save
.
step
()
...
@@ -2625,14 +2722,14 @@ class TestCheckpointing:
...
@@ -2625,14 +2722,14 @@ class TestCheckpointing:
ys_save
=
[]
ys_save
=
[]
for
i
in
range
(
post_checkpoint_steps
):
for
i
in
range
(
post_checkpoint_steps
):
optim_save
.
zero_grad
()
optim_save
.
zero_grad
()
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y
=
model_save
(
xs_save
[
i
])
y
=
model_save
(
xs_save
[
i
])
y
.
backward
(
dys
[
i
])
y
.
backward
(
dys
[
i
])
optim_save
.
step
()
optim_save
.
step
()
ys_save
.
append
(
y
)
ys_save
.
append
(
y
)
# Load checkpoint
# Load checkpoint
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model_load
=
te_ops
.
Sequential
(
model_load
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
)
)
...
@@ -2645,7 +2742,7 @@ class TestCheckpointing:
...
@@ -2645,7 +2742,7 @@ class TestCheckpointing:
ys_load
=
[]
ys_load
=
[]
for
i
in
range
(
post_checkpoint_steps
):
for
i
in
range
(
post_checkpoint_steps
):
optim_load
.
zero_grad
()
optim_load
.
zero_grad
()
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y
=
model_load
(
xs_load
[
i
])
y
=
model_load
(
xs_load
[
i
])
y
.
backward
(
dys
[
i
])
y
.
backward
(
dys
[
i
])
optim_load
.
step
()
optim_load
.
step
()
...
@@ -2706,7 +2803,7 @@ class TestSequentialModules:
...
@@ -2706,7 +2803,7 @@ class TestSequentialModules:
ffn_shape
=
in_shape
[:
-
1
]
+
(
ffn_hidden_size
,)
ffn_shape
=
in_shape
[:
-
1
]
+
(
ffn_hidden_size
,)
# Skip invalid configurations
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
ffn_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
ffn_shape
,
device
=
device
)
quantization_needed
=
quantized_compute
or
quantized_weight
quantization_needed
=
quantized_compute
or
quantized_weight
if
quantization
is
None
and
quantization_needed
:
if
quantization
is
None
and
quantization_needed
:
...
@@ -2732,7 +2829,7 @@ class TestSequentialModules:
...
@@ -2732,7 +2829,7 @@ class TestSequentialModules:
# Implementation with fusible operations
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
with
te
.
quantized
_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
if
normalization
==
"LayerNorm"
:
if
normalization
==
"LayerNorm"
:
norm
=
te_ops
.
LayerNorm
(
norm
=
te_ops
.
LayerNorm
(
hidden_size
,
hidden_size
,
...
@@ -2763,6 +2860,6 @@ class TestSequentialModules:
...
@@ -2763,6 +2860,6 @@ class TestSequentialModules:
dtype
=
dtype
,
dtype
=
dtype
,
)
)
forward
=
te_ops
.
Sequential
(
norm
,
ffn1
,
act
,
ffn2
)
forward
=
te_ops
.
Sequential
(
norm
,
ffn1
,
act
,
ffn2
)
with
te
.
fp8_
autocast
(
enabled
=
quantized_compute
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
tests/pytorch/test_hf_integration.py
View file @
063ef88d
...
@@ -6,7 +6,7 @@ import pytest
...
@@ -6,7 +6,7 @@ import pytest
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.modeling_utils
import
PreTrainedModel
from
transformer_engine.pytorch
.transformer
import
TransformerLayer
from
transformer_engine.pytorch
import
TransformerLayer
class
SimpleTEModel
(
PreTrainedModel
):
class
SimpleTEModel
(
PreTrainedModel
):
...
...
tests/pytorch/test_multi_tensor.py
View file @
063ef88d
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
pytest
import
pytest
import
torch
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.optimizers
import
MultiTensorApply
from
transformer_engine.pytorch.optimizers
import
MultiTensorApply
...
...
tests/pytorch/test_numerics.py
View file @
063ef88d
...
@@ -13,18 +13,15 @@ import torch.nn as nn
...
@@ -13,18 +13,15 @@ import torch.nn as nn
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.fp8
import
(
from
transformer_engine.pytorch.quantization
import
FP8GlobalStateManager
FP8GlobalStateManager
,
fp8_autocast
,
fp8_model_init
,
)
from
transformer_engine.pytorch.utils
import
(
from
transformer_engine.pytorch.utils
import
(
init_method_normal
,
init_method_normal
,
scaled_init_method_normal
,
scaled_init_method_normal
,
attention_mask_func
,
attention_mask_func
,
is_bf16_compatible
,
)
)
from
transformer_engine.pytorch
import
(
from
transformer_engine.pytorch
import
(
autocast
,
quantized_model_init
,
DotProductAttention
,
DotProductAttention
,
LayerNormLinear
,
LayerNormLinear
,
LayerNormMLP
,
LayerNormMLP
,
...
@@ -36,27 +33,29 @@ from transformer_engine.pytorch import (
...
@@ -36,27 +33,29 @@ from transformer_engine.pytorch import (
LayerNorm
,
LayerNorm
,
Fp8Padding
,
Fp8Padding
,
Fp8Unpadding
,
Fp8Unpadding
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
MXFP8Quantizer
,
get_device_compute_capability
,
is_fp8_available
,
is_mxfp8_available
,
is_fp8_block_scaling_available
,
is_bf16_available
,
)
)
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
,
get_available_attention_backends
from
utils
import
ModelConfig
,
reset_rng_states
,
get_available_attention_backends
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
is_mxfp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
=
is_fp8_block_scaling_available
(
return_reason
=
True
)
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
...
@@ -82,7 +81,7 @@ module_inference = ["TransformerLayer", "MultiheadAttention"]
...
@@ -82,7 +81,7 @@ module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference
=
[
"sbhd"
,
"bshd"
]
input_formats_inference
=
[
"sbhd"
,
"bshd"
]
param_types
=
[
torch
.
float32
,
torch
.
float16
]
param_types
=
[
torch
.
float32
,
torch
.
float16
]
if
is_bf16_
compati
ble
():
# bf16 requires sm_80 or higher
if
is_bf16_
availa
ble
():
# bf16 requires sm_80 or higher
param_types
.
append
(
torch
.
bfloat16
)
param_types
.
append
(
torch
.
bfloat16
)
batch_sizes
=
[
1
,
2
]
batch_sizes
=
[
1
,
2
]
...
@@ -553,7 +552,7 @@ def _test_e2e_selective_recompute(
...
@@ -553,7 +552,7 @@ def _test_e2e_selective_recompute(
init_method
=
init_method_normal
(
sigma
)
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
...
@@ -580,7 +579,7 @@ def _test_e2e_selective_recompute(
...
@@ -580,7 +579,7 @@ def _test_e2e_selective_recompute(
te_inp_hidden_states
.
retain_grad
()
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_seqlen_q
)
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_seqlen_q
)
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
recipe
):
te_out
=
block
(
te_out
=
block
(
te_inp_hidden_states
,
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
,
attention_mask
=
te_inp_attn_mask
,
...
@@ -649,7 +648,7 @@ def _test_e2e_full_recompute(
...
@@ -649,7 +648,7 @@ def _test_e2e_full_recompute(
init_method
=
init_method_normal
(
sigma
)
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
...
@@ -677,7 +676,7 @@ def _test_e2e_full_recompute(
...
@@ -677,7 +676,7 @@ def _test_e2e_full_recompute(
te_inp_hidden_states
.
retain_grad
()
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_seqlen_q
)
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_seqlen_q
)
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
recipe
):
if
recompute
:
if
recompute
:
te_out
=
te_checkpoint
(
te_out
=
te_checkpoint
(
block
,
block
,
...
@@ -1107,7 +1106,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
...
@@ -1107,7 +1106,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
)
)
inp_hidden_states
.
retain_grad
()
inp_hidden_states
.
retain_grad
()
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
recipe
):
out
=
block
(
inp_hidden_states
)
out
=
block
(
inp_hidden_states
)
if
isinstance
(
out
,
(
List
,
Tuple
)):
if
isinstance
(
out
,
(
List
,
Tuple
)):
out
=
out
[
0
]
out
=
out
[
0
]
...
@@ -1328,7 +1327,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
...
@@ -1328,7 +1327,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
te_linear_ref
=
Linear
(
te_linear_ref
=
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
...
@@ -1782,7 +1781,7 @@ def _test_grouped_linear_accuracy(
...
@@ -1782,7 +1781,7 @@ def _test_grouped_linear_accuracy(
else
:
else
:
m_splits
=
torch
.
tensor
([
config
.
max_seqlen_q
])
m_splits
=
torch
.
tensor
([
config
.
max_seqlen_q
])
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
recipe
):
if
isinstance
(
block
,
GroupedLinear
):
if
isinstance
(
block
,
GroupedLinear
):
m_splits
=
m_splits
*
bs
m_splits
=
m_splits
*
bs
out
=
block
(
inp_hidden_states
,
m_splits
.
tolist
())
out
=
block
(
inp_hidden_states
,
m_splits
.
tolist
())
...
@@ -1850,7 +1849,7 @@ def test_grouped_linear_accuracy(
...
@@ -1850,7 +1849,7 @@ def test_grouped_linear_accuracy(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
GroupedLinear
(
grouped_linear
=
GroupedLinear
(
num_gemms
,
num_gemms
,
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -1994,7 +1993,7 @@ def test_grouped_linear_accuracy_save_original_input(
...
@@ -1994,7 +1993,7 @@ def test_grouped_linear_accuracy_save_original_input(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
GroupedLinear
(
grouped_linear
=
GroupedLinear
(
num_gemms
,
num_gemms
,
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -2154,7 +2153,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
...
@@ -2154,7 +2153,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
m_splits
=
_generate_random_numbers
(
num_gemms
,
config
.
max_seqlen_q
*
bs
)
m_splits
=
_generate_random_numbers
(
num_gemms
,
config
.
max_seqlen_q
*
bs
)
with
fp8_
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
fp8
,
recipe
=
recipe
):
if
isinstance
(
block
,
TorchGroupedLinearWithPadding
):
if
isinstance
(
block
,
TorchGroupedLinearWithPadding
):
out
=
block
(
inp_hidden_states
,
m_splits
)
out
=
block
(
inp_hidden_states
,
m_splits
)
else
:
else
:
...
@@ -2208,7 +2207,7 @@ def test_padding_grouped_linear_accuracy(
...
@@ -2208,7 +2207,7 @@ def test_padding_grouped_linear_accuracy(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
TorchGroupedLinearWithPadding
(
grouped_linear
=
TorchGroupedLinearWithPadding
(
num_gemms
,
num_gemms
,
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -2219,7 +2218,7 @@ def test_padding_grouped_linear_accuracy(
...
@@ -2219,7 +2218,7 @@ def test_padding_grouped_linear_accuracy(
fp8
=
fp8
,
fp8
=
fp8
,
).
eval
()
).
eval
()
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
ref_grouped_linear
=
GroupedLinear
(
ref_grouped_linear
=
GroupedLinear
(
num_gemms
,
num_gemms
,
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -2285,7 +2284,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
...
@@ -2285,7 +2284,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
TorchGroupedLinearWithPadding
(
grouped_linear
=
TorchGroupedLinearWithPadding
(
num_gemms
,
num_gemms
,
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -2296,7 +2295,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
...
@@ -2296,7 +2295,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8
=
fp8
,
fp8
=
fp8
,
).
eval
()
).
eval
()
with
fp8
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
ref_grouped_linear
=
GroupedLinear
(
ref_grouped_linear
=
GroupedLinear
(
num_gemms
,
num_gemms
,
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -2446,7 +2445,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
...
@@ -2446,7 +2445,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
init_method
=
init_method_normal
(
sigma
)
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
with
fp8
_model_init
(
enabled
=
fp8_model_params
,
recipe
=
recipe
):
with
quantized
_model_init
(
enabled
=
fp8_model_params
,
recipe
=
recipe
):
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
...
@@ -2473,7 +2472,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
...
@@ -2473,7 +2472,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
te_inp_hidden_states
.
retain_grad
()
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_seqlen_q
)
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_seqlen_q
)
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe
):
with
autocast
(
enabled
=
True
,
recipe
=
recipe
):
te_out
=
block
(
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
)
te_out
=
block
(
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
)
loss
=
te_out
.
sum
()
loss
=
te_out
.
sum
()
loss
.
backward
()
loss
.
backward
()
...
...
tests/pytorch/test_onnx_export.py
View file @
063ef88d
...
@@ -33,10 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
...
@@ -33,10 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.onnx_extensions
import
te_translation_table
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
,
te_translation_table
from
transformer_engine.pytorch.
fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.
quantization
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
get_default_init_method
from
transformer_engine.pytorch.utils
import
get_default_init_method
import
tensorrt
as
trt
import
tensorrt
as
trt
...
@@ -59,8 +58,8 @@ NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
...
@@ -59,8 +58,8 @@ NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
# The directory where this file is stored.
# The directory where this file is stored.
TESTS_DIR
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
TESTS_DIR
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
fp8_recipes
=
[]
fp8_recipes
=
[]
if
mxfp8_available
:
if
mxfp8_available
:
...
@@ -179,8 +178,8 @@ def do_export(
...
@@ -179,8 +178,8 @@ def do_export(
input_names
=
input_names
or
[
"input"
]
input_names
=
input_names
or
[
"input"
]
output_names
=
output_names
or
[
"output"
]
output_names
=
output_names
or
[
"output"
]
with
torch
.
inference_mode
(),
te
.
fp8_
autocast
(
with
torch
.
inference_mode
(),
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
),
warnings
.
catch_warnings
():
),
warnings
.
catch_warnings
():
warnings
.
filterwarnings
(
action
=
"ignore"
,
category
=
torch
.
jit
.
TracerWarning
,
module
=
r
".*"
)
warnings
.
filterwarnings
(
action
=
"ignore"
,
category
=
torch
.
jit
.
TracerWarning
,
module
=
r
".*"
)
...
@@ -234,8 +233,8 @@ def te_infer(
...
@@ -234,8 +233,8 @@ def te_infer(
fp8_recipe
:
recipe
.
Recipe
,
fp8_recipe
:
recipe
.
Recipe
,
):
):
"""Transformer Engine forward propagation."""
"""Transformer Engine forward propagation."""
with
torch
.
inference_mode
(),
te
.
fp8_
autocast
(
with
torch
.
inference_mode
(),
te
.
autocast
(
enabled
=
is_fp8
,
fp8_
recipe
=
fp8_recipe
enabled
=
is_fp8
,
recipe
=
fp8_recipe
),
warnings
.
catch_warnings
():
),
warnings
.
catch_warnings
():
te_outputs
=
model
(
*
inps
if
isinstance
(
inps
,
tuple
)
else
(
inps
,))
te_outputs
=
model
(
*
inps
if
isinstance
(
inps
,
tuple
)
else
(
inps
,))
if
not
isinstance
(
te_outputs
,
tuple
):
if
not
isinstance
(
te_outputs
,
tuple
):
...
@@ -441,7 +440,7 @@ def _test_export_linear(
...
@@ -441,7 +440,7 @@ def _test_export_linear(
bias_str
=
"_bias"
if
use_bias
else
""
bias_str
=
"_bias"
if
use_bias
else
""
high_prec_str
=
dtype2str
(
precision
)
high_prec_str
=
dtype2str
(
precision
)
fname
=
f
"te.linear
{
fp8_str
}{
bias_str
}{
high_prec_str
}
.onnx"
fname
=
f
"te.linear
{
fp8_str
}{
bias_str
}{
high_prec_str
}
.onnx"
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
model
=
Test_Linear
(
in_features
,
out_features
,
use_bias
,
return_bias
,
precision
).
to
(
model
=
Test_Linear
(
in_features
,
out_features
,
use_bias
,
return_bias
,
precision
).
to
(
device
=
"cuda"
device
=
"cuda"
)
)
...
@@ -507,7 +506,7 @@ def _test_export_layernorm(
...
@@ -507,7 +506,7 @@ def _test_export_layernorm(
fname
=
f
"te.layernorm_linear
{
fp8_str
}{
high_prec_str
}
.onnx"
fname
=
f
"te.layernorm_linear
{
fp8_str
}{
high_prec_str
}
.onnx"
with
torch
.
no_grad
():
with
torch
.
no_grad
():
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
layernorm_cls
=
te
.
LayerNorm
if
normalization
==
"LayerNorm"
else
te
.
RMSNorm
layernorm_cls
=
te
.
LayerNorm
if
normalization
==
"LayerNorm"
else
te
.
RMSNorm
model
=
layernorm_cls
(
model
=
layernorm_cls
(
hidden_size
,
hidden_size
,
...
@@ -577,7 +576,7 @@ def _test_export_layernorm_linear(
...
@@ -577,7 +576,7 @@ def _test_export_layernorm_linear(
fname
=
f
"te.layernorm_linear
{
fp8_str
}{
bias_str
}{
high_prec_str
}
.onnx"
fname
=
f
"te.layernorm_linear
{
fp8_str
}{
bias_str
}{
high_prec_str
}
.onnx"
with
torch
.
no_grad
():
with
torch
.
no_grad
():
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
model
=
te
.
LayerNormLinear
(
model
=
te
.
LayerNormLinear
(
hidden_size
,
hidden_size
,
3
*
hidden_size
,
3
*
hidden_size
,
...
@@ -673,7 +672,7 @@ def _test_export_layernorm_mlp(
...
@@ -673,7 +672,7 @@ def _test_export_layernorm_mlp(
bias_str
=
"_bias"
if
use_bias
else
""
bias_str
=
"_bias"
if
use_bias
else
""
high_prec_str
=
dtype2str
(
precision
)
high_prec_str
=
dtype2str
(
precision
)
fname
=
f
"te.layernorm_mlp
{
fp8_str
}{
bias_str
}{
high_prec_str
}
_
{
activation
}
.onnx"
fname
=
f
"te.layernorm_mlp
{
fp8_str
}{
bias_str
}{
high_prec_str
}
_
{
activation
}
.onnx"
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
model
=
te
.
LayerNormMLP
(
model
=
te
.
LayerNormMLP
(
hidden_size
,
hidden_size
,
ffn_hidden_size
,
ffn_hidden_size
,
...
@@ -1215,13 +1214,13 @@ def test_trt_integration(fp8_recipe: recipe.Recipe):
...
@@ -1215,13 +1214,13 @@ def test_trt_integration(fp8_recipe: recipe.Recipe):
).
eval
()
).
eval
()
inps
=
(
torch
.
randn
([
16
,
16
,
128
],
device
=
"cuda"
,
requires_grad
=
False
),)
inps
=
(
torch
.
randn
([
16
,
16
,
128
],
device
=
"cuda"
,
requires_grad
=
False
),)
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
out_ref
=
model
(
*
inps
)
out_ref
=
model
(
*
inps
)
onnx_fd
,
onnx_path
=
tempfile
.
mkstemp
(
suffix
=
".onnx"
)
onnx_fd
,
onnx_path
=
tempfile
.
mkstemp
(
suffix
=
".onnx"
)
os
.
close
(
onnx_fd
)
os
.
close
(
onnx_fd
)
try
:
try
:
with
te
.
fp8_
autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_
recipe
=
fp8_recipe
):
with
te
.
autocast
(
enabled
=
fp8_recipe
is
not
None
,
recipe
=
fp8_recipe
):
with
te
.
onnx_export
(
enabled
=
True
):
with
te
.
onnx_export
(
enabled
=
True
):
torch
.
onnx
.
export
(
torch
.
onnx
.
export
(
model
,
model
,
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment