Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
27ddce40
Commit
27ddce40
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_main'
parents
d262ef4c
5b3092a0
Changes
208
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1661 additions
and
68 deletions
+1661
-68
docs/examples/te_gemma/te_gemma_loading_weights.py
docs/examples/te_gemma/te_gemma_loading_weights.py
+189
-0
docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb
...examples/te_gemma/tutorial_generation_gemma_with_te.ipynb
+941
-0
docs/examples/te_gemma/utils.py
docs/examples/te_gemma/utils.py
+370
-0
docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
...mples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
+1
-1
docs/index.rst
docs/index.rst
+1
-0
examples/jax/encoder/test_model_parallel_encoder.py
examples/jax/encoder/test_model_parallel_encoder.py
+4
-1
examples/jax/encoder/test_multigpu_encoder.py
examples/jax/encoder/test_multigpu_encoder.py
+1
-1
examples/jax/encoder/test_multiprocessing_encoder.py
examples/jax/encoder/test_multiprocessing_encoder.py
+4
-1
examples/jax/encoder/test_single_gpu_encoder.py
examples/jax/encoder/test_single_gpu_encoder.py
+3
-1
examples/jax/mnist/test_single_gpu_mnist.py
examples/jax/mnist/test_single_gpu_mnist.py
+3
-1
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
+7
-1
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+23
-28
qa/L1_cpp_distributed/test.sh
qa/L1_cpp_distributed/test.sh
+17
-0
qa/L1_jax_distributed_unittest/test.sh
qa/L1_jax_distributed_unittest/test.sh
+1
-0
qa/L1_pytorch_distributed_unittest/test.sh
qa/L1_pytorch_distributed_unittest/test.sh
+1
-0
qa/L1_pytorch_onnx_unittest/test.sh
qa/L1_pytorch_onnx_unittest/test.sh
+11
-0
setup.py
setup.py
+14
-0
tests/cpp/CMakeLists.txt
tests/cpp/CMakeLists.txt
+1
-0
tests/cpp/operator/test_normalization.cu
tests/cpp/operator/test_normalization.cu
+65
-31
tests/cpp/operator/test_normalization.h
tests/cpp/operator/test_normalization.h
+4
-2
No files found.
docs/examples/te_gemma/te_gemma_loading_weights.py
0 → 100755
View file @
27ddce40
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
os
import
re
import
gc
import
torch
from
typing
import
List
from
transformer_engine.pytorch.fp8
import
fp8_model_init
from
transformers.modeling_utils
import
load_state_dict
from
transformers.utils.hub
import
get_checkpoint_shard_files
"""
This file contains logic of mapping the HuggingFace GemmaModel parameters
with TransformerEngine TransformerLayer. When we have initialized Transformer models
both with HF and with TE, we can copy parameters from the first to the second.
"""
def
_load_weights_for_fp8_model
(
vanilla_model
,
hyperparams
):
"""
Loads weights and FP8 metadata from a calibrated weights file.
The weights are in BF16 precision, but the state dict also contains
fp8 metadata computed by the calibration procedure.
"""
fp8_metadata_sd
=
torch
.
load
(
hyperparams
.
fp8_model_weights_filename
)
# A hack to remove the extra state from the fp8_metadata_sd
# that contains the extra state from the core_attention module.
fp8_metadata_sd
=
{
k
:
v
for
k
,
v
in
fp8_metadata_sd
.
items
()
if
"core_attention._extra_state"
not
in
k
}
vanilla_model
.
load_state_dict
(
fp8_metadata_sd
,
strict
=
False
,
# Because some parameters have multiple pointers to the same weight
# vanilla_model._model_context_phase.model and
# vanilla_model._model_generation_phase.model we need to load the
# weights in a non-strict manner.
)
def
_load_weights_for_standard_model
(
vanilla_model
,
config
):
"""
Loads weights from the HuggingFace checkpoint.
"""
archive_file
=
os
.
path
.
join
(
config
.
weights_cache_dir
,
"model.safetensors.index.json"
)
resolved_archive_file
,
_
=
get_checkpoint_shard_files
(
config
.
weights_cache_dir
,
archive_file
)
total_dict
=
{}
for
shard_file
in
resolved_archive_file
:
state_dict
=
load_state_dict
(
shard_file
)
total_dict
.
update
(
state_dict
)
replace_params
(
total_dict
,
vanilla_model
.
state_dict
(),
config
,
qkv_fused_and_interleaved
=
config
.
fuse_qkv_params
,
)
# Copy remaining parameters like embedding.
vanilla_model
.
load_state_dict
(
total_dict
,
strict
=
False
)
# Force mem release. Taken from huggingface code.
del
total_dict
gc
.
collect
()
def
load_te_model
(
cls
,
config
):
"""
Loads the TE model with proper weights.
"""
# Force the dtype to bfloat16 while loading the model.
old_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
torch
.
bfloat16
)
"""
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo:
https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
config
.
use_cache
=
False
# To make TransformerLayer compatible with GemmaModel
# Loading model with FP8 only weights needs both the following context managers.
# 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights.
# 2. torch.no_grad() during TE modules' initilization so that they respect
# the `fp8_model_init` context manager.
with
torch
.
no_grad
(),
fp8_model_init
(
config
.
fp8_model_init
):
# Just create a model with random weights.
vanilla_model
=
cls
(
config
).
cuda
()
# Copy proper weights into the model. If loading weights with FP8 metadata,
# then the source weights are basically the same as the weights in the model.
# If not, then we need to load the weights from the HuggingFace checkpoint
# and do mapping of the weight names from HF to the TE model.
if
config
.
fp8_model_weights_filename
is
not
None
:
_load_weights_for_fp8_model
(
vanilla_model
,
config
)
else
:
_load_weights_for_standard_model
(
vanilla_model
,
config
)
# Restore the original dtype.
torch
.
set_default_dtype
(
old_dtype
)
return
vanilla_model
def
_get_all_layer_prefixes_to_update
(
hf_state_dict
):
"""
There are many parameters in hf_state_dict, whose name start with "model.layers.[number]."
This function extracts all strings like "model.layers.[number]."
that are starting strings of keys in hf_state_dict.
"""
all_layer_prefixes
=
set
()
for
param_key
in
hf_state_dict
.
keys
():
layer_prefix_pat
=
"model.layers.\d+."
m
=
re
.
match
(
layer_prefix_pat
,
param_key
)
if
m
is
not
None
:
all_layer_prefixes
.
add
(
m
.
group
())
return
all_layer_prefixes
def
replace_params
(
hf_state_dict
,
te_state_dict
,
config
,
qkv_fused_and_interleaved
=
False
):
"""
Replaces params from TE TransformerLayer state_dict with corresponding parameters
from HuggingFace GemmaModel state_dict.
"""
all_layer_prefixes
:
List
[
str
]
=
_get_all_layer_prefixes_to_update
(
hf_state_dict
)
for
layer_prefix
in
all_layer_prefixes
:
def
copy_from_ht_to_te
(
te_name
,
hf_name
,
start
=
None
,
end
=
None
):
te_state_dict
[
layer_prefix
+
te_name
].
data
[
start
:
end
].
copy_
(
hf_state_dict
[
layer_prefix
+
hf_name
]
)
copy_from_ht_to_te
(
"self_attention.layernorm_qkv.layer_norm_weight"
,
"input_layernorm.weight"
)
copy_from_ht_to_te
(
"self_attention.proj.weight"
,
"self_attn.o_proj.weight"
)
copy_from_ht_to_te
(
"layernorm_mlp.layer_norm_weight"
,
"post_attention_layernorm.weight"
)
copy_from_ht_to_te
(
"layernorm_mlp.fc2_weight"
,
"mlp.down_proj.weight"
)
copy_from_ht_to_te
(
"layernorm_mlp.fc1_weight"
,
"mlp.gate_proj.weight"
,
end
=
config
.
intermediate_size
)
copy_from_ht_to_te
(
"layernorm_mlp.fc1_weight"
,
"mlp.up_proj.weight"
,
start
=
config
.
intermediate_size
)
if
qkv_fused_and_interleaved
:
"""
When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor
in TE TransformerLayer. Moreover they are interleaved within each head.
Let q_i, k_i and v_i be query, key and value layers for i-th head respectively.
Then TE stores weight tensor in the form:
[q1 k1 v1 q2 k2 v2 ...]
This is done to maximally optimize performance time.
"""
te_qkv_layer
=
te_state_dict
[
layer_prefix
+
"self_attention.layernorm_qkv.weight"
]
def
copy_interleave
(
hf_name
,
idx
):
src
=
hf_state_dict
[
layer_prefix
+
hf_name
]
for
head_nr
in
range
(
config
.
num_attention_heads
):
dst_offset
=
head_nr
*
config
.
head_dim
*
3
dst_slice
=
slice
(
dst_offset
+
idx
*
config
.
head_dim
,
dst_offset
+
(
idx
+
1
)
*
config
.
head_dim
)
src_slice
=
slice
(
head_nr
*
config
.
head_dim
,
head_nr
*
config
.
head_dim
+
config
.
head_dim
)
te_qkv_layer
[
dst_slice
,
:]
=
src
[
src_slice
,
:]
copy_interleave
(
"self_attn.q_proj.weight"
,
0
)
copy_interleave
(
"self_attn.k_proj.weight"
,
1
)
copy_interleave
(
"self_attn.v_proj.weight"
,
2
)
else
:
copy_from_ht_to_te
(
"self_attention.layernorm_qkv.query_weight"
,
"self_attn.q_proj.weight"
)
copy_from_ht_to_te
(
"self_attention.layernorm_qkv.key_weight"
,
"self_attn.k_proj.weight"
)
copy_from_ht_to_te
(
"self_attention.layernorm_qkv.value_weight"
,
"self_attn.v_proj.weight"
)
return
all_layer_prefixes
docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb
0 → 100755
View file @
27ddce40
This diff is collapsed.
Click to expand it.
docs/examples/te_gemma/utils.py
0 → 100755
View file @
27ddce40
This diff is collapsed.
Click to expand it.
docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
View file @
27ddce40
...
...
@@ -5,7 +5,7 @@
"id": "6a5b2993",
"metadata": {},
"source": [
"# Accelerating
a
Hugging Face Llama 2 and
Llama 3 models
with Transformer Engine\n",
"# Accelerating Hugging Face Llama 2 and
3 Fine-Tuning
with Transformer Engine\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
...
...
docs/index.rst
View file @
27ddce40
...
...
@@ -46,6 +46,7 @@ Transformer Engine documentation
examples/fp8_primer.ipynb
examples/advanced_optimizations.ipynb
examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
examples/te_gemma/tutorial_generation_gemma_with_te.ipynb
examples/onnx/onnx_export.ipynb
.. toctree::
...
...
examples/jax/encoder/test_model_parallel_encoder.py
View file @
27ddce40
...
...
@@ -267,7 +267,10 @@ def train_and_evaluate(args):
)
as
mesh
,
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
mesh_resource
=
te
.
MeshResource
(
dp_resource
=
DEVICE_DP_AXIS
,
tpsp_resource
=
DEVICE_TP_AXIS
,
),
):
rng
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
rng
,
params_rng
=
jax
.
random
.
split
(
rng
)
...
...
examples/jax/encoder/test_multigpu_encoder.py
View file @
27ddce40
...
...
@@ -264,7 +264,7 @@ def train_and_evaluate(args):
)
as
mesh
,
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
None
,
None
,
None
),
mesh_resource
=
te
.
MeshResource
(
dp_resource
=
DEVICE_DP_AXIS
),
):
rng
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
...
...
examples/jax/encoder/test_multiprocessing_encoder.py
View file @
27ddce40
...
...
@@ -382,7 +382,10 @@ def train_and_evaluate(args):
)
as
mesh
,
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
mesh_resource
=
te
.
MeshResource
(
dp_resource
=
DEVICE_DP_AXIS
,
tpsp_resource
=
DEVICE_TP_AXIS
,
),
):
rng
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
rng
,
params_rng
=
jax
.
random
.
split
(
rng
)
...
...
examples/jax/encoder/test_single_gpu_encoder.py
View file @
27ddce40
...
...
@@ -219,7 +219,9 @@ def train_and_evaluate(args):
else
:
fp8_recipe
=
None
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
):
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
sharding
.
MeshResource
()
):
encoder
=
Net
(
num_embed
)
# We use nn.Embed, thus inputs need to be in int
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
...
examples/jax/mnist/test_single_gpu_mnist.py
View file @
27ddce40
...
...
@@ -193,7 +193,9 @@ def train_and_evaluate(args):
else
:
fp8_recipe
=
None
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
):
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
sharding
.
MeshResource
()
):
cnn
=
Net
(
args
.
use_te
)
var_collect
=
cnn
.
init
(
init_rngs
,
jnp
.
empty
(
input_shape
,
dtype
=
jnp
.
bfloat16
))
tx
=
optax
.
sgd
(
args
.
lr
,
args
.
momentum
)
...
...
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
View file @
27ddce40
...
...
@@ -263,7 +263,13 @@ def _train(opts):
te
.
module
.
base
.
initialize_ub
(
[
batched_size
,
hidden_size
],
tp_size
,
use_fp8
=
opts
.
fp8
,
quantization_modes
=
[
(
te
.
module
.
base
.
UserBufferQuantizationMode
.
FP8
if
opts
.
fp8
else
te
.
module
.
base
.
UserBufferQuantizationMode
.
NONE
)
],
dtype
=
torch
.
bfloat16
,
bootstrap_backend
=
opts
.
bootstrap_backend
,
)
...
...
qa/L0_pytorch_unittest/test.sh
View file @
27ddce40
...
...
@@ -23,38 +23,33 @@ set -x
mkdir
-p
"
$XML_LOG_DIR
"
pip3
install
pytest
==
8.2.1
||
error_exit
"Failed to install pytest"
pip3
install
onnxruntime
==
1.20.1
||
error_exit
"Failed to install onnxruntime"
pip3
install
onnxruntime_extensions
==
0.13.0
||
error_exit
"Failed to install onnxruntime_extensions"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_sanity.xml
$TE_PATH
/tests/pytorch/test_sanity.py
||
test_fail
"test_sanity.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_recipe.xml
$TE_PATH
/tests/pytorch/test_recipe.py
||
test_fail
"test_recipe.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_deferred_init.xml
$TE_PATH
/tests/pytorch/test_deferred_init.py
||
test_fail
"test_deferred_init.py"
ROCBLAS_ATOMICS_MOD
=
0
HIPBLASLT_ATOMICS_MOD
=
0
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_numerics.xml
$TE_PATH
/tests/pytorch/test_numerics.py
||
test_fail
"test_numerics.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_batched_linear.xml
$TE_PATH
/tests/pytorch/test_batched_linear.py
||
test_fail
"test_batched_linear.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_cuda_graphs.xml
$TE_PATH
/tests/pytorch/test_cuda_graphs.py
||
test_fail
"test_cuda_graphs.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_jit.xml
$TE_PATH
/tests/pytorch/test_jit.py
||
test_fail
"test_jit.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_rope.xml
$TE_PATH
/tests/pytorch/test_fused_rope.py
||
test_fail
"test_fused_rope.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8tensor.xml
$TE_PATH
/tests/pytorch/test_float8tensor.py
||
test_fail
"test_float8tensor.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8blockwisetensor.xml
$TE_PATH
/tests/pytorch/test_float8blockwisetensor.py
||
test_fail
"test_float8blockwisetensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_sanity.xml
$TE_PATH
/tests/pytorch/test_sanity.py
||
test_fail
"test_sanity.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_recipe.xml
$TE_PATH
/tests/pytorch/test_recipe.py
||
test_fail
"test_recipe.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_deferred_init.xml
$TE_PATH
/tests/pytorch/test_deferred_init.py
||
test_fail
"test_deferred_init.py"
ROCBLAS_ATOMICS_MOD
=
0
HIPBLASLT_ATOMICS_MOD
=
0
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_numerics.xml
$TE_PATH
/tests/pytorch/test_numerics.py
||
test_fail
"test_numerics.py"
ROCBLAS_ATOMICS_MOD
=
0
HIPBLASLT_ATOMICS_MOD
=
0
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_cuda_graphs.xml
$TE_PATH
/tests/pytorch/test_cuda_graphs.py
||
test_fail
"test_cuda_graphs.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_jit.xml
$TE_PATH
/tests/pytorch/test_jit.py
||
test_fail
"test_jit.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_rope.xml
$TE_PATH
/tests/pytorch/test_fused_rope.py
||
test_fail
"test_fused_rope.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8tensor.xml
$TE_PATH
/tests/pytorch/test_float8tensor.py
||
test_fail
"test_float8tensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8blockwisetensor.xml
$TE_PATH
/tests/pytorch/test_float8blockwisetensor.py
||
test_fail
"test_float8blockwisetensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
# channelwise int8 test
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8
=
1
NVTE_INT8_SIM_FP8_TENSORWISE
=
1 python3
-m
pytest
-v
-s
test_float8_current_scaling_exact.py
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
python3
$TE_PATH
/tests/pytorch/test_int8_blockwise_gemm_exact.py
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_gqa.xml
$TE_PATH
/tests/pytorch/test_gqa.py
||
test_fail
"test_gqa.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_optimizer.xml
$TE_PATH
/tests/pytorch/test_fused_optimizer.py
||
test_fail
"test_fused_optimizer.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_multi_tensor.xml
$TE_PATH
/tests/pytorch/test_multi_tensor.py
||
test_fail
"test_multi_tensor.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_onnx_export.xml
$TE_PATH
/tests/pytorch/test_onnx_export.py
||
test_fail
"test_onnx_export.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fusible_ops.xml
$TE_PATH
/tests/pytorch/test_fusible_ops.py
||
test_fail
"test_fusible_ops.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_permutation.xml
$TE_PATH
/tests/pytorch/test_permutation.py
||
test_fail
"test_permutation.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_parallel_cross_entropy.xml
$TE_PATH
/tests/pytorch/test_parallel_cross_entropy.py
||
test_fail
"test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN
=
0 python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_cpu_offloading.xml
$TE_PATH
/tests/pytorch/test_cpu_offloading.py
||
test_fail
"test_cpu_offloading.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_attention.xml
$TE_PATH
/tests/pytorch/attention/test_attention.py
||
test_fail
"test_attention.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_kv_cache.xml
$TE_PATH
/tests/pytorch/attention/test_kv_cache.py
||
test_fail
"test_kv_cache.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_hf_integration.xml
$TE_PATH
/tests/pytorch/test_hf_integration.py
||
test_fail
"test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH
=
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_checkpoint.xml
$TE_PATH
/tests/pytorch/test_checkpoint.py
||
test_fail
"test_checkpoint.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_router.xml
$TE_PATH
/tests/pytorch/test_fused_router.py
||
test_fail
"test_fused_router.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_gqa.xml
$TE_PATH
/tests/pytorch/test_gqa.py
||
test_fail
"test_gqa.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_optimizer.xml
$TE_PATH
/tests/pytorch/test_fused_optimizer.py
||
test_fail
"test_fused_optimizer.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_multi_tensor.xml
$TE_PATH
/tests/pytorch/test_multi_tensor.py
||
test_fail
"test_multi_tensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fusible_ops.xml
$TE_PATH
/tests/pytorch/test_fusible_ops.py
||
test_fail
"test_fusible_ops.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_permutation.xml
$TE_PATH
/tests/pytorch/test_permutation.py
||
test_fail
"test_permutation.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_parallel_cross_entropy.xml
$TE_PATH
/tests/pytorch/test_parallel_cross_entropy.py
||
test_fail
"test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN
=
0 python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_cpu_offloading.xml
$TE_PATH
/tests/pytorch/test_cpu_offloading.py
||
test_fail
"test_cpu_offloading.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_attention.xml
$TE_PATH
/tests/pytorch/attention/test_attention.py
||
test_fail
"test_attention.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_kv_cache.xml
$TE_PATH
/tests/pytorch/attention/test_kv_cache.py
||
test_fail
"test_kv_cache.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_hf_integration.xml
$TE_PATH
/tests/pytorch/test_hf_integration.py
||
test_fail
"test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH
=
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_checkpoint.xml
$TE_PATH
/tests/pytorch/test_checkpoint.py
||
test_fail
"test_checkpoint.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_router.xml
$TE_PATH
/tests/pytorch/test_fused_router.py
||
test_fail
"test_fused_router.py"
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
...
...
qa/L1_cpp_distributed/test.sh
0 → 100755
View file @
27ddce40
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set
-e
# Find TE
:
${
TE_PATH
:
=/opt/transformerengine
}
TE_LIB_PATH
=
$(
pip3 show transformer-engine |
grep
-E
"Location:|Editable project location:"
|
tail
-n
1 |
awk
'{print $NF}'
)
export
LD_LIBRARY_PATH
=
$TE_LIB_PATH
:
$LD_LIBRARY_PATH
if
[[
$(
nvidia-smi
--list-gpus
|
wc
-l
)
-ge
4
]]
;
then
cd
$TE_PATH
/tests/cpp_distributed
cmake
-GNinja
-S
.
-Bbuild
cmake
--build
build
mpirun
--allow-run-as-root
--np
4
--oversubscribe
./build/test_comm_gemm
fi
qa/L1_jax_distributed_unittest/test.sh
View file @
27ddce40
...
...
@@ -9,3 +9,4 @@ set -xe
mkdir
-p
"
$XML_LOG_DIR
"
NVTE_JAX_UNITTEST_LEVEL
=
"L1"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest.xml
$TE_PATH
/tests/jax/test_distributed_
*
SCRIPT_NAME
=
test_multi_process_distributed_grouped_gemm.py bash
$TE_PATH
/tests/jax/multi_process_launch.sh
qa/L1_pytorch_distributed_unittest/test.sh
View file @
27ddce40
...
...
@@ -35,6 +35,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_
python3
-m
pytest
-v
-s
--log-cli-level
=
INFO
--junitxml
=
$XML_LOG_DIR
/pytest_test_comm_gemm_overlap.xml
$TE_PATH
/tests/pytorch/distributed/test_comm_gemm_overlap.py
||
test_fail
"test_comm_gemm_overlap.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fusible_ops_with_userbuffers.xml
$TE_PATH
/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
||
test_fail
"test_fusible_ops_with_userbuffers.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_attention_with_cp.xml
$TE_PATH
/tests/pytorch/attention/test_attention_with_cp.py
||
test_fail
"test_attention_with_cp.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_cp_utils.xml
$TE_PATH
/tests/pytorch/attention/test_cp_utils.py
||
test_fail
"test_cp_utils.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_cast_master_weights_to_fp8.xml
$TE_PATH
/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
||
test_fail
"test_cast_master_weights_to_fp8.py"
...
...
qa/L1_pytorch_onnx_unittest/test.sh
0 → 100644
View file @
27ddce40
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
pip3
install
onnxruntime
==
1.20.1
pip3
install
onnxruntime_extensions
==
0.13.0
:
${
TE_PATH
:
=/opt/transformerengine
}
python3
-m
pytest
--tb
=
auto
$TE_PATH
/tests/pytorch/test_onnx_export.py
setup.py
View file @
27ddce40
This diff is collapsed.
Click to expand it.
tests/cpp/CMakeLists.txt
View file @
27ddce40
...
...
@@ -77,6 +77,7 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_
message
(
STATUS
"Found transformer_engine library:
${
TE_LIB
}
"
)
include_directories
(
../../transformer_engine/common/include
)
include_directories
(
../../transformer_engine/common
)
include_directories
(
../../transformer_engine
)
include_directories
(
${
CMAKE_SOURCE_DIR
}
)
if
(
USE_CUDA
)
...
...
tests/cpp/operator/test_normalization.cu
View file @
27ddce40
This diff is collapsed.
Click to expand it.
tests/cpp/operator/test_normalization.h
View file @
27ddce40
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment