Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import re
import gc
import torch
from typing import List
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformers.modeling_utils import load_state_dict
from transformers.utils.hub import get_checkpoint_shard_files
"""
This file contains logic of mapping the HuggingFace GemmaModel parameters
with TransformerEngine TransformerLayer. When we have initialized Transformer models
both with HF and with TE, we can copy parameters from the first to the second.
"""
def _load_weights_for_fp8_model(vanilla_model, hyperparams):
"""
Loads weights and FP8 metadata from a calibrated weights file.
The weights are in BF16 precision, but the state dict also contains
fp8 metadata computed by the calibration procedure.
"""
fp8_metadata_sd = torch.load(hyperparams.fp8_model_weights_filename)
# A hack to remove the extra state from the fp8_metadata_sd
# that contains the extra state from the core_attention module.
fp8_metadata_sd = {
k: v for k, v in fp8_metadata_sd.items() if "core_attention._extra_state" not in k
}
vanilla_model.load_state_dict(
fp8_metadata_sd,
strict=False,
# Because some parameters have multiple pointers to the same weight
# vanilla_model._model_context_phase.model and
# vanilla_model._model_generation_phase.model we need to load the
# weights in a non-strict manner.
)
def _load_weights_for_standard_model(vanilla_model, config):
"""
Loads weights from the HuggingFace checkpoint.
"""
archive_file = os.path.join(config.weights_cache_dir, "model.safetensors.index.json")
resolved_archive_file, _ = get_checkpoint_shard_files(config.weights_cache_dir, archive_file)
total_dict = {}
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
total_dict.update(state_dict)
replace_params(
total_dict,
vanilla_model.state_dict(),
config,
qkv_fused_and_interleaved=config.fuse_qkv_params,
)
# Copy remaining parameters like embedding.
vanilla_model.load_state_dict(total_dict, strict=False)
# Force mem release. Taken from huggingface code.
del total_dict
gc.collect()
def load_te_model(cls, config):
"""
Loads the TE model with proper weights.
"""
# Force the dtype to bfloat16 while loading the model.
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)
"""
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo:
https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
config.use_cache = False # To make TransformerLayer compatible with GemmaModel
# Loading model with FP8 only weights needs both the following context managers.
# 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights.
# 2. torch.no_grad() during TE modules' initilization so that they respect
# the `fp8_model_init` context manager.
with torch.no_grad(), fp8_model_init(config.fp8_model_init):
# Just create a model with random weights.
vanilla_model = cls(config).cuda()
# Copy proper weights into the model. If loading weights with FP8 metadata,
# then the source weights are basically the same as the weights in the model.
# If not, then we need to load the weights from the HuggingFace checkpoint
# and do mapping of the weight names from HF to the TE model.
if config.fp8_model_weights_filename is not None:
_load_weights_for_fp8_model(vanilla_model, config)
else:
_load_weights_for_standard_model(vanilla_model, config)
# Restore the original dtype.
torch.set_default_dtype(old_dtype)
return vanilla_model
def _get_all_layer_prefixes_to_update(hf_state_dict):
"""
There are many parameters in hf_state_dict, whose name start with "model.layers.[number]."
This function extracts all strings like "model.layers.[number]."
that are starting strings of keys in hf_state_dict.
"""
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = "model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
return all_layer_prefixes
def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False):
"""
Replaces params from TE TransformerLayer state_dict with corresponding parameters
from HuggingFace GemmaModel state_dict.
"""
all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict)
for layer_prefix in all_layer_prefixes:
def copy_from_ht_to_te(te_name, hf_name, start=None, end=None):
te_state_dict[layer_prefix + te_name].data[start:end].copy_(
hf_state_dict[layer_prefix + hf_name]
)
copy_from_ht_to_te(
"self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight"
)
copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight")
copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight")
copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight")
copy_from_ht_to_te(
"layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size
)
copy_from_ht_to_te(
"layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size
)
if qkv_fused_and_interleaved:
"""
When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor
in TE TransformerLayer. Moreover they are interleaved within each head.
Let q_i, k_i and v_i be query, key and value layers for i-th head respectively.
Then TE stores weight tensor in the form:
[q1 k1 v1 q2 k2 v2 ...]
This is done to maximally optimize performance time.
"""
te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"]
def copy_interleave(hf_name, idx):
src = hf_state_dict[layer_prefix + hf_name]
for head_nr in range(config.num_attention_heads):
dst_offset = head_nr * config.head_dim * 3
dst_slice = slice(
dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim
)
src_slice = slice(
head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim
)
te_qkv_layer[dst_slice, :] = src[src_slice, :]
copy_interleave("self_attn.q_proj.weight", 0)
copy_interleave("self_attn.k_proj.weight", 1)
copy_interleave("self_attn.v_proj.weight", 2)
else:
copy_from_ht_to_te(
"self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight"
)
copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight")
copy_from_ht_to_te(
"self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight"
)
return all_layer_prefixes
This diff is collapsed.
This diff is collapsed.
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"id": "6a5b2993", "id": "6a5b2993",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine\n", "# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n",
"\n", "\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"\n", "\n",
......
...@@ -46,6 +46,7 @@ Transformer Engine documentation ...@@ -46,6 +46,7 @@ Transformer Engine documentation
examples/fp8_primer.ipynb examples/fp8_primer.ipynb
examples/advanced_optimizations.ipynb examples/advanced_optimizations.ipynb
examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
examples/te_gemma/tutorial_generation_gemma_with_te.ipynb
examples/onnx/onnx_export.ipynb examples/onnx/onnx_export.ipynb
.. toctree:: .. toctree::
......
...@@ -267,7 +267,10 @@ def train_and_evaluate(args): ...@@ -267,7 +267,10 @@ def train_and_evaluate(args):
) as mesh, te.fp8_autocast( ) as mesh, te.fp8_autocast(
enabled=args.use_fp8, enabled=args.use_fp8,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), mesh_resource=te.MeshResource(
dp_resource=DEVICE_DP_AXIS,
tpsp_resource=DEVICE_TP_AXIS,
),
): ):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
......
...@@ -264,7 +264,7 @@ def train_and_evaluate(args): ...@@ -264,7 +264,7 @@ def train_and_evaluate(args):
) as mesh, te.fp8_autocast( ) as mesh, te.fp8_autocast(
enabled=args.use_fp8, enabled=args.use_fp8,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None), mesh_resource=te.MeshResource(dp_resource=DEVICE_DP_AXIS),
): ):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
......
...@@ -382,7 +382,10 @@ def train_and_evaluate(args): ...@@ -382,7 +382,10 @@ def train_and_evaluate(args):
) as mesh, te.fp8_autocast( ) as mesh, te.fp8_autocast(
enabled=args.use_fp8, enabled=args.use_fp8,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), mesh_resource=te.MeshResource(
dp_resource=DEVICE_DP_AXIS,
tpsp_resource=DEVICE_TP_AXIS,
),
): ):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
......
...@@ -219,7 +219,9 @@ def train_and_evaluate(args): ...@@ -219,7 +219,9 @@ def train_and_evaluate(args):
else: else:
fp8_recipe = None fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
encoder = Net(num_embed) encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int # We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
......
...@@ -193,7 +193,9 @@ def train_and_evaluate(args): ...@@ -193,7 +193,9 @@ def train_and_evaluate(args):
else: else:
fp8_recipe = None fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
cnn = Net(args.use_te) cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum) tx = optax.sgd(args.lr, args.momentum)
......
...@@ -263,7 +263,13 @@ def _train(opts): ...@@ -263,7 +263,13 @@ def _train(opts):
te.module.base.initialize_ub( te.module.base.initialize_ub(
[batched_size, hidden_size], [batched_size, hidden_size],
tp_size, tp_size,
use_fp8=opts.fp8, quantization_modes=[
(
te.module.base.UserBufferQuantizationMode.FP8
if opts.fp8
else te.module.base.UserBufferQuantizationMode.NONE
)
],
dtype=torch.bfloat16, dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend, bootstrap_backend=opts.bootstrap_backend,
) )
......
...@@ -23,38 +23,33 @@ set -x ...@@ -23,38 +23,33 @@ set -x
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime"
pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/test_batched_linear.xml $TE_PATH/tests/pytorch/test_batched_linear.py || test_fail "test_batched_linear.py" ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
# channelwise int8 test # channelwise int8 test
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py" NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
# Find TE
: ${TE_PATH:=/opt/transformerengine}
TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}')
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then
cd $TE_PATH/tests/cpp_distributed
cmake -GNinja -S. -Bbuild
cmake --build build
mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm
fi
...@@ -9,3 +9,4 @@ set -xe ...@@ -9,3 +9,4 @@ set -xe
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
...@@ -35,6 +35,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_ ...@@ -35,6 +35,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
pip3 install onnxruntime==1.20.1
pip3 install onnxruntime_extensions==0.13.0
: ${TE_PATH:=/opt/transformerengine}
python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py
This diff is collapsed.
...@@ -77,6 +77,7 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_ ...@@ -77,6 +77,7 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_
message(STATUS "Found transformer_engine library: ${TE_LIB}") message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include) include_directories(../../transformer_engine/common/include)
include_directories(../../transformer_engine/common) include_directories(../../transformer_engine/common)
include_directories(../../transformer_engine)
include_directories(${CMAKE_SOURCE_DIR}) include_directories(${CMAKE_SOURCE_DIR})
if(USE_CUDA) if(USE_CUDA)
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment