Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b05e121
Commit
2b05e121
authored
Jun 17, 2025
by
yuguo
Browse files
Merge commit '
a69692ac
' of...
Merge commit '
a69692ac
' of
https://github.com/NVIDIA/TransformerEngine
parents
0fd441c2
a69692ac
Changes
245
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1356 additions
and
349 deletions
+1356
-349
tests/pytorch/test_cpu_offloading.py
tests/pytorch/test_cpu_offloading.py
+5
-0
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+120
-0
tests/pytorch/test_float8_current_scaling_exact.py
tests/pytorch/test_float8_current_scaling_exact.py
+4
-4
tests/pytorch/test_float8blockwisetensor.py
tests/pytorch/test_float8blockwisetensor.py
+42
-1
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+251
-175
tests/pytorch/test_hf_integration.py
tests/pytorch/test_hf_integration.py
+40
-0
tests/pytorch/test_jit.py
tests/pytorch/test_jit.py
+59
-0
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+26
-0
tests/pytorch/test_qk_norm.py
tests/pytorch/test_qk_norm.py
+242
-0
tests/pytorch/test_recipe.py
tests/pytorch/test_recipe.py
+135
-1
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+81
-1
tests/pytorch/utils.py
tests/pytorch/utils.py
+22
-0
transformer_engine/__init__.py
transformer_engine/__init__.py
+47
-2
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+5
-1
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+64
-90
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+26
-27
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
...common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
+41
-19
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
...engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
+2
-1
transformer_engine/common/common.cu
transformer_engine/common/common.cu
+39
-17
transformer_engine/common/common.h
transformer_engine/common/common.h
+105
-10
No files found.
tests/pytorch/test_cpu_offloading.py
View file @
2b05e121
...
...
@@ -97,6 +97,8 @@ def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload
max_mem_used
=
torch
.
cuda
.
memory_allocated
()
/
(
1024
**
2
)
torch
.
cuda
.
synchronize
()
tensor
.
sum
().
backward
()
return
max_mem_used
...
...
@@ -115,6 +117,9 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""
import
gc
gc
.
collect
()
model_cls
=
model_types
[
model_key
]
models_list
=
[
model_cls
()
for
_
in
range
(
NUM_LAYERS
)]
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
2b05e121
...
...
@@ -88,6 +88,126 @@ def initialize_for_many_scales(
return
result
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
# full tile cases
(
128
,
128
),
(
256
,
256
),
(
256
,
1024
),
(
1024
,
256
),
# Padding required cases
(
256
,
272
),
(
303
,
300
),
(
305
,
256
),
# Some larger tiles.
(
2000
,
2000
),
(
2048
,
2000
),
(
2000
,
1024
),
(
2048
,
1024
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
True
],
ids
=
[
"pow2scales"
])
def
test_quantization_1D_block_tiling_with_compact_data_and_scales
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
quant_dtype
:
torch
.
dtype
,
eps
:
float
,
pow_2_scales
:
bool
,
)
->
None
:
te_dtype
=
TE_DType
[
quant_dtype
]
tile_size
=
(
1
,
128
)
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
ref_quantizer
=
BlockwiseQuantizerReference
()
sut_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
eps
,
force_pow_2_scales
=
pow_2_scales
,
block_scaling_dim
=
1
,
all_gather_usage
=
True
,
)
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Input
x
=
initialize_for_many_scales
((
M
,
N
),
tile_size
,
dtype
=
x_dtype
,
device
=
device
)
x_fp8_sut
=
sut_quantizer
.
make_empty
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_fp8_sut
=
sut_quantizer
.
update_quantized
(
x
,
x_fp8_sut
)
x_fp8_sut_cpp_alloc
=
sut_quantizer
(
x
)
assert
x_fp8_sut
.
_rowwise_data
is
not
None
qx
:
torch
.
Tensor
=
x_fp8_sut
.
_rowwise_data
.
view
(
dtype
=
quant_dtype
)
assert
x_fp8_sut
.
_rowwise_scale_inv
is
not
None
sx
:
torch
.
Tensor
=
x_fp8_sut
.
_rowwise_scale_inv
qx_t
=
x_fp8_sut
.
_columnwise_data
sx_t
=
x_fp8_sut
.
_columnwise_scale_inv
qresult_ref
=
ref_quantizer
.
quantize
(
x
,
quant_dtype
=
quant_dtype
,
return_transpose
=
True
,
eps
=
eps
,
pow_2_scales
=
pow_2_scales
,
quant_tile_shape
=
tile_size
,
munge_scale_shapes
=
False
,
)
qx_ref
,
sx_ref
,
qx_t_ref
,
sx_t_ref
=
(
qresult_ref
.
data
,
qresult_ref
.
scale
,
qresult_ref
.
data_t
,
qresult_ref
.
scale_t
,
)
# match the reference quantize transpose output with the columnwise non-transpose method
qx_t_ref
=
qx_t_ref
.
transpose
(
-
1
,
-
2
).
contiguous
()
sx_t_ref
=
sx_t_ref
.
transpose
(
-
1
,
-
2
).
contiguous
()
# Check
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
assert
qx_t
is
not
None
qx_t
=
qx_t
.
view
(
dtype
=
quant_dtype
)
assert
qx_t_ref
is
not
None
assert
sx_t
is
not
None
assert
sx_t_ref
is
not
None
torch
.
testing
.
assert_close
(
qx_t
.
float
(),
qx_t_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx_t
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
# check that the C++ and Python allocators are equivalent
torch
.
testing
.
assert_close
(
x_fp8_sut
.
_rowwise_data
,
x_fp8_sut_cpp_alloc
.
_rowwise_data
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
x_fp8_sut
.
_rowwise_scale_inv
,
x_fp8_sut_cpp_alloc
.
_rowwise_scale_inv
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
x_fp8_sut
.
_columnwise_data
,
x_fp8_sut_cpp_alloc
.
_columnwise_data
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
x_fp8_sut
.
_columnwise_scale_inv
,
x_fp8_sut_cpp_alloc
.
_columnwise_scale_inv
,
atol
=
0.0
,
rtol
=
0.0
,
)
# check if the fp8 output between C++ and Python are the same
assert
x_fp8_sut
.
_data_format
==
x_fp8_sut_cpp_alloc
.
_data_format
def
check_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
...
...
tests/pytorch/test_float8_current_scaling_exact.py
View file @
2b05e121
...
...
@@ -385,7 +385,7 @@ class TestFP8RecipeLinearBase:
)
# recipe1
using_fp8_recipe
=
recipe1
()
is
not
None
using_fp8_recipe
=
recipe1
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe1
()):
y_q_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
...
...
@@ -393,7 +393,7 @@ class TestFP8RecipeLinearBase:
y_q_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
# recipe2
using_fp8_recipe
=
recipe2
!=
GetRecipes
.
none
using_fp8_recipe
=
recipe2
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe2
()):
y_q
,
dgrad
,
wgrad
,
bgrad
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
...
...
@@ -608,7 +608,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
)
# recipe1
using_fp8_recipe
=
recipe1
()
is
not
None
using_fp8_recipe
=
recipe1
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe1
()):
y_q_ref
,
ln_out_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_layernorm_linear
(
...
...
@@ -630,7 +630,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
)
# recipe2
using_fp8_recipe
=
recipe2
!=
GetRecipes
.
none
using_fp8_recipe
=
recipe2
()
!=
GetRecipes
.
none
()
if
using_fp8_recipe
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe2
()):
y_q
,
ln_out
,
dgrad
,
wgrad
,
bgrad
=
self
.
run_layernorm_linear
(
...
...
tests/pytorch/test_float8blockwisetensor.py
View file @
2b05e121
...
...
@@ -176,7 +176,40 @@ class TestFloat8BlockwiseTensor:
)
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"dq_columnwise"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"all_gather_usage"
,
[
True
,
False
])
def
test_quantize_dequantize_dims
(
self
,
dims
:
DimsType
,
block_scaling_dim
:
int
,
dq_columnwise
:
bool
,
all_gather_usage
:
bool
,
)
->
None
:
if
all_gather_usage
and
block_scaling_dim
!=
1
:
pytest
.
skip
(
"all_gather_usage only implemented for 1D block quantization."
)
atol
=
_tols
[
tex
.
DType
.
kFloat8E4M3
][
"atol"
]
rtol
=
_tols
[
tex
.
DType
.
kFloat8E4M3
][
"rtol"
]
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
rowwise
=
True
,
columnwise
=
dq_columnwise
,
block_scaling_dim
=
block_scaling_dim
,
all_gather_usage
=
all_gather_usage
,
)
self
.
_test_quantize_dequantize
(
quantizer
=
quantizer
,
dims
=
dims
,
atol
=
atol
,
rtol
=
rtol
,
dequant_columnwise
=
dq_columnwise
,
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[],
256
,
311
,
[
264
],
[
256
,
512
],
[
250
,
500
],
[
7
,
5
,
3
],
[
2
,
3
,
5
,
3
]]
)
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"dq_columnwise"
,
[
True
,
False
])
@
pytest
.
mark
.
xfail
(
raises
=
NotImplementedError
)
def
test_quantize_dequantize_compact_format
(
self
,
dims
:
DimsType
,
block_scaling_dim
:
int
,
dq_columnwise
:
bool
)
->
None
:
atol
=
_tols
[
tex
.
DType
.
kFloat8E4M3
][
"atol"
]
...
...
@@ -186,6 +219,7 @@ class TestFloat8BlockwiseTensor:
rowwise
=
True
,
columnwise
=
dq_columnwise
,
block_scaling_dim
=
block_scaling_dim
,
all_gather_usage
=
True
,
)
self
.
_test_quantize_dequantize
(
quantizer
=
quantizer
,
...
...
@@ -250,8 +284,13 @@ class TestFloat8BlockwiseTensor:
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
def
test_serialization
(
self
,
dims
:
DimsType
,
block_scaling_dim
:
int
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"all_gather_usage"
,
[
True
,
False
])
def
test_serialization
(
self
,
dims
:
DimsType
,
block_scaling_dim
:
int
,
all_gather_usage
:
bool
)
->
None
:
"""Test serialization of Float8BlockwiseQTensor"""
if
all_gather_usage
and
block_scaling_dim
!=
1
:
pytest
.
skip
(
"all_gather_usage only implemented for 1D block quantization."
)
device
=
"cuda"
dtype
=
torch
.
bfloat16
x_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
...
...
@@ -260,6 +299,7 @@ class TestFloat8BlockwiseTensor:
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
block_scaling_dim
,
all_gather_usage
=
all_gather_usage
,
)
# Create FP8 tensor
...
...
@@ -283,6 +323,7 @@ class TestFloat8BlockwiseTensor:
assert
x_fp8_loaded
.
_is_2D_scaled
==
x_fp8
.
_is_2D_scaled
assert
x_fp8_loaded
.
dtype
==
x_fp8
.
dtype
assert
x_fp8_loaded
.
_fp8_dtype
==
x_fp8
.
_fp8_dtype
assert
x_fp8_loaded
.
_data_format
==
x_fp8
.
_data_format
# Test that dequantized values match
x_fp8_dequant
=
x_fp8
.
dequantize
()
...
...
tests/pytorch/test_fusible_ops.py
View file @
2b05e121
...
...
@@ -7,6 +7,8 @@ from __future__ import annotations
from
collections.abc
import
Iterable
import
io
import
math
import
pathlib
import
sys
from
typing
import
Optional
import
pytest
...
...
@@ -25,10 +27,20 @@ from transformer_engine.pytorch.ops.fused import (
ForwardLinearBiasAdd
,
)
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Tensor
,
Float8CurrentScalingQuantizer
,
Float8Quantizer
,
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
,
MXFP8Quantizer
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine_torch
as
tex
# Import utility functions
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
))
from
utils
import
dtype_tols
,
make_recipe
if
IS_HIP_EXTENSION
:
import
os
from
functools
import
cache
...
...
@@ -49,6 +61,13 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
# Supported devices
_devices
:
list
[
torch
.
device
]
=
[
torch
.
device
(
"cpu"
),
torch
.
device
(
"cuda"
)]
# Supported quantization recipes
_quantization_list
:
list
[
Optional
[
str
]]
=
[
None
]
if
fp8_available
:
_quantization_list
.
extend
((
"fp8_delayed_scaling"
,
"fp8_current_scaling"
))
if
mxfp8_available
:
_quantization_list
.
append
(
"mxfp8"
)
def
maybe_skip_quantization
(
quantization
:
Optional
[
str
],
...
...
@@ -56,13 +75,14 @@ def maybe_skip_quantization(
dims
:
Optional
[
Iterable
[
int
]
|
int
]
=
None
,
device
:
Optional
[
torch
.
device
|
str
]
=
None
,
)
->
None
:
"""Skip test case if a quantization scheme is not supported"""
# Don't skip if there is no quantization
if
quantization
is
None
:
return
# Check if quantization scheme is supported
if
quantization
==
"fp8"
and
not
fp8_available
:
if
quantization
in
(
"fp8"
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
)
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
...
...
@@ -70,7 +90,7 @@ def maybe_skip_quantization(
if
dims
is
not
None
:
if
not
isinstance
(
dims
,
Iterable
):
dims
=
(
dims
,)
if
quantization
==
"fp8"
:
if
quantization
in
(
"fp8"
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
)
:
if
math
.
prod
(
dims
[:
-
1
])
%
16
!=
0
or
dims
[
-
1
]
%
16
!=
0
:
pytest
.
skip
(
"FP8 GEMMs require dims that are divisible by 16"
)
elif
quantization
==
"mxfp8"
:
...
...
@@ -82,47 +102,15 @@ def maybe_skip_quantization(
pytest
.
skip
(
"Quantization is only supported on CUDA devices"
)
def
dtype_tols
(
dtype
:
torch
.
dtype
|
tex
.
DType
)
->
dict
[
str
,
float
]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if
isinstance
(
dtype
,
tex
.
DType
):
if
dtype
==
tex
.
DType
.
kFloat8E4M3
:
return
dict
(
rtol
=
0.125
,
atol
=
0.0675
)
# epsilon = 0.0625
if
dtype
==
tex
.
DType
.
kFloat8E5M2
:
return
dict
(
rtol
=
0.25
,
atol
=
0.125
)
# epsilon = 0.152
dtype
=
{
tex
.
DType
.
kByte
:
torch
.
uint8
,
tex
.
DType
.
kInt32
:
torch
.
int32
,
tex
.
DType
.
kFloat32
:
torch
.
float32
,
tex
.
DType
.
kFloat16
:
torch
.
half
,
tex
.
DType
.
kBFloat16
:
torch
.
bfloat16
,
}[
dtype
]
# PyTorch dtypes
if
dtype
==
torch
.
float16
:
return
dict
(
rtol
=
1e-3
,
atol
=
1e-5
)
if
dtype
==
torch
.
bfloat16
:
return
dict
(
rtol
=
1.6e-2
,
atol
=
1e-5
)
if
dtype
==
torch
.
float32
:
return
dict
(
rtol
=
1.3e-6
,
atol
=
1e-5
)
if
dtype
==
torch
.
float64
:
return
dict
(
rtol
=
1e-7
,
atol
=
1e-7
)
raise
ValueError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
@
torch
.
no_grad
()
def
make_reference_and_test_tensors
(
shape
:
int
|
Iterable
[
int
],
quantization
:
Optional
[
str
]
=
None
,
ref_dtype
:
torch
.
dtype
=
torch
.
float64
,
ref_device
:
torch
.
device
=
"cpu"
,
test_dtype
:
torch
.
dtype
=
torch
.
float32
,
test_device
:
torch
.
device
=
"cuda"
,
test_is_
fp8
:
bool
=
False
,
test_is_
quantized
:
bool
=
False
,
requires_grad
:
bool
=
True
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Construct tensors with the same values
...
...
@@ -131,39 +119,49 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random reference tensor
ref
=
torch
.
rand
(
shape
,
dtype
=
ref_dtype
,
device
=
ref_device
)
# Construct test tensor from reference tensor
test
=
ref
.
to
(
device
=
test_device
,
dtype
=
test_dtype
)
if
test_is_fp8
:
if
quantization
is
None
:
if
test_is_quantized
:
raise
ValueError
(
"Quantization scheme not provided"
)
if
test
.
data_ptr
()
==
ref
.
data_ptr
():
test
=
test
.
clone
()
elif
quantization
in
(
"fp8"
,
"fp8_delayed_scaling"
):
quantizer
=
Float8Quantizer
(
scale
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
test_device
).
squeeze
(),
amax
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
test_device
),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
test
=
quantizer
(
test
)
elif
test
.
data_ptr
()
==
ref
.
data_ptr
():
test
=
test
.
clone
()
elif
quantization
==
"fp8_current_scaling"
:
quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
test_device
,
)
test
=
quantizer
(
test
)
elif
quantization
==
"mxfp8"
:
test
=
MXFP8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)(
test
)
else
:
raise
ValueError
(
f
"Unsupported quantization scheme (
{
quantization
}
)"
)
if
isinstance
(
test
,
QuantizedTensor
)
and
not
test_is_quantized
:
test
=
test
.
dequantize
()
# Make sure reference and test tensors match each other
ref
.
copy_
(
test
)
ref
.
requires_grad_
(
requires_grad
)
test
.
requires_grad_
(
requires_grad
)
return
ref
,
test
def
make_recipe
(
name
:
Optional
[
str
]
=
None
)
->
Optional
[
Recipe
]:
"""Make recipe for quantization scheme"""
if
name
is
None
:
return
None
if
name
==
"fp8"
:
return
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
,
)
if
name
==
"mxfp8"
:
return
transformer_engine
.
common
.
recipe
.
MXFP8BlockScaling
(
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
,
)
raise
ValueError
(
f
"Unsupported quantization scheme (
{
name
}
)"
)
class
TestSequential
:
"""Tests for sequential container"""
...
...
@@ -373,7 +371,7 @@ class TestFuser:
@
pytest
.
mark
.
parametrize
(
"init_dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"final_dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_dtype_cast
(
self
,
*
,
...
...
@@ -386,8 +384,9 @@ class TestFuser:
"""Check dtype cast functions"""
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
device
=
devic
e
)
in_shape
=
(
size
,
siz
e
)
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
# Random data
dtype
=
torch
.
float32
...
...
@@ -397,9 +396,9 @@ class TestFuser:
dtype
=
torch
.
bfloat16
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
size
,
size
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
with_quantization
,
)
# Construct operation
...
...
@@ -421,11 +420,11 @@ class TestFuser:
assert
isinstance
(
op
.
weight
,
QuantizedTensor
)
==
with_quantization
assert
op
.
weight
.
dtype
==
final_dtype
w_test
=
op
.
weight
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
w_test
,
w_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
w_test
,
w_ref
,
**
dtype_tols
(
dtype
)
)
# Check forward and backward pass
x
=
torch
.
zeros
(
(
size
,
size
)
,
in_shape
,
dtype
=
init_dtype
,
device
=
device
,
requires_grad
=
True
,
...
...
@@ -438,7 +437,7 @@ class TestFuser:
@
pytest
.
mark
.
parametrize
(
"model_dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"autocast_dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_pyt_autocast
(
self
,
*
,
...
...
@@ -453,8 +452,9 @@ class TestFuser:
device
=
torch
.
device
(
device
)
# Skip invalid configurations
in_shape
=
(
size
,
size
)
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
# Construct operation
recipe
=
make_recipe
(
quantization
)
...
...
@@ -463,7 +463,7 @@ class TestFuser:
# Check forward and backward pass
x
=
torch
.
zeros
(
(
size
,
size
)
,
in_shape
,
dtype
=
model_dtype
,
device
=
device
,
requires_grad
=
True
,
...
...
@@ -501,33 +501,34 @@ class TestBasicOps:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
@
pytest
.
mark
.
parametrize
(
"
fp8"
,
(
False
,
True
)
)
@
pytest
.
mark
.
parametrize
(
"
quantization"
,
_quantization_list
)
def
test_identity
(
self
,
*
,
in_shape
:
Iterable
[
int
]
=
(
1
,
),
in_shape
:
Iterable
[
int
]
=
(
32
,
32
),
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
fp8
:
bool
,
quantization
:
Optional
[
str
]
,
)
->
None
:
# Skip invalid configurations
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
torch
.
device
(
device
).
type
!=
"cuda"
:
pytest
.
skip
(
"FP8 is only supported on CUDA devices"
)
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_
fp8
=
fp8
,
test_is_
quantized
=
with_quantization
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_quantized
=
with_quantization
,
requires_grad
=
False
,
)
...
...
@@ -563,7 +564,7 @@ class TestBasicOps:
),
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"
fp8"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"
quantization"
,
(
None
,
"fp8_current_scaling"
))
def
test_reshape
(
self
,
*
,
...
...
@@ -571,31 +572,32 @@ class TestBasicOps:
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
memory_format
:
torch
.
memory_format
=
torch
.
contiguous_format
,
fp8
:
bool
,
quantization
:
Optional
[
str
]
,
)
->
None
:
in_shape
,
out_shape
=
shapes
# Skip invalid configurations
if
memory_format
==
torch
.
channels_last
and
len
(
in_shape
)
!=
4
:
pytest
.
skip
(
"torch.channels_last only supports 4D tensors"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
torch
.
device
(
device
).
type
!=
"cuda"
:
pytest
.
skip
(
"FP8 is only supported on CUDA devices"
)
maybe_skip_quantization
(
quantization
,
device
=
device
)
with_quantization
=
quantization
is
not
None
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_
fp8
=
fp8
,
test_is_
quantized
=
with_quantization
,
)
x_test
=
x_test
.
contiguous
(
memory_format
=
memory_format
)
x_test
=
x_test
.
detach
().
requires_grad_
()
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
x_ref
.
reshape
(
out_shape
).
size
(),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_quantized
=
with_quantization
,
requires_grad
=
False
,
)
...
...
@@ -624,10 +626,10 @@ class TestBasicOps:
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"size"
,
(
1
,
7
,
32
))
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
-
1
,),
(
1
,
3
,
-
1
),
(
2
,
3
,
4
,
-
1
)))
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
-
1
,),
(
1
,
3
,
-
1
),
(
4
,
3
,
8
,
-
1
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
_devices
)
@
pytest
.
mark
.
parametrize
(
"
fp8"
,
(
False
,
True
)
)
@
pytest
.
mark
.
parametrize
(
"
quantization"
,
_quantization_list
)
def
test_bias
(
self
,
*
,
...
...
@@ -635,24 +637,23 @@ class TestBasicOps:
in_shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
fp8
:
bool
,
quantization
:
Optional
[
str
]
,
)
->
None
:
# Make input and bias shapes consistent
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
size
]
# Skip invalid configurations
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
torch
.
device
(
device
).
type
!=
"cuda"
:
pytest
.
skip
(
"FP8 is only supported on CUDA devices"
)
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_
fp8
=
fp8
,
test_is_
quantized
=
with_quantization
,
)
b_ref
,
b_test
=
make_reference_and_test_tensors
(
size
,
...
...
@@ -661,8 +662,10 @@ class TestBasicOps:
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_quantized
=
with_quantization
,
requires_grad
=
False
,
)
...
...
@@ -687,7 +690,7 @@ class TestBasicOps:
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
db_test
,
b_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"cast_forward"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"cast_backward"
,
(
False
,
True
))
def
test_quantize
(
...
...
@@ -703,25 +706,26 @@ class TestBasicOps:
"""Quantize"""
# Skip invalid configurations
maybe_skip_quantization
(
quantization
)
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
device
=
device
)
if
quantization
==
"mxfp8"
:
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
test_is_fp8
=
True
,
requires_grad
=
True
,
)
x_test
=
x_test
.
dequantize
().
requires_grad_
()
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
test_is_fp8
=
True
,
)
dy_test
=
dy_test
.
dequantize
()
# Plain PyTorch implementation
y_ref
=
x_ref
...
...
@@ -730,11 +734,12 @@ class TestBasicOps:
# Implementation with fusible operation
op
=
te_ops
.
Quantize
(
forward
=
cast_forward
,
backward
=
cast_backward
)
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_autocast
(
fp8_recipe
=
recipe
):
with
te
.
fp8_autocast
(
enabled
=
with_quantization
,
fp8_recipe
=
recipe
):
y_test
=
op
(
x_test
)
y_test
.
backward
(
dy_test
)
# Check tensor types
if
with_quantization
:
assert
isinstance
(
y_test
,
QuantizedTensor
)
==
cast_forward
assert
isinstance
(
x_test
.
grad
,
QuantizedTensor
)
==
cast_backward
...
...
@@ -771,9 +776,24 @@ class TestBasicOps:
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantization
==
"fp8"
and
quantized_output
and
not
quantized_compute
:
quantization_needed
=
any
(
(
quantized_compute
,
quantized_input
,
quantized_weight
,
quantized_output
,
quantized_grad_output
,
quantized_grad_input
,
)
)
if
quantization
is
None
and
quantization_needed
:
pytest
.
skip
(
"Quantization scheme is not specified"
)
if
quantization
is
not
None
and
not
quantization_needed
:
pytest
.
skip
(
"Quantization scheme is not used"
)
if
quantization
in
(
"fp8"
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
):
if
quantized_output
and
not
quantized_compute
:
pytest
.
skip
(
"FP8 output is only supported with FP8 GEMMs"
)
if
quantization
==
"fp8"
and
quantized_grad_input
and
not
quantized_compute
:
if
quantized_grad_input
and
not
quantized_compute
:
pytest
.
skip
(
"FP8 grad input is only supported with FP8 GEMMs"
)
if
quantization
==
"mxfp8"
and
quantized_output
:
pytest
.
skip
(
"MXFP8 output is not supported with MXFP8 GEMMs"
)
...
...
@@ -786,28 +806,25 @@ class TestBasicOps:
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_
fp8
=
(
quantized
_compute
or
quantized_input
)
,
test_is_quantized
=
quantized_input
,
)
if
isinstance
(
x_test
,
QuantizedTensor
):
with
torch
.
no_grad
():
x_test
=
x_test
.
dequantize
().
requires_grad_
()
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
(
quantized_compute
or
quantized_weight
),
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_
fp8
=
(
quantized
_compute
or
quantized_grad_output
)
,
test_is_quantized
=
quantized_grad_output
,
requires_grad
=
False
,
)
if
isinstance
(
dy_test
,
QuantizedTensor
):
dy_test
=
dy_test
.
dequantize
()
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
linear
(
x_ref
,
w_ref
)
...
...
@@ -870,7 +887,7 @@ class TestBasicOps:
@
pytest
.
mark
.
parametrize
(
"weight_shape"
,
((
64
,
32
),
(
3
,
5
)))
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
-
1
,),
(
5
,
1
,
-
1
),
(
4
,
2
,
4
,
-
1
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"accumulate_into_main_grad"
,
(
False
,
True
))
def
test_basic_linear
(
self
,
...
...
@@ -892,7 +909,7 @@ class TestBasicOps:
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantized_compute"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantized_input"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
...
...
@@ -911,6 +928,8 @@ class TestBasicOps:
quantized_grad_input
:
bool
,
)
->
None
:
"""GEMM with FP8 inputs and outputs"""
if
quantization
is
None
:
pytest
.
skip
(
"Skipping case without quantization"
)
self
.
_test_basic_linear
(
dtype
=
torch
.
bfloat16
,
quantization
=
quantization
,
...
...
@@ -923,8 +942,11 @@ class TestBasicOps:
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
))
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantized_compute"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"input_requires_grad"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"weight_requires_grad"
,
(
False
,
True
))
def
test_linear
(
self
,
*
,
...
...
@@ -934,7 +956,10 @@ class TestBasicOps:
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantized_compute
:
bool
,
quantized_weight
:
bool
,
input_requires_grad
:
bool
,
weight_requires_grad
:
bool
,
)
->
None
:
"""GEMM + bias"""
...
...
@@ -944,25 +969,25 @@ class TestBasicOps:
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantization
is
None
and
(
quantized_compute
or
quantized_weight
):
pytest
.
skip
(
"Quantization scheme is not specified"
)
if
quantization
is
not
None
and
not
(
quantized_compute
or
quantized_weight
):
pytest
.
skip
(
"Quantization scheme is not used"
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
quantized_compute
,
)
if
isinstance
(
x_test
,
QuantizedTensor
):
with
torch
.
no_grad
():
x_test
=
x_test
.
dequantize
().
requires_grad_
()
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
(
quantized_compute
or
quantized_weight
),
)
b_ref
,
b_test
=
None
,
None
if
bias
:
...
...
@@ -973,6 +998,7 @@ class TestBasicOps:
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
...
...
@@ -998,8 +1024,11 @@ class TestBasicOps:
op
.
bias
.
copy_
(
b_test
)
del
w_test
del
b_test
for
param
in
op
.
parameters
():
param
.
requires_grad_
(
requires_grad
=
weight_requires_grad
)
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
y_test
=
op
(
x_test
)
if
input_requires_grad
or
weight_requires_grad
:
y_test
.
backward
(
dy_test
)
# Expected numerical error
...
...
@@ -1011,10 +1040,12 @@ class TestBasicOps:
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
op
.
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
if
input_requires_grad
:
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
if
weight_requires_grad
:
dw_test
=
op
.
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
if
bias
:
db_test
=
op
.
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -1024,7 +1055,7 @@ class TestBasicOps:
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
-
1
,),
(
6
,
16
,
-
1
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_layer_norm
(
self
,
*
,
...
...
@@ -1194,7 +1225,7 @@ class TestBasicOps:
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
-
1
,),
(
6
,
16
,
-
1
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_rmsnorm
(
self
,
*
,
...
...
@@ -1275,16 +1306,68 @@ class TestBasicOps:
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
32
,),
(
6
,
16
,
64
),
(
32
,
64
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
def
test_l2normalization
(
self
,
*
,
in_shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
eps
:
float
=
1e-6
,
)
->
None
:
"""L2 Normalization"""
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
# L2 norm: x / ||x||_2 = x / sqrt(sum(x^2) + eps)
l2_norm_squared
=
x_ref
.
pow
(
2
).
sum
(
dim
=-
1
,
keepdim
=
True
)
rsqrt_norm
=
torch
.
rsqrt
(
l2_norm_squared
+
eps
)
y_ref
=
x_ref
*
rsqrt_norm
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
op
=
te_ops
.
L2Normalization
(
eps
=
eps
,
)
y_test
=
op
(
x_test
)
y_test
.
backward
(
dy_test
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
# L2Norm backward pass requires slightly looser atol for bfloat16
if
dtype
==
torch
.
bfloat16
:
tols
[
"atol"
]
=
2e-3
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
@
pytest
.
mark
.
parametrize
(
"
fp8"
,
(
False
,
True
)
)
@
pytest
.
mark
.
parametrize
(
"
quantization"
,
_quantization_list
)
def
test_add_in_place
(
self
,
*
,
in_shape
:
Iterable
[
int
]
=
(
1
,
),
in_shape
:
Iterable
[
int
]
=
(
32
,
32
),
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
fp8
:
bool
,
quantization
:
Optional
[
str
]
,
)
->
None
:
"""Add two tensors
...
...
@@ -1293,28 +1376,30 @@ class TestBasicOps:
"""
# Skip invalid configurations
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
torch
.
device
(
device
).
type
!=
"cuda"
:
pytest
.
skip
(
"FP8 is only supported on CUDA devices"
)
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
# Random data
x1_ref
,
x1_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_
fp8
=
fp8
,
test_is_
quantized
=
with_quantization
,
)
x2_ref
,
x2_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_
fp8
=
fp8
,
test_is_
quantized
=
with_quantization
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_quantized
=
with_quantization
,
requires_grad
=
False
,
)
...
...
@@ -1331,7 +1416,7 @@ class TestBasicOps:
# Check results
tols
=
dtype_tols
(
dtype
)
if
fp8
:
if
with_quantization
:
tols
=
dtype_tols
(
x1_test
.
_fp8_dtype
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx1_test
=
x1_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -1342,14 +1427,14 @@ class TestBasicOps:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
(
"cuda"
,
"cpu"
))
@
pytest
.
mark
.
parametrize
(
"
fp8"
,
(
False
,
True
)
)
@
pytest
.
mark
.
parametrize
(
"
quantization"
,
_quantization_list
)
def
test_make_extra_output
(
self
,
*
,
in_shape
:
Iterable
[
int
]
=
(
1
,
),
in_shape
:
Iterable
[
int
]
=
(
32
,
32
),
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
fp8
:
bool
,
quantization
:
Optional
[
str
]
,
)
->
None
:
"""Output tensor twice
...
...
@@ -1358,28 +1443,31 @@ class TestBasicOps:
"""
# Skip invalid configurations
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
torch
.
device
(
device
).
type
!=
"cuda"
:
pytest
.
skip
(
"FP8 is only supported on CUDA devices"
)
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_
fp8
=
fp8
,
test_is_
quantized
=
with_quantization
,
)
dy1_ref
,
dy1_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_quantized
=
with_quantization
,
requires_grad
=
False
,
)
dy2_ref
,
dy2_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_quantized
=
with_quantization
,
requires_grad
=
False
,
)
...
...
@@ -1405,7 +1493,7 @@ class TestBasicOps:
@
pytest
.
mark
.
parametrize
(
"activation"
,
(
"relu"
,
"gelu"
,
"geglu"
,
"reglu"
,
"swiglu"
))
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
37
,),
(
2
,
13
),
(
32
,
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"cache_quantized_input"
,
(
False
,
True
))
def
test_activation
(
self
,
...
...
@@ -1428,26 +1516,21 @@ class TestBasicOps:
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
if
cache_quantized_input
:
maybe_skip_quantization
(
"fp8"
,
device
=
device
)
maybe_skip_quantization
(
"fp8
_current_scaling
"
,
device
=
device
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
"fp8_current_scaling"
if
cache_quantized_input
else
None
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
quantized_compute
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
quantized_compute
,
requires_grad
=
False
,
)
if
quantized_compute
:
with
torch
.
no_grad
():
x_test
=
x_test
.
dequantize
().
requires_grad_
()
dy_test
=
dy_test
.
dequantize
()
# Plain PyTorch implementation
y_ref
:
torch
.
Tensor
...
...
@@ -1490,8 +1573,6 @@ class TestBasicOps:
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
or
cache_quantized_input
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
if
activation
==
"relu"
and
not
cache_quantized_input
:
tols
=
{
"atol"
:
0
,
"rtol"
:
0
}
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
...
...
@@ -1500,7 +1581,7 @@ class TestBasicOps:
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantize_forward"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantize_backward"
,
(
False
,
True
))
def
test_swiglu
(
...
...
@@ -1578,7 +1659,7 @@ class TestFusedOps:
@
pytest
.
mark
.
parametrize
(
"weight_shape"
,
((
32
,
64
),
(
3
,
5
)))
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
-
1
,),
(
1
,
7
,
-
1
),
(
8
,
2
,
10
,
-
1
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
def
test_forward_linear_bias_activation
(
self
,
...
...
@@ -1610,18 +1691,15 @@ class TestFusedOps:
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
quantized_compute
,
)
if
quantized_compute
:
with
torch
.
no_grad
():
x_test
=
x_test
.
dequantize
().
requires_grad_
()
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
(
quantized_compute
or
quantized_weight
),
)
b_ref
,
b_test
=
None
,
None
if
bias
:
...
...
@@ -1632,6 +1710,7 @@ class TestFusedOps:
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
...
...
@@ -1688,7 +1767,7 @@ class TestFusedOps:
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_forward_linear_bias_add
(
self
,
*
,
...
...
@@ -1717,18 +1796,15 @@ class TestFusedOps:
# Random data
x1_ref
,
x1_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
quantized_compute
,
)
if
isinstance
(
x1_test
,
QuantizedTensor
):
with
torch
.
no_grad
():
x1_test
=
x1_test
.
dequantize
().
requires_grad_
()
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
(
quantized_compute
or
quantized_weight
),
)
b_ref
,
b_test
=
None
,
None
if
bias
:
...
...
@@ -1744,6 +1820,7 @@ class TestFusedOps:
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
...
...
@@ -1802,7 +1879,7 @@ class TestFusedOps:
torch
.
testing
.
assert_close
(
db_test
,
b_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_backward_linear_add
(
self
,
*
,
...
...
@@ -1830,27 +1907,26 @@ class TestFusedOps:
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
quantized_compute
,
)
if
isinstance
(
x_test
,
QuantizedTensor
):
with
torch
.
no_grad
():
x_test
=
x_test
.
dequantize
().
requires_grad_
()
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
(
quantized_compute
or
quantized_weight
),
)
dy1_ref
,
dy1_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
dy2_ref
,
dy2_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
...
...
@@ -1914,7 +1990,7 @@ class TestCheckpointing:
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
)
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
def
test_linear
(
self
,
...
...
tests/pytorch/test_hf_integration.py
0 → 100644
View file @
2b05e121
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_utils
import
PreTrainedModel
from
transformer_engine.pytorch.transformer
import
TransformerLayer
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
class
SimpleTEModel
(
PreTrainedModel
):
config_class
=
PretrainedConfig
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
(
config
)
self
.
my_layer
=
TransformerLayer
(
hidden_size
=
320
,
num_attention_heads
=
16
,
ffn_hidden_size
=
1024
,
layer_number
=
None
,
)
def
forward
(
self
,
hidden_states
,
attention_mask
):
return
self
.
my_layer
(
hidden_states
,
attention_mask
)
def
test_save_hf_model
(
tmp_path
):
model
=
SimpleTEModel
(
PretrainedConfig
())
model
.
save_pretrained
(
tmp_path
/
"simple_te_model"
)
@
pytest
.
mark
.
xfail
(
reason
=
"This test is failing until huggingface/transformers#38155 is merged."
)
def
test_save_and_load_hf_model
(
tmp_path
):
model
=
SimpleTEModel
(
PretrainedConfig
())
model
.
save_pretrained
(
tmp_path
/
"simple_te_model"
)
del
model
model
=
SimpleTEModel
.
from_pretrained
(
tmp_path
/
"simple_te_model"
)
assert
model
is
not
None
tests/pytorch/test_jit.py
View file @
2b05e121
...
...
@@ -63,3 +63,62 @@ def test_lazy_compile():
from
transformer_engine.pytorch.jit
import
dgelu_fused_
dgelu_fused_
(
torch
.
randn
(
10
,
10
),
torch
.
randn
(
10
,
10
))
def
test_l2normalization_fused
():
"""Smoke test for L2Normalization fusion functions."""
from
transformer_engine.pytorch.jit
import
(
l2normalization_fused
,
l2normalization_fwd_fused
,
l2normalization_backward_fused
,
)
# Basic smoke test like other JIT functions
x
=
torch
.
randn
(
10
,
128
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
eps
=
1e-6
# Test inference version
output_inf
=
l2normalization_fused
(
x
,
eps
)
# Test training version with backward
x_train
=
torch
.
randn
(
10
,
128
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
requires_grad
=
True
)
output_train
,
rsqrt_norm
=
l2normalization_fwd_fused
(
x_train
,
eps
)
grad_output
=
torch
.
randn_like
(
output_train
)
grad_input
=
l2normalization_backward_fused
(
grad_output
,
x_train
,
rsqrt_norm
,
eps
)
def
test_l2normalization_fused_correctness
():
"""Simple verification that L2Normalization fusion matches reference implementation."""
from
transformer_engine.pytorch.jit
import
(
l2normalization_fwd_fused
,
l2normalization_backward_fused
,
)
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
x
=
torch
.
randn
(
16
,
64
,
device
=
device
,
dtype
=
torch
.
float32
,
requires_grad
=
True
)
eps
=
1e-6
# Test fused forward
output_fused
,
rsqrt_norm
=
l2normalization_fwd_fused
(
x
,
eps
)
# Reference implementation
x_ref
=
x
.
clone
().
detach
().
requires_grad_
(
True
)
x_squared
=
x_ref
.
pow
(
2
)
l2_norm_squared
=
x_squared
.
sum
(
dim
=-
1
,
keepdim
=
True
)
rsqrt_norm_ref
=
torch
.
rsqrt
(
l2_norm_squared
+
eps
)
output_ref
=
x_ref
*
rsqrt_norm_ref
# Check forward pass matches
torch
.
testing
.
assert_close
(
output_fused
,
output_ref
,
atol
=
1e-6
,
rtol
=
1e-5
)
torch
.
testing
.
assert_close
(
rsqrt_norm
,
rsqrt_norm_ref
,
atol
=
1e-6
,
rtol
=
1e-5
)
# Test fused backward
grad_output
=
torch
.
randn_like
(
output_fused
)
grad_input_fused
=
l2normalization_backward_fused
(
grad_output
,
x
,
rsqrt_norm
,
eps
)
# Reference backward
output_ref
.
backward
(
grad_output
)
grad_input_ref
=
x_ref
.
grad
# Check backward pass matches
torch
.
testing
.
assert_close
(
grad_input_fused
,
grad_input_ref
,
atol
=
1e-5
,
rtol
=
1e-4
)
tests/pytorch/test_numerics.py
View file @
2b05e121
...
...
@@ -106,6 +106,20 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types
=
[
"causal"
,
"no_mask"
]
NVTE_TEST_NVINSPECT_ENABLED
=
os
.
environ
.
get
(
"NVTE_TEST_NVINSPECT_ENABLED"
,
False
)
if
NVTE_TEST_NVINSPECT_ENABLED
:
# The numerics of all the layers should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import
nvdlfw_inspect.api
as
debug_api
debug_api
.
initialize
(
os
.
environ
[
"NVTE_TEST_NVINSPECT_CONFIG_FILE"
],
feature_dirs
=
os
.
environ
[
"NVTE_TEST_NVINSPECT_FEATURE_DIRS"
],
)
fp8_recipes
=
[
recipe
.
MXFP8BlockScaling
(),
recipe
.
DelayedScaling
(),
...
...
@@ -572,6 +586,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
...
...
@@ -686,6 +702,8 @@ def test_gpt_full_activation_recompute(
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
...
...
@@ -1730,6 +1748,8 @@ def test_grouped_linear_accuracy(
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
...
...
@@ -1934,6 +1954,8 @@ def test_padding_grouped_linear_accuracy(
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
...
...
@@ -2049,6 +2071,8 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
def
test_gpt_cuda_graph
(
dtype
,
bs
,
model
):
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"Cuda Graphs are not supported in debug mode."
)
config
=
model_configs
[
model
]
sigma
=
0.023
...
...
@@ -2146,6 +2170,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
...
...
tests/pytorch/test_qk_norm.py
0 → 100644
View file @
2b05e121
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
transformer_engine.pytorch
import
MultiheadAttention
import
pytest
import
torch
@
pytest
.
mark
.
parametrize
(
"use_qk_norm"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"attention_type"
,
[
"self"
,
"cross"
])
@
pytest
.
mark
.
parametrize
(
"qk_norm_eps"
,
[
1e-6
,
1e-5
])
def
test_qk_norm_functionality
(
use_qk_norm
,
attention_type
,
qk_norm_eps
)
->
None
:
"""Test QK normalization functionality, module structure, and numerical behavior."""
hidden_size
=
256
num_attention_heads
=
8
seq_len
=
128
# Create MultiheadAttention module
mha
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
attention_type
=
attention_type
,
use_qk_norm
=
use_qk_norm
,
qk_norm_eps
=
qk_norm_eps
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Check module structure based on use_qk_norm parameter
if
use_qk_norm
:
assert
hasattr
(
mha
,
"qk_norm"
),
"Should have qk_norm module when use_qk_norm=True"
assert
not
hasattr
(
mha
,
"q_l2norm"
),
"Should not have separate q_l2norm module"
assert
not
hasattr
(
mha
,
"k_l2norm"
),
"Should not have separate k_l2norm module"
# Check that the module is L2Norm type
from
transformer_engine.pytorch.ops.basic.l2normalization
import
L2Normalization
assert
isinstance
(
mha
.
qk_norm
,
L2Normalization
),
"qk_norm should be an L2Normalization module"
else
:
assert
not
hasattr
(
mha
,
"qk_norm"
),
"Should not have qk_norm module when use_qk_norm=False"
# Create input tensors
batch_size
=
2
# Use a fixed batch size for testing
hidden_states
=
torch
.
randn
(
seq_len
,
batch_size
,
hidden_size
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
if
attention_type
==
"cross"
:
encoder_output
=
torch
.
randn
(
seq_len
,
batch_size
,
hidden_size
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
else
:
encoder_output
=
None
# Test forward pass
with
torch
.
no_grad
():
if
attention_type
==
"cross"
:
output
=
mha
(
hidden_states
,
encoder_output
=
encoder_output
)
else
:
output
=
mha
(
hidden_states
)
# Check output shape and numerical properties
assert
output
.
shape
==
(
seq_len
,
batch_size
,
hidden_size
,
),
f
"Output shape mismatch:
{
output
.
shape
}
"
assert
not
torch
.
isnan
(
output
).
any
(),
"Output contains NaN"
assert
not
torch
.
isinf
(
output
).
any
(),
"Output contains Inf"
# Test with RoPE (if self-attention)
if
attention_type
==
"self"
:
head_dim
=
hidden_size
//
num_attention_heads
rotary_dim
=
head_dim
//
2
rotary_pos_emb
=
torch
.
randn
(
seq_len
,
1
,
1
,
rotary_dim
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
with
torch
.
no_grad
():
output_with_rope
=
mha
(
hidden_states
,
rotary_pos_emb
=
rotary_pos_emb
)
assert
output_with_rope
.
shape
==
(
seq_len
,
batch_size
,
hidden_size
,
),
"Output shape with RoPE mismatch"
assert
not
torch
.
isnan
(
output_with_rope
).
any
(),
"RoPE output contains NaN"
assert
not
torch
.
isinf
(
output_with_rope
).
any
(),
"RoPE output contains Inf"
def
test_qk_norm_output_difference
()
->
None
:
"""Test that QK normalization actually changes the output compared to no normalization."""
hidden_size
=
256
num_attention_heads
=
8
seq_len
=
128
batch_size
=
2
# Use same random seed to ensure identical weight initialization
current_rng_state
=
torch
.
get_rng_state
()
current_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
# Reset to a known seed for reproducible initialization
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
# Create model with QK normalization
mha_with_norm
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
use_qk_norm
=
True
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Reset to same seed for identical initialization
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
# Create identical model without QK normalization
mha_no_norm
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
use_qk_norm
=
False
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Create input tensors
hidden_states
=
torch
.
randn
(
seq_len
,
batch_size
,
hidden_size
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# Compare outputs with identical weights but different QK norm settings
with
torch
.
no_grad
():
output_with_norm
=
mha_with_norm
(
hidden_states
)
output_no_norm
=
mha_no_norm
(
hidden_states
)
# Outputs should be different when QK normalization is enabled
assert
not
torch
.
allclose
(
output_with_norm
,
output_no_norm
,
atol
=
1e-6
),
"QK normalization should change the output, but outputs are identical"
def
test_qk_norm_with_fused_qkv
()
->
None
:
"""Test QK normalization works with fused QKV parameters."""
hidden_size
=
256
num_attention_heads
=
8
seq_len
=
64
mha
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
fuse_qkv_params
=
True
,
use_qk_norm
=
True
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Create input and test forward pass
batch_size
=
2
# Use a fixed batch size for testing
hidden_states
=
torch
.
randn
(
seq_len
,
batch_size
,
hidden_size
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
with
torch
.
no_grad
():
output
=
mha
(
hidden_states
)
assert
output
.
shape
==
(
seq_len
,
batch_size
,
hidden_size
,
),
f
"Output shape mismatch:
{
output
.
shape
}
"
def
test_qk_norm_transformer_layer_output_difference
()
->
None
:
"""Test that QK normalization actually changes TransformerLayer output compared to no normalization."""
from
transformer_engine.pytorch
import
TransformerLayer
hidden_size
=
256
ffn_hidden_size
=
1024
num_attention_heads
=
8
seq_len
=
128
batch_size
=
2
# Use same random seed to ensure identical weight initialization
current_rng_state
=
torch
.
get_rng_state
()
current_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
# Reset to a known seed for reproducible initialization
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
# Create TransformerLayer with QK normalization
transformer_with_norm
=
TransformerLayer
(
hidden_size
=
hidden_size
,
ffn_hidden_size
=
ffn_hidden_size
,
num_attention_heads
=
num_attention_heads
,
use_qk_norm
=
True
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Reset to same seed for identical initialization
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
# Create identical TransformerLayer without QK normalization
transformer_no_norm
=
TransformerLayer
(
hidden_size
=
hidden_size
,
ffn_hidden_size
=
ffn_hidden_size
,
num_attention_heads
=
num_attention_heads
,
use_qk_norm
=
False
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Create input tensors
hidden_states
=
torch
.
randn
(
seq_len
,
batch_size
,
hidden_size
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# Compare outputs with identical weights but different QK norm settings
with
torch
.
no_grad
():
output_with_norm
=
transformer_with_norm
(
hidden_states
)
output_no_norm
=
transformer_no_norm
(
hidden_states
)
# Outputs should be different when QK normalization is enabled
assert
not
torch
.
allclose
(
output_with_norm
,
output_no_norm
,
atol
=
1e-6
),
"QK normalization should change the TransformerLayer output, but outputs are identical"
# Check that outputs have expected shapes and properties
assert
output_with_norm
.
shape
==
(
seq_len
,
batch_size
,
hidden_size
,
),
f
"Output shape mismatch:
{
output_with_norm
.
shape
}
"
assert
not
torch
.
isnan
(
output_with_norm
).
any
(),
"Output with QK norm contains NaN"
assert
not
torch
.
isinf
(
output_with_norm
).
any
(),
"Output with QK norm contains Inf"
assert
not
torch
.
isnan
(
output_no_norm
).
any
(),
"Output without QK norm contains NaN"
assert
not
torch
.
isinf
(
output_no_norm
).
any
(),
"Output without QK norm contains Inf"
tests/pytorch/test_recipe.py
View file @
2b05e121
...
...
@@ -6,22 +6,32 @@ from typing import Iterable, Optional
import
pytest
import
torch
import
warnings
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
_amax_and_scale_update
,
get_default_fp8_recipe
,
fp8_model_init
,
)
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch
import
Linear
,
LayerNormLinear
,
LayerNormMLP
,
GroupedLinear
from
transformer_engine.pytorch.distributed
import
fp8_autocast
from
transformer_engine.common.recipe
import
DelayedScaling
,
Float8BlockScaling
,
MXFP8BlockScaling
import
transformer_engine_torch
as
tex
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
# FP8 per tensor delayed scaling
...
...
@@ -368,3 +378,127 @@ class TestFP8Recipe:
)
torch
.
testing
.
assert_close
(
fp8_meta
[
forward_key
].
scale
,
expected_scale
)
@
pytest
.
mark
.
parametrize
(
"model_init_recipe"
,
[
pytest
.
param
(
MXFP8BlockScaling
(),
marks
=
pytest
.
mark
.
skipif
(
not
mxfp8_available
,
reason
=
reason_for_no_mxfp8
),
),
pytest
.
param
(
Float8BlockScaling
(),
marks
=
pytest
.
mark
.
skipif
(
not
fp8_block_scaling_available
,
reason
=
reason_for_no_fp8_block_scaling
),
),
],
)
def
test_check_for_weight_tensor_and_recipe_correspondence
(
self
,
model_init_recipe
):
with
fp8_model_init
(
enabled
=
True
,
recipe
=
model_init_recipe
):
linear
=
Linear
(
32
,
32
).
cuda
()
x
=
torch
.
randn
(
32
,
32
,
device
=
"cuda"
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
DelayedScaling
()):
with
pytest
.
raises
(
RuntimeError
)
as
excinfo
:
_
=
linear
(
x
)
assert
"Recipe mismatch for "
in
str
(
excinfo
.
value
)
@
pytest
.
mark
.
parametrize
(
"target_recipe_class, expected_quantizer_type, available_flag, reason"
,
[
pytest
.
param
(
MXFP8BlockScaling
,
MXFP8Quantizer
,
mxfp8_available
,
reason_for_no_mxfp8
,
id
=
"DelayedScaling->MXFP8BlockScaling"
,
),
pytest
.
param
(
Float8BlockScaling
,
Float8BlockQuantizer
,
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
,
id
=
"DelayedScaling->Float8BlockScaling"
,
),
],
)
def
test_dynamic_recipe_update
(
self
,
target_recipe_class
,
expected_quantizer_type
,
available_flag
,
reason
):
if
not
available_flag
:
pytest
.
skip
(
reason
)
in_features
=
32
out_features
=
32
batch_size
=
32
linear
=
Linear
(
in_features
,
out_features
).
cuda
()
initial_recipe
=
DelayedScaling
()
# Run initial iterations with DelayedScaling
for
_
in
range
(
3
):
x
=
torch
.
randn
(
batch_size
,
in_features
,
device
=
"cuda"
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
initial_recipe
):
y
=
linear
(
x
)
loss
=
y
.
mean
()
loss
.
backward
()
for
quantizer
in
linear
.
quantizers
[
"scaling_fwd"
]:
assert
isinstance
(
quantizer
,
Float8Quantizer
)
# Change recipe
target_recipe
=
target_recipe_class
()
# Run subsequent iterations with the target recipe
for
i
in
range
(
3
):
x
=
torch
.
randn
(
batch_size
,
in_features
,
device
=
"cuda"
)
if
i
==
0
:
# Expect a warning on the first iteration with the new recipe
with
pytest
.
warns
(
UserWarning
,
match
=
"Recipe type changed"
):
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
target_recipe
):
y
=
linear
(
x
)
for
quantizer
in
linear
.
quantizers
[
"scaling_fwd"
]:
assert
isinstance
(
quantizer
,
expected_quantizer_type
)
else
:
# No warning expected on subsequent iterations
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"error"
)
# Raise error if unexpected warning occurs
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
target_recipe
):
y
=
linear
(
x
)
loss
=
y
.
mean
()
loss
.
backward
()
# Final check
for
quantizer
in
linear
.
quantizers
[
"scaling_fwd"
]:
assert
isinstance
(
quantizer
,
expected_quantizer_type
)
@
pytest
.
mark
.
parametrize
(
"module_class"
,
[
Linear
,
LayerNormLinear
,
LayerNormMLP
,
GroupedLinear
,
],
)
def
test_quantizer_update
(
self
,
module_class
):
in_features
=
32
out_features
=
32
batch_size
=
32
recipe
=
DelayedScaling
(
amax_history_len
=
1024
)
with
fp8_model_init
(
recipe
=
recipe
):
if
module_class
==
GroupedLinear
:
module
=
module_class
(
1
,
in_features
,
out_features
).
cuda
()
else
:
module
=
module_class
(
in_features
,
out_features
).
cuda
()
x
=
torch
.
randn
(
batch_size
,
in_features
,
device
=
"cuda"
)
recipe
=
DelayedScaling
(
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe
):
warn_msg
=
"Quantizer is being updated, this may affect model behavior"
with
pytest
.
warns
(
UserWarning
,
match
=
warn_msg
):
if
module_class
==
GroupedLinear
:
y
=
module
(
x
,
[
batch_size
])
else
:
y
=
module
(
x
)
tests/pytorch/test_sanity.py
View file @
2b05e121
...
...
@@ -11,6 +11,7 @@ import pytest
import
os
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine.pytorch
from
transformer_engine.pytorch.fp8
import
(
fp8_autocast
,
FP8GlobalStateManager
,
...
...
@@ -39,9 +40,11 @@ from transformer_engine.pytorch.cpp_extensions import general_gemm
from
transformer_engine.pytorch.module.base
import
get_workspace
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
Float8Quantizer
,
Float8Tensor
,
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.distributed
import
checkpoint
from
test_numerics
import
reset_rng_states
,
dtype_tols
...
...
@@ -1349,3 +1352,80 @@ def test_sanity_checkpointing_on_callables():
# Assert that gradients are the same
torch
.
testing
.
assert_close
(
grad_checkpoint
,
grad_standard
)
@
pytest
.
mark
.
parametrize
(
"module_name"
,
(
"Linear"
,
"LayerNormLinear"
,
"LayerNormMLP"
,
"GroupedLinear"
,
"ops.Linear"
),
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
),
)
def
test_inference_mode
(
module_name
:
str
,
quantization
:
Optional
[
str
],
)
->
None
:
"""Test heuristics for initializing quantized weights"""
# Tensor dimensions
sequence_length
=
32
hidden_size
=
32
# Skip invalid configurations
if
quantization
in
(
"fp8_delayed_scaling"
,
"fp8_current_scaling"
)
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
# Construct quantization recipe
with_quantization
=
quantization
not
in
(
None
,
"None"
)
quantization_recipe
=
None
if
quantization
==
"fp8_delayed_scaling"
:
quantization_recipe
=
recipe
.
DelayedScaling
()
elif
quantization
==
"fp8_current_scaling"
:
quantization_recipe
=
recipe
.
Float8CurrentScaling
()
elif
quantization
==
"mxfp8"
:
quantization_recipe
=
recipe
.
MXFP8BlockScaling
()
# Construct module
module
=
None
with
torch
.
no_grad
():
with
fp8_model_init
(
enabled
=
with_quantization
,
recipe
=
quantization_recipe
):
if
module_name
==
"Linear"
:
module
=
Linear
(
hidden_size
,
hidden_size
)
elif
module_name
==
"LayerNormLinear"
:
module
=
LayerNormLinear
(
hidden_size
,
hidden_size
)
elif
module_name
==
"LayerNormMLP"
:
module
=
LayerNormMLP
(
hidden_size
,
hidden_size
)
elif
module_name
==
"GroupedLinear"
:
module
=
GroupedLinear
(
1
,
hidden_size
,
hidden_size
)
elif
module_name
==
"ops.Linear"
:
module
=
transformer_engine
.
pytorch
.
ops
.
Linear
(
hidden_size
,
hidden_size
)
def
check_weights
():
"""Helper function to check that weight parameters have expected data"""
for
param
in
module
.
parameters
():
if
isinstance
(
param
,
Float8Tensor
):
assert
param
.
_data
is
not
None
,
"Missing FP8 data"
assert
(
param
.
_transpose
is
None
and
param
.
_transpose_invalid
),
"FP8 transpose is not expected for inference"
if
isinstance
(
param
,
MXFP8Tensor
):
assert
param
.
_rowwise_data
is
not
None
,
"Missing row-wise MXFP8 data"
assert
(
param
.
_columnwise_data
is
None
),
"Column-wise MXFP8 data is not expected for inference"
# Check that modules have expected weights after initialization
check_weights
()
# Check that modules have expected weights after forward pass
with
torch
.
inference_mode
():
x
=
torch
.
zeros
(
sequence_length
,
hidden_size
,
device
=
"cuda"
)
kwargs
=
{}
if
module_name
==
"GroupedLinear"
:
kwargs
[
"m_splits"
]
=
[
sequence_length
]
with
fp8_autocast
(
enabled
=
with_quantization
,
fp8_recipe
=
quantization_recipe
):
y
=
module
(
x
,
**
kwargs
)
check_weights
()
tests/pytorch/utils.py
View file @
2b05e121
...
...
@@ -7,6 +7,7 @@ from __future__ import annotations
import
torch
import
transformer_engine
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
...
...
@@ -83,3 +84,24 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
if
dtype
==
torch
.
float8_e5m2
:
return
dict
(
rtol
=
0.25
,
atol
=
0.125
)
# epsilon = 0.152
raise
ValueError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
def
make_recipe
(
name
:
Optional
[
str
])
->
Optional
[
Recipe
]:
"""Make recipe for quantization scheme"""
if
name
is
None
:
return
None
if
name
in
(
"fp8"
,
"fp8_delayed_scaling"
):
return
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
,
)
if
name
==
"fp8_current_scaling"
:
return
transformer_engine
.
common
.
recipe
.
Float8CurrentScaling
(
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
,
)
if
name
==
"mxfp8"
:
return
transformer_engine
.
common
.
recipe
.
MXFP8BlockScaling
(
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
,
)
if
name
==
"fp8_block_scaling"
:
return
transformer_engine
.
common
.
recipe
.
Float8BlockScaling
()
raise
ValueError
(
f
"Unsupported quantization scheme (
{
name
}
)"
)
transformer_engine/__init__.py
View file @
2b05e121
...
...
@@ -6,17 +6,62 @@
# pylint: disable=unused-import
import
os
from
importlib
import
metadata
import
transformer_engine.common
try
:
from
.
import
pytorch
except
ImportError
as
e
:
except
ImportError
:
pass
except
FileNotFoundError
as
e
:
if
"Could not find shared object file"
not
in
str
(
e
):
raise
e
# Unexpected error
else
:
if
os
.
getenv
(
"NVTE_FRAMEWORK"
):
frameworks
=
os
.
getenv
(
"NVTE_FRAMEWORK"
).
split
(
","
)
if
"pytorch"
in
frameworks
or
"all"
in
frameworks
:
raise
e
else
:
# If we got here, we could import `torch` but could not load the framework extension.
# This can happen when a user wants to work only with `transformer_engine.jax` on a system that
# also has a PyTorch installation. In order to enable that use case, we issue a warning here
# about the missing PyTorch extension in case the user hasn't set NVTE_FRAMEWORK.
import
warnings
warnings
.
warn
(
"Detected a PyTorch installation but could not find the shared object file for the "
"Transformer Engine PyTorch extension library. If this is not intentional, please "
"reinstall Transformer Engine with `pip install transformer_engine[pytorch]` or "
"build from source with `NVTE_FRAMEWORK=pytorch`."
,
category
=
RuntimeWarning
,
)
try
:
from
.
import
jax
except
ImportError
as
e
:
except
ImportError
:
pass
except
FileNotFoundError
as
e
:
if
"Could not find shared object file"
not
in
str
(
e
):
raise
e
# Unexpected error
else
:
if
os
.
getenv
(
"NVTE_FRAMEWORK"
):
frameworks
=
os
.
getenv
(
"NVTE_FRAMEWORK"
).
split
(
","
)
if
"jax"
in
frameworks
or
"all"
in
frameworks
:
raise
e
else
:
# If we got here, we could import `jax` but could not load the framework extension.
# This can happen when a user wants to work only with `transformer_engine.pytorch` on a system
# that also has a Jax installation. In order to enable that use case, we issue a warning here
# about the missing Jax extension in case the user hasn't set NVTE_FRAMEWORK.
import
warnings
warnings
.
warn
(
"Detected a Jax installation but could not find the shared object file for the "
"Transformer Engine Jax extension library. If this is not intentional, please "
"reinstall Transformer Engine with `pip install transformer_engine[jax]` or "
"build from source with `NVTE_FRAMEWORK=jax`."
,
category
=
RuntimeWarning
,
)
__version__
=
str
(
metadata
.
version
(
"transformer_engine"
))
transformer_engine/common/CMakeLists.txt
View file @
2b05e121
...
...
@@ -30,7 +30,9 @@ endif()
# Language options
if
(
USE_CUDA
)
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
elseif
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120
)
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90
)
...
...
@@ -149,6 +151,7 @@ if(USE_CUDA)
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
...
...
@@ -201,6 +204,7 @@ else()
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
...
...
transformer_engine/common/__init__.py
View file @
2b05e121
...
...
@@ -4,25 +4,26 @@
"""FW agnostic user-end APIs"""
import
sys
import
glob
import
sysconfig
import
subprocess
import
ctypes
import
functools
import
glob
import
importlib
from
importlib.metadata
import
version
,
metadata
,
PackageNotFoundError
import
logging
import
os
import
platform
import
importlib
import
functools
from
pathlib
import
Path
from
importlib.metadata
import
version
,
metadata
,
PackageNotFoundError
import
platform
import
subprocess
import
sys
import
sysconfig
from
typing
import
Optional
_logger
=
logging
.
getLogger
(
__name__
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_is_pip_package_installed
(
package
):
def
_is_pip_package_installed
(
package
)
->
bool
:
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
...
...
@@ -37,37 +38,37 @@ def _is_pip_package_installed(package):
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_find_shared_object_in_te_dir
(
te_path
:
Path
,
prefix
:
str
):
def
_find_shared_object_in_te_dir
(
te_path
:
Path
,
prefix
:
str
)
->
Optional
[
Path
]
:
"""
Find a shared object file
of
given prefix in the top level TE directory.
Only the following locations are searched to avoid stray SOs and build
artifacts
:
1. T
he given t
op level directory (editable install).
2. `transformer_engine`
named
director
ies
(source install).
3. `wheel_lib`
named
director
ies
(PyPI install).
Find a shared object file
with the
given prefix
with
in the top level TE directory.
The following locations are searched
:
1. Top level directory (editable install).
2. `transformer_engine` director
y
(source install).
3. `wheel_lib` director
y
(PyPI install).
Returns None if no shared object files are found.
Raises an error if multiple shared object files are found.
"""
# Ensure top level dir exists and has the module
.
before searching.
if
not
te_path
.
exists
()
or
not
(
te_path
/
"transformer_engine"
).
exists
():
# Ensure top level dir exists and has the module before searching.
if
not
te_path
.
is_dir
()
or
not
(
te_path
/
"transformer_engine"
).
exists
():
return
None
files
=
[]
search_paths
=
(
te_path
,
te_path
/
"transformer_engine"
,
te_path
/
"transformer_engine/wheel_lib"
,
te_path
/
"wheel_lib"
,
te_path
,
# Editable build.
te_path
/
"transformer_engine"
,
# Regular source build.
te_path
/
"transformer_engine/wheel_lib"
,
# PyPI.
)
# Search.
for
dirname
,
_
,
names
in
os
.
walk
(
te_path
):
if
Path
(
dirname
)
in
search_paths
:
for
name
in
names
:
if
name
.
startswith
(
prefix
)
and
name
.
endswith
(
f
".
{
_get_sys_extension
()
}
"
):
files
.
append
(
Path
(
dirname
,
name
))
for
dir_path
in
search_paths
:
if
not
dir_path
.
is_dir
():
continue
for
file_path
in
dir_path
.
iterdir
():
if
file_path
.
name
.
startswith
(
prefix
)
and
file_path
.
suffix
==
_get_sys_extension
():
files
.
append
(
file_path
)
if
len
(
files
)
==
0
:
return
None
...
...
@@ -79,16 +80,12 @@ def _find_shared_object_in_te_dir(te_path: Path, prefix: str):
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_shared_object_file
(
library
:
str
)
->
Path
:
"""
Return the path of the shared object file for the given TE
library, one of 'core', 'torch', or 'jax'.
Several factors affect finding the correct location of the shared object:
1. System and environment.
2. If the installation is from source or via PyPI.
- Source installed .sos are placed in top level dir
- Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts.
3. For source installations, is the install editable/inplace?
4. The user directory from where TE is being imported.
Path to shared object file for a Transformer Engine library.
TE libraries are 'core', 'torch', or 'jax'. This function first
searches in the imported TE directory, and then in the
site-packages directory.
"""
# Check provided input and determine the correct prefix for .so.
...
...
@@ -98,47 +95,25 @@ def _get_shared_object_file(library: str) -> Path:
else
:
so_prefix
=
f
"transformer_engine_
{
library
}
"
# Check TE install location (will be local if TE is available in current dir for import).
te_install_dir
=
Path
(
importlib
.
util
.
find_spec
(
"transformer_engine"
).
origin
).
parent
.
parent
so_path_in_install_dir
=
_find_shared_object_in_te_dir
(
te_install_dir
,
so_prefix
)
# Search for shared lib in imported directory
te_path
=
Path
(
importlib
.
util
.
find_spec
(
"transformer_engine"
).
origin
).
parent
.
parent
so_path
=
_find_shared_object_in_te_dir
(
te_path
,
so_prefix
)
if
so_path
is
not
None
:
return
so_path
# Check default python package install location in system.
site_packages_dir
=
Path
(
sysconfig
.
get_paths
()[
"purelib"
])
so_path_in_default_dir
=
_find_shared_object_in_te_dir
(
site_packages_dir
,
so_prefix
)
# Search for shared lib in site-packages directory
te_path
=
Path
(
sysconfig
.
get_paths
()[
"purelib"
])
so_path
=
_find_shared_object_in_te_dir
(
te_path
,
so_prefix
)
if
so_path
is
not
None
:
return
so_path
# Case 1: Typical user workflow: Both locations are the same, return any result.
if
te_install_dir
==
site_packages_dir
:
assert
(
so_path_in_install_dir
is
not
None
),
f
"Could not find shared object file for Transformer Engine
{
library
}
lib."
return
so_path_in_install_dir
# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
# editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic.
if
so_path_in_install_dir
is
not
None
and
so_path_in_default_dir
is
not
None
:
raise
RuntimeError
(
f
"Found multiple shared object files:
{
so_path_in_install_dir
}
and"
f
"
{
so_path_in_default_dir
}
. Remove local shared objects installed"
f
" here
{
so_path_in_install_dir
}
or change the working directory to"
"execute from outside TE."
raise
FileNotFoundError
(
f
"Could not find shared object file for Transformer Engine
{
library
}
lib."
)
# Case 3: Typical dev workflow: Editable install
if
so_path_in_install_dir
is
not
None
:
return
so_path_in_install_dir
# Case 4: Executing from inside a TE directory without an inplace build available.
if
so_path_in_default_dir
is
not
None
:
return
so_path_in_default_dir
raise
RuntimeError
(
f
"Could not find shared object file for Transformer Engine
{
library
}
lib."
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
load_framework_extension
(
framework
:
str
):
def
load_framework_extension
(
framework
:
str
)
->
None
:
"""
Load shared library with Transformer Engine framework bindings
and check verify correctness if installed via PyPI.
...
...
@@ -196,19 +171,18 @@ def load_framework_extension(framework: str):
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_sys_extension
():
def
_get_sys_extension
()
->
str
:
"""File extension for shared objects."""
system
=
platform
.
system
()
if
system
==
"Linux"
:
extension
=
"so"
elif
system
==
"Darwin"
:
extension
=
"dylib"
elif
system
==
"Windows"
:
extension
=
"dll"
else
:
return
".so"
if
system
==
"Darwin"
:
return
".dylib"
if
system
==
"Windows"
:
return
".dll"
raise
RuntimeError
(
f
"Unsupported operating system (
{
system
}
)"
)
return
extension
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_nvidia_cuda_library
(
lib_name
:
str
):
...
...
@@ -221,7 +195,7 @@ def _load_nvidia_cuda_library(lib_name: str):
so_paths
=
glob
.
glob
(
os
.
path
.
join
(
sysconfig
.
get_path
(
"purelib"
),
f
"nvidia/
{
lib_name
}
/lib/lib*
.
{
_get_sys_extension
()
}
.*[0-9]"
,
f
"nvidia/
{
lib_name
}
/lib/lib*
{
_get_sys_extension
()
}
.*[0-9]"
,
)
)
...
...
@@ -236,7 +210,7 @@ def _load_nvidia_cuda_library(lib_name: str):
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_nvidia_cudart_include_dir
():
def
_nvidia_cudart_include_dir
()
->
str
:
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
try
:
...
...
@@ -255,14 +229,14 @@ def _load_cudnn():
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home
=
os
.
environ
.
get
(
"CUDNN_HOME"
)
or
os
.
environ
.
get
(
"CUDNN_PATH"
)
if
cudnn_home
:
libs
=
glob
.
glob
(
f
"
{
cudnn_home
}
/**/libcudnn
.
{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
=
glob
.
glob
(
f
"
{
cudnn_home
}
/**/libcudnn
{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
.
sort
(
reverse
=
True
,
key
=
os
.
path
.
basename
)
if
libs
:
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home
=
os
.
environ
.
get
(
"CUDA_HOME"
)
or
os
.
environ
.
get
(
"CUDA_PATH"
)
or
"/usr/local/cuda"
libs
=
glob
.
glob
(
f
"
{
cuda_home
}
/**/libcudnn
.
{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
=
glob
.
glob
(
f
"
{
cuda_home
}
/**/libcudnn
{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
.
sort
(
reverse
=
True
,
key
=
os
.
path
.
basename
)
if
libs
:
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
...
...
@@ -273,7 +247,7 @@ def _load_cudnn():
return
handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return
ctypes
.
CDLL
(
f
"libcudnn
.
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
return
ctypes
.
CDLL
(
f
"libcudnn
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
@
functools
.
lru_cache
(
maxsize
=
None
)
...
...
@@ -281,7 +255,7 @@ def _load_nvrtc():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home
=
os
.
environ
.
get
(
"CUDA_HOME"
)
or
os
.
environ
.
get
(
"CUDA_PATH"
)
or
"/usr/local/cuda"
libs
=
glob
.
glob
(
f
"
{
cuda_home
}
/**/libnvrtc
.
{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
=
glob
.
glob
(
f
"
{
cuda_home
}
/**/libnvrtc
{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
=
list
(
filter
(
lambda
x
:
not
(
"stub"
in
x
or
"libnvrtc-builtins"
in
x
),
libs
))
libs
.
sort
(
reverse
=
True
,
key
=
os
.
path
.
basename
)
if
libs
:
...
...
@@ -305,7 +279,7 @@ def _load_nvrtc():
return
ctypes
.
CDLL
(
sos
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return
ctypes
.
CDLL
(
f
"libnvrtc
.
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
return
ctypes
.
CDLL
(
f
"libnvrtc
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
@
functools
.
lru_cache
(
maxsize
=
None
)
...
...
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
2b05e121
...
...
@@ -248,7 +248,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
if
(
param_type
==
NVTETensorParam
::
kNVTERowwiseData
||
param_type
==
NVTETensorParam
::
kNVTEColumnwiseData
)
{
// Offset data pointer
param_dptr
+=
chunk_offset
*
typeToSize
(
param_dtype
);
param_dptr
+=
get_buffer_size_bytes
(
chunk_offset
,
param_dtype
);
param_shape
=
chunk_shape
;
if
(
param_type
==
NVTETensorParam
::
kNVTEColumnwiseData
&&
...
...
@@ -269,7 +269,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
}
else
{
chunk_scale_height
/=
32
;
}
param_dptr
+=
(
chunk_offset
/
32
)
*
typeToSize
(
param_dtype
);
param_dptr
+=
get_buffer_size_bytes
(
chunk_offset
/
32
,
param_dtype
);
param_shape
=
{
chunk_scale_height
,
chunk_scale_width
};
}
...
...
@@ -288,7 +288,7 @@ TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source
auto
chunk
=
get_tensor_chunk
(
source
,
chunk_offset
,
chunk_shape
);
// Update chunk with offset data pointers from the communication buffer
auto
ubuf_ptr
=
reinterpret_cast
<
char
*>
(
_ubuf
.
dptr
())
+
(
chunk_offset
*
_ubuf
.
element_size
()
)
;
auto
ubuf_ptr
=
reinterpret_cast
<
char
*>
(
_ubuf
.
dptr
())
+
chunk_offset
*
_ubuf
.
element_size
();
if
(
chunk
.
dptr
()
!=
nullptr
)
{
chunk
.
set_rowwise_data
(
reinterpret_cast
<
void
*>
(
ubuf_ptr
),
chunk
.
dtype
(),
chunk
.
shape
());
}
...
...
@@ -326,7 +326,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
"or 2 (multi-atomic)."
);
NVTE_CHECK
(
buffer_shape
.
size
()
==
2
,
"Userbuffer shape must be 2-dimensional!"
);
size_t
buffer_bytes
=
buffer_shape
[
0
]
*
buffer_shape
[
1
]
*
typeToSize
(
buffer_dtype
);
size_t
buffer_bytes
=
get_buffer_size_bytes
(
buffer_shape
[
0
]
,
buffer_shape
[
1
]
,
buffer_dtype
);
void
*
buffer_ptr
;
_ub_reg
=
register_user_buffer_collective
(
&
buffer_ptr
,
buffer_bytes
,
_ub_comm
,
true
);
if
(
_ub_comm
->
myrank
==
0
)
printf
(
"!!! [UB] Register UBuf %d
\n
"
,
_ub_reg
);
...
...
@@ -398,7 +398,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[
0
],
_start_comm
,
0
));
// Communication: AG and RS
int
comm_elements
=
(
_ubuf
.
numel
()
/
2
)
*
_ubuf
.
element_size
()
;
// UBUF uses 2Byte element size
int
comm_elements
=
_ubuf
.
bytes
()
/
2
;
// UBUF uses 2Byte element size
if
(
comm_type
==
CommOverlapType
::
AG
)
{
allgather2_userbuff_inplace
(
_ub_reg
,
0
,
comm_elements
,
_ub_comm
,
_stream_comm
,
(
cudaEvent_t
)
_comm_launch_event
);
...
...
@@ -723,7 +723,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
// Create workspace tensor with userbuffer
NVTE_CHECK
(
buffer_shape
.
size
()
==
2
,
"Userbuffer shape must be 2-dimensional!"
);
size_t
buffer_bytes
=
buffer_shape
[
0
]
*
buffer_shape
[
1
]
*
typeToSize
(
buffer_dtype
);
size_t
buffer_bytes
=
get_buffer_size_bytes
(
buffer_shape
[
0
]
,
buffer_shape
[
1
]
,
buffer_dtype
);
int
buffer_chunk_bytes
=
buffer_bytes
/
tp_size
;
_num_ubuf_chunks
=
tp_size
;
if
(
_is_reduce_scatter
)
{
...
...
@@ -827,7 +827,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
assert
(
pre_gelu_out
.
numel
()
==
0
);
// Get communication and GEMM output chunk sizes
const
int
comm_bytes
=
_ubufs
[
0
].
numel
()
*
_ubufs
[
0
].
element_size
();
const
int
comm_bytes
=
_ubufs
[
0
].
bytes
();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
void
*
D_buffer_ptr
;
...
...
@@ -885,21 +885,20 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
if
(
B_copy
.
numel
()
>
0
)
{
assert
(
B_copy
.
numel
()
==
_ubufs
[
_self_chunk_id
].
numel
());
assert
(
B_copy
.
element_size
()
==
_ubufs
[
_self_chunk_id
].
element_size
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
B_copy
.
dptr
(),
_ubufs
[
_self_chunk_id
].
dptr
(),
_ubufs
[
_self_chunk_id
].
numel
()
*
_ubufs
[
_self_chunk_id
].
element_size
(),
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
B_copy
.
dptr
(),
_ubufs
[
_self_chunk_id
].
dptr
(),
_ubufs
[
_self_chunk_id
].
bytes
(),
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_send
,
_stream_send
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_send
,
0
));
}
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char
*
src_ptr
=
reinterpret_cast
<
char
*>
(
D_buffer
.
dptr
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
src_ptr
+
(
D
.
numel
()
*
D
.
element_size
()
),
src_ptr
,
D_chunk_bytes
,
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
src_ptr
+
D
.
bytes
(
),
src_ptr
,
D_chunk_bytes
,
cudaMemcpyDeviceToDevice
,
stream_main
));
// Return the last N rows of D_buffer
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
D
.
dptr
(),
src_ptr
+
D_chunk_bytes
,
D
.
numel
()
*
D
.
element_size
(),
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
D
.
dptr
(),
src_ptr
+
D_chunk_bytes
,
D
.
bytes
(),
cudaMemcpyDeviceToDevice
,
stream_main
));
// Clean up buffer allocation
...
...
@@ -929,7 +928,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
const
size_t
n_chunk
=
_ubufs
[
0
].
size
(
0
);
// Get communication and GEMM output chunk sizes
const
int
comm_bytes
=
_ubufs
[
0
].
numel
()
*
_ubufs
[
0
].
element_size
();
const
int
comm_bytes
=
_ubufs
[
0
].
bytes
();
const
bool
do_gelu
=
pre_gelu_out
.
numel
()
>
0
;
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
...
...
@@ -945,7 +944,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Chunk dims
std
::
vector
<
size_t
>
input_b_chunk_shape
=
(
transb
?
std
::
vector
<
size_t
>
{
k
,
2
*
n_chunk
}
:
std
::
vector
<
size_t
>
{
2
*
n_chunk
,
k
});
std
::
vector
<
size_t
>
output_chunk_shape
=
{
2
*
n_chunk
,
k
};
std
::
vector
<
size_t
>
output_chunk_shape
=
{
2
*
n_chunk
,
m
};
size_t
input_b_chunk_size
=
2
*
n_chunk
*
k
;
size_t
output_chunk_size
=
2
*
n_chunk
*
m
;
...
...
@@ -976,12 +975,12 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM
auto
input_b_chunk
=
get_buffer_chunk_like
(
B
,
input_b_chunk_size
*
send_chunk_id
,
input_b_chunk_shape
);
get_buffer_chunk_like
(
B
,
input_b_chunk_size
*
send_chunk_id
/
2
,
input_b_chunk_shape
);
auto
output_chunk
=
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
,
output_chunk_shape
);
auto
aux_chunk
=
(
do_gelu
)
?
get_tensor_chunk
(
pre_gelu_out
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
*
2
,
k
})
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
/
2
,
output_chunk_shape
);
auto
aux_chunk
=
(
do_gelu
)
?
get_tensor_chunk
(
pre_gelu_out
,
output_chunk_size
*
send_chunk_id
/
2
,
{
n_chunk
*
2
,
k
})
:
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
0
},
pre_gelu_out
.
dtype
());
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
...
...
@@ -1012,8 +1011,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
assert
(
B_copy
.
numel
()
==
_ubufs
[
_tp_id
].
numel
());
assert
(
B_copy
.
element_size
()
==
_ubufs
[
_tp_id
].
element_size
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
B_copy
.
dptr
(),
_ubufs
[
_tp_id
].
dptr
(),
_ubufs
[
_tp_id
].
numel
()
*
_ubufs
[
_tp_id
].
element_size
()
,
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
_ubufs
[
_tp_id
].
bytes
(),
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
}
}
}
else
{
...
...
@@ -1072,8 +1071,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
assert
(
B_copy
.
numel
()
==
_ubufs
[
_tp_id
].
numel
());
assert
(
B_copy
.
element_size
()
==
_ubufs
[
_tp_id
].
element_size
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
B_copy
.
dptr
(),
_ubufs
[
_tp_id
].
dptr
(),
_ubufs
[
_tp_id
].
numel
()
*
_ubufs
[
_tp_id
].
element_size
()
,
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
_ubufs
[
_tp_id
].
bytes
(),
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
}
}
}
...
...
@@ -1103,7 +1102,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
_ub_comm
->
cga_size
=
_cga_size
;
// Get communication and GEMM input chunk sizes
const
int
comm_bytes
=
_ubufs
[
0
].
numel
()
*
_ubufs
[
0
].
element_size
();
const
int
comm_bytes
=
_ubufs
[
0
].
bytes
();
// Reset counters
int
*
counter_ptr
=
reinterpret_cast
<
int
*>
(
_counter
.
dptr
());
...
...
@@ -1170,7 +1169,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
size_t
m
=
transa
?
A
.
size
(
0
)
:
A
.
size
(
1
);
size_t
k
=
transa
?
A
.
size
(
1
)
:
A
.
size
(
0
);
size_t
n_chunk
=
_ubufs
[
0
].
size
(
0
);
const
int
comm_bytes
=
_ubufs
[
0
].
numel
()
*
_ubufs
[
0
].
element_size
();
const
int
comm_bytes
=
_ubufs
[
0
].
bytes
();
// Get input and workspace data pointers
size_t
input_chunk_size
=
n_chunk
*
k
;
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
View file @
2b05e121
...
...
@@ -248,7 +248,8 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
CUmemFabricHandle
*
tmphndl
=
reinterpret_cast
<
CUmemFabricHandle
*>
(
malloc
(
sizeof
(
CUmemFabricHandle
)));
CUmemFabricHandle
*
exphndls
;
NVTE_CHECK_CUDA
(
cudaMallocHost
(
&
exphndls
,
(
*
comm
)
->
nvsize
*
sizeof
(
CUmemFabricHandle
)));
NVTE_CHECK_CUDA
(
cudaMallocHost
(
reinterpret_cast
<
void
**>
(
&
exphndls
),
(
*
comm
)
->
nvsize
*
sizeof
(
CUmemFabricHandle
)));
if
((
*
comm
)
->
ar2_nvrank
==
0
)
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemExportToShareableHandle
,
static_cast
<
void
*>
(
tmphndl
),
(
*
comm
)
->
mc_handle
,
CU_MEM_HANDLE_TYPE_FABRIC
,
0
);
...
...
@@ -345,8 +346,10 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
register_user_buffer_collective
(
&
((
*
comm
)
->
gpu_ptrs
),
LOCALSIZE
,
*
comm
,
true
);
NVTE_CHECK_CUDA
(
cudaMalloc
(
&
(
*
comm
)
->
send_id
,
(
*
comm
)
->
nranks
*
sizeof
(
int
)));
NVTE_CHECK_CUDA
(
cudaMalloc
(
&
(
*
comm
)
->
recv_id
,
NVTE_MAX_REGIONS
*
(
*
comm
)
->
nranks
*
sizeof
(
int
)));
NVTE_CHECK_CUDA
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
(
*
comm
)
->
send_id
),
(
*
comm
)
->
nranks
*
sizeof
(
int
)));
NVTE_CHECK_CUDA
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
(
*
comm
)
->
recv_id
),
NVTE_MAX_REGIONS
*
(
*
comm
)
->
nranks
*
sizeof
(
int
)));
NVTE_CHECK_CUDA
(
cudaMemset
((
*
comm
)
->
send_id
,
0
,
(
*
comm
)
->
nranks
*
sizeof
(
int
)));
NVTE_CHECK_CUDA
(
cudaMemset
((
*
comm
)
->
recv_id
,
0
,
NVTE_MAX_REGIONS
*
(
*
comm
)
->
nranks
*
sizeof
(
int
)));
...
...
@@ -358,13 +361,14 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
#define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1)
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
NVTE_CHECK_CUDA
(
cudaMalloc
(
&
(
*
comm
)
->
flags
,
2
*
GPU_PAGE_SIZE
));
NVTE_CHECK_CUDA
(
cudaMemset
((
*
comm
)
->
flags
,
0
,
2
*
GPU_PAGE_SIZE
));
(
*
comm
)
->
flags
=
NVTE_CHECK_CUDA
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
(
*
comm
)
->
flags_baseptr
),
2
*
GPU_PAGE_SIZE
));
NVTE_CHECK_CUDA
(
cudaMemset
((
*
comm
)
->
flags_baseptr
,
0
,
2
*
GPU_PAGE_SIZE
));
(
*
comm
)
->
flags
=
reinterpret_cast
<
int
*>
(
#ifdef USE_ROCM
reinterpret_cast
<
int
*>
(
(
reinterpret_cast
<
uintptr_t
>
((
*
comm
)
->
flags
)
+
GPU_PAGE_SIZE
-
1
)
&
GPU_PAGE_MASK
);
(
reinterpret_cast
<
uintptr_t
>
((
*
comm
)
->
flags
_baseptr
)
+
GPU_PAGE_SIZE
-
1
)
&
GPU_PAGE_MASK
);
#else
reinterpret_cast
<
int
*>
(
((
CUdeviceptr
)(
*
comm
)
->
flags
+
GPU_PAGE_SIZE
-
1
)
&
GPU_PAGE_MASK
);
((
CUdeviceptr
)(
*
comm
)
->
flags
_baseptr
+
GPU_PAGE_SIZE
-
1
)
&
GPU_PAGE_MASK
);
#endif
using
namespace
std
;
...
...
@@ -442,20 +446,31 @@ int create_communicator_mpi(communicator **comm) {
}
void
destroy_communicator
(
communicator
*
comm
)
{
for
(
int
hndl
=
0
;
hndl
<
comm
->
free_region
;
hndl
++
)
{
// Clear memory allocated in register_user_buffer_collective calls
for
(
int
hndl
=
comm
->
free_region
-
1
;
hndl
>=
0
;
hndl
--
)
{
if
(
comm
->
use_mc
&&
comm
->
mem_dealloc
[
hndl
])
{
// Unbind the local device buffer from the Multicast handle
CUdevice
dev
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuDeviceGet
,
&
dev
,
comm
->
mydev
);
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMulticastUnbind
,
comm
->
mc_handle
,
dev
,
comm
->
uc_offsets
[
hndl
],
comm
->
mem_size
[
hndl
]);
// Unmap memory addresses and release handles for both peer and own buffers
for
(
int
rank
=
0
;
rank
<
comm
->
nvsize
;
rank
++
)
{
if
(
rank
==
comm
->
nvrank
)
{
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemUnmap
,
reinterpret_cast
<
CUdeviceptr
>
(
comm
->
peer_ptr
[
hndl
][
rank
]),
comm
->
mem_size
[
hndl
]);
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemRelease
,
comm
->
uchandles
[
hndl
][
rank
]);
}
else
{
comm
->
uchandles
[
hndl
][
rank
]
=
0
;
}
}
free
(
reinterpret_cast
<
void
*>
(
comm
->
uchandles
[
hndl
]));
// Free memory reserved for buffer allocations
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemAddressFree
,
comm
->
ucbase_ptr
[
hndl
],
static_cast
<
size_t
>
(
comm
->
mem_size
[
hndl
]
*
comm
->
nvsize
));
}
else
{
for
(
int
rank
=
0
;
rank
<
comm
->
nvsize
;
rank
++
)
{
if
(
rank
!=
comm
->
nvrank
)
{
cudaIpcCloseMemHandle
(
comm
->
peer_ptr
[
hndl
][
rank
]);
NVTE_CHECK_CUDA
(
cudaIpcCloseMemHandle
(
comm
->
peer_ptr
[
hndl
][
rank
])
)
;
}
else
if
(
comm
->
mem_dealloc
[
hndl
])
{
NVTE_CHECK_CUDA
(
cudaFree
(
comm
->
peer_ptr
[
hndl
][
rank
]));
}
else
{
...
...
@@ -464,11 +479,16 @@ void destroy_communicator(communicator *comm) {
}
}
free
(
comm
->
peer_ptr
[
hndl
]);
comm
->
mem_ptr
[
hndl
]
=
nullptr
;
comm
->
mem_ptr
[
hndl
]
=
nullptr
;
// this points to already cleaned up local device buffer
}
cudaFree
(
reinterpret_cast
<
void
*>
(
comm
->
recv_id
));
cudaFree
(
reinterpret_cast
<
void
*>
(
comm
->
send_id
));
// Clear memory allocated in the communicator constructor
NVTE_CHECK_CUDA
(
cudaFree
(
reinterpret_cast
<
void
*>
(
comm
->
recv_id
)));
NVTE_CHECK_CUDA
(
cudaFree
(
reinterpret_cast
<
void
*>
(
comm
->
send_id
)));
NVTE_CHECK_CUDA
(
cudaFree
(
reinterpret_cast
<
void
*>
(
comm
->
flags_baseptr
)));
if
(
comm
->
use_mc
)
{
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemUnmap
,
reinterpret_cast
<
CUdeviceptr
>
(
comm
->
mc_baseptr
),
comm
->
mc_maxsize
);
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemAddressFree
,
comm
->
mc_baseptr
,
comm
->
mc_maxsize
);
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemRelease
,
comm
->
mc_handle
);
}
delete
comm
;
...
...
@@ -535,7 +555,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
CUmemFabricHandle
myhndl
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemExportToShareableHandle
,
&
myhndl
,
comm
->
uchandles
[
hndl
][
myrank
],
CU_MEM_HANDLE_TYPE_FABRIC
,
0
);
NVTE_CHECK_CUDA
(
cudaMallocHost
(
&
exphndl
,
comm
->
nvsize
*
sizeof
(
CUmemFabricHandle
)));
NVTE_CHECK_CUDA
(
cudaMallocHost
(
reinterpret_cast
<
void
**>
(
&
exphndl
),
comm
->
nvsize
*
sizeof
(
CUmemFabricHandle
)));
comm
->
_allgather
(
reinterpret_cast
<
void
*>
(
exphndl
),
comm
->
nvsize
*
sizeof
(
CUmemFabricHandle
),
reinterpret_cast
<
void
*>
(
&
myhndl
),
sizeof
(
CUmemFabricHandle
),
comm
->
comm_intra
);
...
...
@@ -619,6 +640,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
aligned_size
,
(
uint64_t
)
0
);
comm
->
memflags
[
hndl
]
|=
NVTE_UB_MEM_MC_CREATED
;
comm
->
mc_ptr
[
hndl
]
=
reinterpret_cast
<
char
*>
(
comm
->
mc_baseptr
)
+
comm
->
mc_offset
;
comm
->
uc_offsets
[
hndl
]
=
comm
->
mc_offset
;
comm
->
mc_offset
+=
aligned_size
;
}
else
if
(
!
comm
->
myrank
)
{
printf
(
"UB: warning region %d size %ld MB registered without MC access
\n
"
,
hndl
,
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
View file @
2b05e121
...
...
@@ -111,6 +111,7 @@ struct communicator {
CUmemGenericAllocationHandle
*
uchandles
[
NVTE_MAX_REGIONS
];
#endif
void
*
ucbase_ptr
[
NVTE_MAX_REGIONS
];
// only for cuMem allocated memory
size_t
uc_offsets
[
NVTE_MAX_REGIONS
];
size_t
mem_size
[
NVTE_MAX_REGIONS
];
bool
mem_dealloc
[
NVTE_MAX_REGIONS
];
...
...
@@ -133,7 +134,7 @@ struct communicator {
// max value for running block counters in hostflags
int
basecounter
[
userbuffers_op_types
];
// NOLINT(*)
int
*
flags
,
*
map_flags
;
int
*
flags_baseptr
,
*
flags
,
*
map_flags
;
void
*
mem_mr
[
NVTE_MAX_REGIONS
];
...
...
transformer_engine/common/common.cu
View file @
2b05e121
...
...
@@ -121,13 +121,20 @@ void checkCuDriverContext(CUstream stream) {
#ifndef __HIP_PLATFORM_AMD__
CUtensorMapDataType
get_CUtensorMapDataType
(
DType
dtype
)
{
static
const
std
::
unordered_map
<
DType
,
CUtensorMapDataType
>
dtypeMapping
=
{
static
const
std
::
unordered_map
<
DType
,
CUtensorMapDataType
>
dtypeMapping
=
[]()
{
std
::
unordered_map
<
DType
,
CUtensorMapDataType
>
typeMapping
=
{
{
DType
::
kByte
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_UINT8
},
{
DType
::
kFloat32
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_FLOAT32
},
{
DType
::
kFloat16
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_FLOAT16
},
{
DType
::
kBFloat16
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16
},
{
DType
::
kFloat8E4M3
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_UINT8
},
{
DType
::
kFloat8E5M2
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_UINT8
}};
#if FP4_TYPE_SUPPORTED
typeMapping
.
insert
(
{
DType
::
kFloat4E2M1
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B
});
#endif
return
typeMapping
;
}();
return
dtypeMapping
.
at
(
dtype
);
}
...
...
@@ -135,18 +142,19 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
void
create_2D_tensor_map
(
CUtensorMap
&
tensorMap
,
const
SimpleTensor
&
tensor
,
const
uint64_t
globalY
,
const
uint64_t
globalX
,
const
uint32_t
shmemY
,
const
uint32_t
shmemX
,
const
uint32_t
stride_elems
,
const
uint32_t
offset_elems
,
const
size_t
type_
size
)
{
const
uint32_t
offset_elems
,
const
size_t
type_
num_bits
)
{
// Get a function pointer to the cuTensorMapEncodeTiled driver API
static
PFN_cuTensorMapEncodeTiled
cuDriverTensorMapEncodeTiled
=
[]()
{
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static
PFN_cuTensorMapEncodeTiled_v12000
cuDriverTensorMapEncodeTiled
=
[]()
{
void
*
driver_ptr
=
cuda_driver
::
get_symbol
(
"cuTensorMapEncodeTiled"
);
return
reinterpret_cast
<
PFN_cuTensorMapEncodeTiled
>
(
driver_ptr
);
return
reinterpret_cast
<
PFN_cuTensorMapEncodeTiled
_v12000
>
(
driver_ptr
);
}();
// rank is the number of dimensions of the array
constexpr
uint32_t
rank
=
2
;
uint64_t
size
[
rank
]
=
{
globalX
,
globalY
};
// The stride is the number of bytes to traverse from the first element of one row to the next
uint64_t
stride
[
rank
-
1
]
=
{
stride_elems
*
type_
size
};
uint64_t
stride
[
rank
-
1
]
=
{
(
stride_elems
*
type_
num_bits
)
/
8
};
// The boxSize is the size of the shared memory buffer that is used as the
// source/destination of a TMA transfer
...
...
@@ -156,15 +164,15 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
uint32_t
elemStride
[
rank
]
=
{
1
,
1
};
const
CUtensorMapDataType
tensorDataType
=
get_CUtensorMapDataType
(
tensor
.
dtype
);
void
*
dataPtr
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uint8_t
*>
(
tensor
.
dptr
)
+
offset_elems
*
type_
size
);
void
*
dataPtr
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uint8_t
*>
(
tensor
.
dptr
)
+
(
offset_elems
*
type_
num_bits
)
/
8
);
NVTE_CHECK
(
is_aligned_ptr
(
dataPtr
,
TMA_gmem_alignment
),
"Tensor data pointer must be 16B aligned"
);
const
int
TMA_needed_size
=
TMA_gmem_alignment
/
type_
size
;
NVTE_CHECK
(
globalX
%
TMA_needed_size
==
0
,
"Shape not supported. For "
,
type_
size
,
"-b
yte
data type, expected multiple of "
,
TMA_needed_size
,
", got "
,
globalX
);
const
int
TMA_needed_size
=
(
TMA_gmem_alignment
*
8
)
/
type_
num_bits
;
NVTE_CHECK
(
globalX
%
TMA_needed_size
==
0
,
"Shape not supported. For "
,
type_
num_bits
,
"-b
it
data type, expected multiple of "
,
TMA_needed_size
,
", got "
,
globalX
);
// Create the tensor descriptor.
NVTE_CHECK_CUDA_DRIVER
(
cuDriverTensorMapEncodeTiled
(
...
...
@@ -209,10 +217,24 @@ std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensor
for
(
size_t
i
=
0
;
i
<
outer_size
;
++
i
)
{
ret
.
emplace_back
();
for
(
size_t
j
=
0
;
j
<
inner_size
;
++
j
)
{
ret
.
back
().
push_back
(
reinterpret_cast
<
Tensor
*>
(
nvte_tensors
[
i
][
j
]));
ret
.
back
().
push_back
(
convertNVTE
Tensor
(
nvte_tensors
[
i
][
j
]));
}
}
return
ret
;
}
size_t
get_buffer_size_bytes
(
const
size_t
elements_num
,
const
DType
buffer_dtype
)
{
return
(
elements_num
*
typeToNumBits
(
buffer_dtype
))
/
8
;
}
size_t
get_buffer_size_bytes
(
const
size_t
dim_first
,
const
size_t
dim_last
,
const
DType
buffer_dtype
)
{
if
(
buffer_dtype
==
DType
::
kFloat4E2M1
)
{
NVTE_CHECK
(
dim_last
%
2
==
0
,
"Last dimension of a tensor with FP4 type of data must be an even number!"
);
}
const
size_t
elements_num
=
dim_first
*
dim_last
;
return
get_buffer_size_bytes
(
elements_num
,
buffer_dtype
);
}
}
// namespace transformer_engine
transformer_engine/common/common.h
View file @
2b05e121
...
...
@@ -9,9 +9,15 @@
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>
...
...
@@ -90,9 +96,16 @@ struct SimpleTensor {
}
return
acc
;
}
void
clear
()
{
dptr
=
nullptr
;
shape
.
resize
(
0
);
dtype
=
DType
::
kFloat32
;
}
};
struct
Tensor
{
public:
SimpleTensor
data
;
SimpleTensor
columnwise_data
;
SimpleTensor
amax
;
...
...
@@ -100,8 +113,8 @@ struct Tensor {
SimpleTensor
scale_inv
;
SimpleTensor
columnwise_scale_inv
;
public:
NVTEScalingMode
scaling_mode
;
NVTETensor
nvte_tensor
;
Tensor
()
:
data
(),
...
...
@@ -110,7 +123,20 @@ struct Tensor {
scale
(
nullptr
,
{
1
},
DType
::
kFloat32
),
scale_inv
(
nullptr
,
{
1
},
DType
::
kFloat32
),
columnwise_scale_inv
(
nullptr
,
{
1
},
DType
::
kFloat32
),
scaling_mode
(
NVTE_DELAYED_TENSOR_SCALING
)
{}
scaling_mode
(
NVTE_DELAYED_TENSOR_SCALING
),
nvte_tensor
(
0
)
{}
void
clear
()
{
data
.
clear
();
columnwise_data
.
clear
();
amax
.
clear
();
scale
.
clear
();
scale_inv
.
clear
();
columnwise_scale_inv
.
clear
();
scaling_mode
=
NVTE_DELAYED_TENSOR_SCALING
;
}
explicit
operator
NVTETensor
()
const
noexcept
{
return
nvte_tensor
;
}
size_t
numel
()
const
{
size_t
acc
=
1
;
...
...
@@ -164,6 +190,7 @@ struct Tensor {
}
break
;
case
NVTE_MXFP8_1D_SCALING
:
case
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING
:
if
(
!
has_data
()
&&
has_columnwise_data
())
{
return
columnwise_data
.
shape
;
}
else
{
...
...
@@ -233,11 +260,14 @@ struct QuantizationConfig {
bool
force_pow_2_scales
=
false
;
float
amax_epsilon
=
0.0
f
;
NVTETensor
noop_tensor
=
nullptr
;
Float8BlockScaleTensorFormat
float8_block_scale_tensor_format
=
Float8BlockScaleTensorFormat
::
GEMM_READY
;
static
constexpr
size_t
attr_sizes
[]
=
{
sizeof
(
bool
),
// force_pow_2_scales
sizeof
(
float
),
// amax_epsilon
sizeof
(
NVTETensor
)
// noop_tensor
sizeof
(
NVTETensor
),
// noop_tensor
sizeof
(
Float8BlockScaleTensorFormat
)
// float8_block_scale_tensor_format
};
};
...
...
@@ -246,6 +276,13 @@ constexpr T DIVUP(const T &x, const T &y) {
return
(((
x
)
+
((
y
)
-
1
))
/
(
y
));
}
template
<
typename
T1
,
typename
T2
>
constexpr
__device__
__host__
__forceinline__
uint64_t
DIVUP_TO_MULTIPLE
(
const
T1
&
N
,
const
T2
&
M
)
{
static_assert
(
std
::
is_integral
<
T1
>::
value
&&
std
::
is_integral
<
T2
>::
value
,
"Integral type required."
);
return
DIVUP
(
static_cast
<
uint64_t
>
(
N
),
static_cast
<
uint64_t
>
(
M
))
*
M
;
}
using
byte
=
uint8_t
;
using
int16
=
int16_t
;
using
int32
=
int32_t
;
...
...
@@ -259,8 +296,10 @@ using fp8e5m2 = __nv_fp8_e5m2;
#if CUDA_VERSION >= 12080
using
fp8e8m0
=
__nv_fp8_e8m0
;
#endif
#if FP4_TYPE_SUPPORTED
using
fp4e2m1
=
__nv_fp4_e2m1
;
#endif
using
e8m0_t
=
uint8_t
;
using
int8
=
int8_t
;
namespace
detail
{
...
...
@@ -284,11 +323,21 @@ TRANSFORMER_ENGINE_TYPE_NAME(int8_t)
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e8m0
)
#endif
#if FP4_TYPE_SUPPORTED
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp4_e2m1
)
#endif
#undef TRANSFORMER_ENGINE_TYPE_NAME
template
<
typename
T
>
struct
TypeExtrema
;
#if FP4_TYPE_SUPPORTED
template
<
>
struct
TypeExtrema
<
fp4e2m1
>
{
static
constexpr
float
max
=
6.0
f
;
};
#endif
template
<
>
struct
TypeExtrema
<
fp8e4m3
>
{
static
constexpr
float
max
=
448.0
f
;
...
...
@@ -323,9 +372,28 @@ struct TypeExtrema {
}
// namespace detail
template
<
typename
T
>
struct
BitsNumber
;
#if FP4_TYPE_SUPPORTED
template
<
>
struct
BitsNumber
<
fp4e2m1
>
{
static
constexpr
size_t
num_bits
=
4
;
};
#endif
template
<
typename
T
>
struct
BitsNumber
{
static
constexpr
size_t
num_bits
=
8
*
sizeof
(
T
);
};
template
<
typename
T
>
struct
TypeInfo
{
#if FP4_TYPE_SUPPORTED
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp4e2m1
>
;
#else
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
>
;
#endif
template
<
typename
U
,
DType
current
>
struct
Helper
{
...
...
@@ -350,11 +418,21 @@ struct TypeInfo {
}
constexpr
static
DType
dtype
=
getType
<
T
>
();
constexpr
static
size_t
size
=
sizeof
(
T
)
;
constexpr
static
size_t
size
=
BitsNumber
<
T
>::
num_bits
;
constexpr
static
float
max_finite_value
=
detail
::
TypeExtrema
<
T
>::
max
;
constexpr
static
const
char
*
name
=
detail
::
type_name
<
T
>
();
};
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
case DType::kFloat4E2M1: { \
using type = fp4e2m1; \
{ __VA_ARGS__ } \
} break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
...
...
@@ -398,6 +476,7 @@ struct TypeInfo {
using type = byte; \
{ __VA_ARGS__ } \
} break; \
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type."); \
}
...
...
@@ -559,6 +638,9 @@ struct TypeInfo {
case DType::kFloat8E4M3: { \
NVTE_ERROR("FP8 type not instantiated for input."); \
} break; \
case DType::kFloat4E2M1: { \
NVTE_ERROR("FP4 type not instantiated for input."); \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
...
...
@@ -629,6 +711,14 @@ struct is_fp8<fp8e4m3> : std::true_type {};
template
<
>
struct
is_fp8
<
fp8e5m2
>
:
std
::
true_type
{};
template
<
typename
T
>
struct
is_fp4
:
std
::
false_type
{};
#if FP4_TYPE_SUPPORTED
template
<
>
struct
is_fp4
<
fp4e2m1
>
:
std
::
true_type
{};
#endif
// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors
constexpr
size_t
scale_tensor_alignment_X_rowwise
=
4
;
constexpr
size_t
scale_tensor_alignment_Y_rowwise
=
128
;
...
...
@@ -647,13 +737,16 @@ inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) {
}
size_t
typeToSize
(
const
DType
type
);
size_t
typeToNumBits
(
const
DType
type
);
size_t
get_buffer_size_bytes
(
const
size_t
N
,
const
DType
buffer_dtype
);
size_t
get_buffer_size_bytes
(
const
size_t
dim_first
,
const
size_t
dim_last
,
const
DType
buffer_dtype
);
void
CheckNoopTensor
(
const
Tensor
&
t
,
const
std
::
string
&
name
);
void
CheckInputTensor
(
const
Tensor
&
t
,
const
std
::
string
&
name
);
void
CheckOutputTensor
(
const
Tensor
&
t
,
const
std
::
string
&
name
,
bool
allow_empty
=
false
);
bool
is_fp8_dtype
(
const
DType
t
);
/*! \brief Update a tensor's FP8 scale-inverse
*
* The FP8 scale-inverse (dequantization scaling factor) is updated
...
...
@@ -673,7 +766,7 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
void
create_2D_tensor_map
(
CUtensorMap
&
tensorMap
,
const
SimpleTensor
&
tensor
,
const
uint64_t
globalY
,
const
uint64_t
globalX
,
const
uint32_t
shmemY
,
const
uint32_t
shmemX
,
const
uint32_t
stride_elems
,
const
uint32_t
offset_elems
,
const
size_t
type_
size
);
const
uint32_t
offset_elems
,
const
size_t
type_
num_bits
);
#endif
bool
is_supported_by_CC_100
();
...
...
@@ -681,6 +774,8 @@ bool is_supported_by_CC_100();
std
::
vector
<
std
::
vector
<
Tensor
*>>
convert_tensor_array
(
NVTETensor
**
nvte_tensors
,
size_t
outer_size
,
size_t
inner_size
);
Tensor
*
convertNVTETensor
(
const
NVTETensor
tensor
);
Tensor
*
convertNVTETensorCheck
(
const
NVTETensor
tensor
);
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
Prev
1
2
3
4
5
6
7
8
9
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment