Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
53fa872c
Commit
53fa872c
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_release_v2.8' into release_v2.8
parents
27ddce40
40c69e75
Changes
159
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1338 additions
and
35 deletions
+1338
-35
.github/workflows/build.yml
.github/workflows/build.yml
+4
-4
benchmarks/benchmark_rht_cast.py
benchmarks/benchmark_rht_cast.py
+152
-0
build_tools/VERSION.txt
build_tools/VERSION.txt
+1
-1
build_tools/jax.py
build_tools/jax.py
+1
-0
build_tools/pytorch.py
build_tools/pytorch.py
+1
-1
build_tools/utils.py
build_tools/utils.py
+9
-6
examples/jax/collective_gemm/common.py
examples/jax/collective_gemm/common.py
+245
-0
examples/jax/collective_gemm/conftest.py
examples/jax/collective_gemm/conftest.py
+29
-0
examples/jax/collective_gemm/run_test_cgemm.sh
examples/jax/collective_gemm/run_test_cgemm.sh
+119
-0
examples/jax/collective_gemm/test_dense_grad.py
examples/jax/collective_gemm/test_dense_grad.py
+214
-0
examples/jax/collective_gemm/test_gemm.py
examples/jax/collective_gemm/test_gemm.py
+206
-0
examples/jax/collective_gemm/test_layernorm_mlp_grad.py
examples/jax/collective_gemm/test_layernorm_mlp_grad.py
+272
-0
examples/jax/encoder/run_test_multiprocessing_encoder.sh
examples/jax/encoder/run_test_multiprocessing_encoder.sh
+53
-8
qa/L0_jax_distributed_unittest/test.sh
qa/L0_jax_distributed_unittest/test.sh
+4
-0
qa/L0_jax_unittest/test.sh
qa/L0_jax_unittest/test.sh
+1
-1
qa/L0_pytorch_debug_unittest/test.sh
qa/L0_pytorch_debug_unittest/test.sh
+10
-9
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-0
qa/L1_pytorch_distributed_unittest/test.sh
qa/L1_pytorch_distributed_unittest/test.sh
+3
-2
qa/L1_pytorch_onnx_unittest/test.sh
qa/L1_pytorch_onnx_unittest/test.sh
+5
-3
tests/cpp/operator/CMakeLists.txt
tests/cpp/operator/CMakeLists.txt
+8
-0
No files found.
.github/workflows/build.yml
View file @
53fa872c
...
...
@@ -19,7 +19,7 @@ jobs:
run
:
|
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja
pip install cmake==3.21.0 pybind11[global] ninja
nvidia-mathdx==25.1.1
-
name
:
'
Checkout'
uses
:
actions/checkout@v3
with
:
...
...
@@ -43,7 +43,7 @@ jobs:
run
:
|
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
nvidia-mathdx==25.1.1
-
name
:
'
Checkout'
uses
:
actions/checkout@v3
with
:
...
...
@@ -63,7 +63,7 @@ jobs:
options
:
--user root
steps
:
-
name
:
'
Dependencies'
run
:
pip install pybind11[global]
run
:
pip install pybind11[global]
nvidia-mathdx==25.1.1
-
name
:
'
Checkout'
uses
:
actions/checkout@v3
with
:
...
...
@@ -83,7 +83,7 @@ jobs:
options
:
--user root
steps
:
-
name
:
'
Dependencies'
run
:
pip install torch pybind11[global] einops onnxscript
run
:
pip install torch pybind11[global] einops onnxscript
nvidia-mathdx==25.1.1
-
name
:
'
Checkout'
uses
:
actions/checkout@v3
with
:
...
...
benchmarks/benchmark_rht_cast.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
argparse
import
torch
import
pandas
as
pd
import
torch.utils.benchmark
as
benchmark
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
transformer_engine.pytorch.tensor.nvfp4_tensor
import
NVFP4Quantizer
scale_padding_to
=
1
permute_scale
=
False
TORCH_TO_TE_FLOAT_MAP
=
{
torch
.
bfloat16
:
tex
.
DType
.
kBFloat16
,
}
def
run_kernel
(
shape
,
stochastic_rounding
:
bool
,
input_dtype
=
torch
.
bfloat16
):
# Generate random input data
M
,
K
=
shape
x
=
torch
.
randn
([
M
,
K
],
dtype
=
input_dtype
,
device
=
"cuda"
)
assert
shape
[
0
]
%
16
==
0
,
"Shape must be divisible by 16"
assert
shape
[
1
]
%
16
==
0
,
"Shape must be divisible by 16"
# Quantize
nvfp4_quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
tex
.
DType
.
kFloat4E2M1
,
rowwise
=
True
,
columnwise
=
True
,
with_amax_reduction
=
False
,
amax_reduction_group
=
None
,
with_rht
=
True
,
with_post_rht_amax
=
True
,
with_random_sign_mask
=
True
,
stochastic_rounding
=
stochastic_rounding
,
)
x_nvfp4_sut
=
nvfp4_quantizer
.
make_empty
(
(
M
,
K
),
dtype
=
x
.
dtype
,
device
=
x
.
device
,
requires_grad
=
False
)
x_nvfp4_sut
=
nvfp4_quantizer
.
update_quantized
(
x
,
x_nvfp4_sut
)
with
torch
.
no_grad
():
stmt
=
"kernel_func(input, output)"
globals_dict
=
{
"kernel_func"
:
nvfp4_quantizer
.
update_quantized
,
"input"
:
x
,
"output"
:
x_nvfp4_sut
,
}
timing
=
benchmark
.
Timer
(
stmt
=
stmt
,
globals
=
globals_dict
,
num_threads
=
1
,
).
blocked_autorange
(
min_run_time
=
5
)
print
(
timing
)
timing_us
=
timing
.
median
*
1e6
input_nbytes
=
shape
[
0
]
*
shape
[
1
]
*
2
# bf16
output_nbytes
=
shape
[
0
]
*
shape
[
1
]
//
2
# //2 for fp4
sf_nbytes
=
shape
[
0
]
*
shape
[
1
]
//
16
# //16 for 1 byte per 16 elems
total_nbytes
=
(
0
+
input_nbytes
*
3
# Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T))
+
2
*
4
# Output 2 * float for scale & amax
+
2
*
4
# Input 2 * float
+
output_nbytes
*
2
# Output from Cast(x) and Cast(RHT(x.T))
+
sf_nbytes
*
2
# Scale factor
)
throughput_GBps
=
total_nbytes
/
(
1024
*
1024
*
1024
)
/
(
timing_us
/
1e6
)
print
(
f
"Stochastic rounding:
{
stochastic_rounding
}
, Total:
{
total_nbytes
}
bytes, Throughput:"
f
"
{
throughput_GBps
}
GB/s"
)
return
timing_us
,
throughput_GBps
# Nsight Compute Profiling Command:
# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"Enable profiling mode"
)
args
=
parser
.
parse_args
()
if
args
.
profile
:
print
(
"Profiling is enabled."
)
else
:
print
(
"Profiling is disabled."
)
shapes
=
[
(
8192
,
5120
),
(
8192
,
10240
),
(
8192
,
2560
),
(
8192
,
11328
),
(
8192
,
512
),
(
8192
,
3584
),
(
5120
,
8192
),
(
10240
,
8192
),
(
2560
,
8192
),
(
11328
,
8192
),
(
512
,
8192
),
(
3584
,
8192
),
(
4096
,
16384
),
(
14336
,
16384
),
]
if
args
.
profile
:
shapes
=
[
(
16384
,
6144
),
]
data
=
[]
for
stochastic_rounding
in
[
True
]:
# , False]:
for
shape
in
shapes
:
print
(
f
"Running benchmark_func with shape
{
shape
}
and stochastic_rounding"
f
"
{
stochastic_rounding
}
"
)
timing_us
,
throughput_GBps
=
run_kernel
(
shape
,
stochastic_rounding
)
data
.
append
(
[
"benchmark_func"
,
shape
,
stochastic_rounding
,
timing_us
,
throughput_GBps
,
]
)
df
=
pd
.
DataFrame
(
data
=
data
,
columns
=
[
"kernel"
,
"shape"
,
"stochastic_rounding"
,
"timing_us"
,
"throughput(GB/s)"
,
],
)
print
(
df
)
df
.
to_csv
(
"benchmark_cast_nvfp4.csv"
,
index
=
False
)
build_tools/VERSION.txt
View file @
53fa872c
2.
9.0.dev
0
2.
8.
0
build_tools/jax.py
View file @
53fa872c
...
...
@@ -87,4 +87,5 @@ def setup_jax_extension(
sources
=
[
str
(
path
)
for
path
in
sources
],
include_dirs
=
[
str
(
path
)
for
path
in
include_dirs
],
extra_compile_args
=
cxx_flags
,
libraries
=
[
"nccl"
],
)
build_tools/pytorch.py
View file @
53fa872c
...
...
@@ -14,7 +14,7 @@ from typing import List
def
install_requirements
()
->
List
[
str
]:
"""Install dependencies for TE/PyTorch extensions."""
return
[
"torch>=2.1"
,
"einops"
]
# "onnxscript
==0.3.1
", "onnx"]
return
[
"torch>=2.1"
,
"einops"
]
# "onnxscript", "onnx"]
def
test_requirements
()
->
List
[
str
]:
...
...
build_tools/utils.py
View file @
53fa872c
...
...
@@ -272,15 +272,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]:
@
functools
.
lru_cache
(
maxsize
=
None
)
def
cuda_archs
()
->
str
:
version
=
cuda_version
()
if
os
.
getenv
(
"NVTE_CUDA_ARCHS"
)
is
None
:
archs
=
os
.
getenv
(
"NVTE_CUDA_ARCHS"
)
if
archs
is
None
:
version
=
cuda_version
()
if
version
>=
(
13
,
0
):
os
.
environ
[
"NVTE_CUDA_ARCHS"
]
=
"75;80;89;90;100;120"
archs
=
"75;80;89;90;100;100a;103a;120"
elif
version
>=
(
12
,
9
):
archs
=
"70;80;89;90;100;100a;103a;120"
elif
version
>=
(
12
,
8
):
os
.
environ
[
"NVTE_CUDA_ARCHS"
]
=
"70;80;89;90;100;120"
archs
=
"70;80;89;90;100;
100a;
120"
else
:
os
.
environ
[
"NVTE_CUDA_ARCHS"
]
=
"70;80;89;90"
return
os
.
getenv
(
"NVTE_CUDA_ARCHS"
)
archs
=
"70;80;89;90"
return
archs
def
cuda_version
()
->
Tuple
[
int
,
...]:
...
...
examples/jax/collective_gemm/common.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Shared functions for the comm_overlap tests"""
import
jax.numpy
as
jnp
import
numpy
as
np
# Add this after your existing imports
def
dtype_tols
(
dtype
,
rtol
=
None
,
atol
=
None
):
"""Expected numerical tolerance for a data type."""
# Return immediately if tolerances are fully specified
if
rtol
is
not
None
and
atol
is
not
None
:
return
{
"rtol"
:
rtol
,
"atol"
:
atol
}
# Default tolerances for common dtypes
if
dtype
in
[
jnp
.
float32
,
"float32"
]:
return
{
"rtol"
:
1e-5
,
"atol"
:
1e-8
}
elif
dtype
in
[
jnp
.
float16
,
"float16"
]:
return
{
"rtol"
:
1e-3
,
"atol"
:
1e-6
}
elif
dtype
in
[
jnp
.
bfloat16
,
"bfloat16"
]:
return
{
"rtol"
:
1e-2
,
"atol"
:
1e-5
}
else
:
return
{
"rtol"
:
1e-5
,
"atol"
:
1e-8
}
def
assert_allclose
(
actual
,
desired
,
rtol
=
None
,
atol
=
None
,
dtype
=
None
,
**
kwargs
,
):
"""Check if two tensors are close."""
# Infer data type if needed
if
dtype
is
None
:
if
isinstance
(
actual
,
float
):
dtype
=
"float32"
else
:
dtype
=
actual
.
dtype
# Determine tolerances
tols
=
{}
if
rtol
is
None
or
atol
is
None
:
tols
=
dtype_tols
(
dtype
)
if
rtol
is
not
None
:
tols
[
"rtol"
]
=
rtol
if
atol
is
not
None
:
tols
[
"atol"
]
=
atol
# Cast tensors to fp32
if
not
isinstance
(
actual
,
float
):
actual
=
actual
.
astype
(
jnp
.
float32
)
if
not
isinstance
(
desired
,
float
):
desired
=
desired
.
astype
(
jnp
.
float32
)
# Check if tensors are close
np
.
testing
.
assert_allclose
(
actual
,
desired
,
**
tols
,
**
kwargs
)
def
assert_allclose_print_index
(
ref_output
,
gathered_output
,
rtol
=
1e-5
,
atol
=
1e-8
):
if
not
jnp
.
allclose
(
ref_output
,
gathered_output
,
rtol
=
rtol
,
atol
=
atol
):
diff
=
jnp
.
abs
(
ref_output
-
gathered_output
)
mask
=
diff
>
(
atol
+
rtol
*
jnp
.
abs
(
gathered_output
))
print
(
mask
.
astype
(
int
))
print
(
jnp
.
where
(
mask
,
diff
,
0
))
# Shared constants for all tests
DP_AXIS
=
"data"
TPSP_AXIS
=
"tensor_sequence"
PARAMS_KEY
=
"params"
# Shared functions for distributed testing
import
argparse
import
jax
from
jax.experimental
import
mesh_utils
from
transformer_engine.jax.cpp_extensions.gemm
import
collective_gemm_bootstrap
# Global flag to track if distributed has been initialized
_distributed_initialized
=
False
def
_is_distributed_initialized
():
"""Check if JAX distributed has been initialized."""
return
_distributed_initialized
def
_initialize_distributed
(
args
):
"""Initialize JAX distributed with custom arguments."""
global
_distributed_initialized
# Check if already initialized
if
_distributed_initialized
:
return
if
args
.
coordinator_address
is
None
or
args
.
num_processes
is
None
or
args
.
process_id
is
None
:
raise
ValueError
(
"All distributed initialization arguments are required: "
"--coordinator-address, --num-processes, --process-id"
)
if
args
.
local_device_ids
is
None
:
assert
(
args
.
num_devices_per_process
is
not
None
),
"Either local_device_ids or num_devices_per_process must be provided"
# Calculate device range for this process
# Single process single device: each process gets one unique device
# Single process multiple devices: each process gets a unique range of devices
start_device
=
args
.
process_id
*
args
.
num_devices_per_process
device_range
=
range
(
start_device
,
start_device
+
args
.
num_devices_per_process
)
global_device_ids_for_this_process
=
","
.
join
(
map
(
str
,
device_range
))
else
:
# Use explicitly provided global device IDs
global_device_ids_for_this_process
=
args
.
local_device_ids
args
.
num_devices_per_process
=
len
(
args
.
local_device_ids
.
split
(
","
))
assert
args
.
num_devices_per_process
==
1
,
"Only single process single GPU is supported!"
print
(
f
"Initializing JAX distributed with coordinator=
{
args
.
coordinator_address
}
, "
f
"num_processes=
{
args
.
num_processes
}
, process_id=
{
args
.
process_id
}
"
)
# Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process"
jax
.
distributed
.
initialize
(
coordinator_address
=
args
.
coordinator_address
,
num_processes
=
args
.
num_processes
,
process_id
=
args
.
process_id
,
local_device_ids
=
global_device_ids_for_this_process
,
)
_distributed_initialized
=
True
jax
.
clear_caches
()
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
False
)
# CollectiveGEMM does not work with Shardy yet
assert
jax
.
local_device_count
()
==
1
,
(
f
"[
{
args
.
process_id
}
|
{
args
.
num_devices_per_process
}
] Expected 1 GPU per process, found"
f
"
{
jax
.
local_device_count
()
}
"
)
devices_per_process
=
1
num_total_devices
=
args
.
num_processes
print
(
f
"Initializing CGEMM communicator with num_total_devices=
{
num_total_devices
}
,"
f
" devices_per_process=
{
devices_per_process
}
, process_id=
{
args
.
process_id
}
"
)
collective_gemm_bootstrap
(
num_total_devices
=
num_total_devices
,
num_devices_per_process
=
devices_per_process
,
process_id
=
args
.
process_id
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
)
def
_get_dp_and_tp_sizes
(
args
):
num_gpu
=
args
.
num_processes
*
args
.
num_devices_per_process
if
args
.
tensor_parallel_size
is
None
:
num_gpu_dp
=
2
if
args
.
enable_data_parallel
else
1
assert
(
num_gpu
>
1
and
num_gpu
%
num_gpu_dp
==
0
),
"Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_tp
=
num_gpu
//
num_gpu_dp
else
:
num_gpu_tp
=
args
.
tensor_parallel_size
assert
(
num_gpu
>
1
and
num_gpu
%
num_gpu_tp
==
0
),
"Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_dp
=
num_gpu
//
num_gpu_tp
return
num_gpu_dp
,
num_gpu_tp
def
_create_mesh
(
args
):
"""Create mesh configuration with proper validation."""
num_gpu
=
args
.
num_processes
*
args
.
num_devices_per_process
assert
num_gpu
==
len
(
jax
.
devices
()),
"Number of GPUs must be equal to number of devices"
num_gpu_dp
,
num_gpu_tp
=
_get_dp_and_tp_sizes
(
args
)
print
(
f
"Using
{
num_gpu_dp
}
x
{
num_gpu_tp
}
mesh (
{
num_gpu_dp
*
num_gpu_tp
}
total GPUs)"
)
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu_dp
,
num_gpu_tp
))
mesh
=
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DP_AXIS
,
TPSP_AXIS
))
return
mesh
def
cgemm_parser
(
description
=
"Collective GEMM test on multi-GPU with tensor parallelism"
):
"""Create common argument parser for all collective GEMM tests."""
parser
=
argparse
.
ArgumentParser
(
description
=
description
)
# Distributed initialization arguments
parser
.
add_argument
(
"--coordinator-address"
,
type
=
str
,
default
=
None
,
help
=
"Coordinator address for distributed initialization"
,
)
parser
.
add_argument
(
"--num-processes"
,
type
=
int
,
default
=
None
,
help
=
"Number of processes for distributed initialization"
,
)
parser
.
add_argument
(
"--process-id"
,
type
=
int
,
default
=
None
,
help
=
"Process ID for distributed initialization"
)
parser
.
add_argument
(
"--local-device-ids"
,
type
=
str
,
default
=
None
,
help
=
"Local device IDs for distributed initialization (comma-separated)"
,
)
parser
.
add_argument
(
"--num-devices-per-process"
,
type
=
int
,
default
=
1
,
help
=
"Number of devices per process"
)
# Test configuration arguments
parser
.
add_argument
(
"--tensor-parallel-size"
,
type
=
int
,
default
=
None
,
help
=
"Tensor parallel size"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size for testing"
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
8192
,
help
=
"Sequence length for testing"
)
parser
.
add_argument
(
"--hidden-in"
,
type
=
int
,
default
=
4096
,
help
=
"Input hidden dimension"
)
parser
.
add_argument
(
"--hidden-out"
,
type
=
int
,
default
=
8192
,
help
=
"Output hidden dimension"
)
parser
.
add_argument
(
"--collective-type"
,
type
=
str
,
default
=
"all_gather"
,
choices
=
[
"all_gather"
,
"reduce_scatter"
],
help
=
"Type of collective operation"
,
)
parser
.
add_argument
(
"--fp8-recipe"
,
type
=
str
,
default
=
"DelayedScaling"
,
help
=
"FP8 recipe to use"
)
parser
.
add_argument
(
"--enable-data-parallel"
,
action
=
"store_true"
,
help
=
"Enable data parallelism"
)
parser
.
add_argument
(
"--enable-result-check"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Enable result checking"
)
return
parser
examples/jax/collective_gemm/conftest.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""config for collective_gemm tests"""
import
pytest
def
pytest_addoption
(
parser
):
"""Pytest hook for collective_gemm tests"""
parser
.
addoption
(
"--coordinator-address"
,
action
=
"store"
,
default
=
"localhost:12345"
)
parser
.
addoption
(
"--num-processes"
,
action
=
"store"
,
default
=
1
)
parser
.
addoption
(
"--process-id"
,
action
=
"store"
,
default
=
0
)
parser
.
addoption
(
"--local-device-ids"
,
action
=
"store"
,
default
=
None
)
@
pytest
.
fixture
(
autouse
=
True
)
def
distributed_args
(
request
):
"""Fixture for querying distributed initialization arguments"""
if
request
.
cls
:
request
.
cls
.
coordinator_address
=
request
.
config
.
getoption
(
"--coordinator-address"
)
request
.
cls
.
num_processes
=
int
(
request
.
config
.
getoption
(
"--num-processes"
))
request
.
cls
.
process_id
=
int
(
request
.
config
.
getoption
(
"--process-id"
))
request
.
cls
.
local_device_ids
=
request
.
config
.
getoption
(
"--local-device-ids"
)
request
.
cls
.
num_devices_per_process
=
(
1
if
request
.
cls
.
local_device_ids
is
None
else
len
(
request
.
cls
.
local_device_ids
.
split
(
","
))
)
examples/jax/collective_gemm/run_test_cgemm.sh
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
NUM_GPUS
=
${
NUM_GPUS
:-
$(
nvidia-smi
-L
|
wc
-l
)
}
:
${
TE_PATH
:
=/opt/transformerengine
}
:
${
XML_LOG_DIR
:
=/logs
}
mkdir
-p
"
$XML_LOG_DIR
"
# Check if NVLINK is supported before running tests
echo
"*** Checking NVLINK support***"
NVLINK_OUTPUT
=
$(
nvidia-smi nvlink
--status
2>&1
)
NVLINK_EXIT_CODE
=
$?
# Check if command failed OR output indicates no NVLINK
if
[
$NVLINK_EXIT_CODE
-ne
0
]
||
[[
"
$NVLINK_OUTPUT
"
==
*
"not supported"
*
]]
||
[[
"
$NVLINK_OUTPUT
"
==
*
"No devices"
*
]]
||
[
-z
"
$NVLINK_OUTPUT
"
]
;
then
echo
"NVLINK is not supported on this platform"
echo
"Collective GEMM tests require NVLINK connectivity"
echo
"SKIPPING all tests"
exit
0
else
echo
"NVLINK support detected"
fi
# Define the test files to run
TEST_FILES
=(
"test_gemm.py"
"test_dense_grad.py"
"test_layernorm_mlp_grad.py"
)
echo
echo
"*** Executing tests in examples/jax/collective_gemm/ ***"
HAS_FAILURE
=
0
# Global failure flag
PIDS
=()
# Array to store all process PIDs
# Cleanup function to kill all processes
cleanup
()
{
for
pid
in
"
${
PIDS
[@]
}
"
;
do
if
kill
-0
"
$pid
"
2>/dev/null
;
then
echo
"Killing process
$pid
"
kill
-TERM
"
$pid
"
2>/dev/null
||
true
fi
done
# Wait a bit and force kill if needed
sleep
2
for
pid
in
"
${
PIDS
[@]
}
"
;
do
if
kill
-0
"
$pid
"
2>/dev/null
;
then
echo
"Force killing process
$pid
"
kill
-KILL
"
$pid
"
2>/dev/null
||
true
fi
done
}
# Set up signal handlers to cleanup on exit
trap
cleanup EXIT INT TERM
# Run each test file across all GPUs
for
TEST_FILE
in
"
${
TEST_FILES
[@]
}
"
;
do
echo
echo
"=== Starting test file:
$TEST_FILE
..."
# Clear PIDs array for this test file
PIDS
=()
for
i
in
$(
seq
0
$((
$NUM_GPUS
-
1
))
)
;
do
# Define output file for logs
LOG_FILE
=
"
${
TEST_FILE
}
_gpu_
${
i
}
.log"
if
[
$i
-eq
0
]
;
then
# For process 0: show live output AND save to log file using tee
echo
"=== Live output from process 0 ==="
pytest
-s
-c
"
$TE_PATH
/tests/jax/pytest.ini"
\
-vs
--junitxml
=
$XML_LOG_DIR
/collective_gemm_
${
TEST_FILE
}
.xml
\
"
$TE_PATH
/examples/jax/collective_gemm/
$TEST_FILE
"
\
--num-processes
=
$NUM_GPUS
\
--process-id
=
$i
2>&1 |
tee
"
$LOG_FILE
"
&
PID
=
$!
PIDS+
=(
$PID
)
else
# For other processes: redirect to log files only
pytest
-s
-c
"
$TE_PATH
/tests/jax/pytest.ini"
\
-vs
"
$TE_PATH
/examples/jax/collective_gemm/
$TEST_FILE
"
\
--num-processes
=
$NUM_GPUS
\
--process-id
=
$i
>
"
$LOG_FILE
"
2>&1 &
PID
=
$!
PIDS+
=(
$PID
)
fi
done
# Wait for all processes to finish
wait
# Check and print the log content from process 0 (now has log file thanks to tee)
if
grep
-q
"SKIPPED"
"
${
TEST_FILE
}
_gpu_0.log"
;
then
echo
"...
$TEST_FILE
SKIPPED"
elif
grep
-q
"FAILED"
"
${
TEST_FILE
}
_gpu_0.log"
;
then
echo
"...
$TEST_FILE
FAILED"
HAS_FAILURE
=
1
elif
grep
-q
"PASSED"
"
${
TEST_FILE
}
_gpu_0.log"
;
then
echo
"...
$TEST_FILE
PASSED"
else
echo
"...
$TEST_FILE
INVALID"
HAS_FAILURE
=
1
fi
# Remove the log files after processing them
wait
rm
${
TEST_FILE
}
_gpu_
*
.log
done
wait
# Final cleanup (trap will also call cleanup on exit)
cleanup
exit
$HAS_FAILURE
examples/jax/collective_gemm/test_dense_grad.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
import
argparse
import
unittest
import
os
import
jax
import
jax.numpy
as
jnp
from
jax.sharding
import
PartitionSpec
,
NamedSharding
import
flax
from
common
import
(
assert_allclose
,
_initialize_distributed
,
_get_dp_and_tp_sizes
,
_create_mesh
,
DP_AXIS
,
TPSP_AXIS
,
PARAMS_KEY
,
cgemm_parser
,
)
from
transformer_engine.jax.dense
import
dense
from
transformer_engine.jax.quantize
import
fp8_autocast
from
transformer_engine.jax.cpp_extensions.gemm
import
(
CollectiveOp
,
CollectiveOpSet
,
noop_collective_op_set
,
)
from
transformer_engine.jax.sharding
import
MeshResource
import
transformer_engine.jax.flax
as
te_flax
def
_get_logical_axes
(
collective_op
):
if
collective_op
.
is_all_gather
:
input_axes
=
(
DP_AXIS
,
TPSP_AXIS
,
None
)
weight_axes
=
(
None
,
TPSP_AXIS
)
bias_axes
=
(
TPSP_AXIS
,)
output_axes
=
(
DP_AXIS
,
None
,
TPSP_AXIS
)
else
:
# RS
input_axes
=
(
DP_AXIS
,
None
,
TPSP_AXIS
)
weight_axes
=
(
TPSP_AXIS
,
None
)
bias_axes
=
(
None
,)
output_axes
=
(
DP_AXIS
,
TPSP_AXIS
,
None
)
return
input_axes
,
weight_axes
,
bias_axes
,
output_axes
def
_get_operand_sharding
(
mesh
,
collective_op
):
input_axes
,
weight_axes
,
bias_axes
,
_
=
_get_logical_axes
(
collective_op
)
x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
input_axes
))
weight_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
weight_axes
))
bias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_axes
))
return
x_sharding
,
weight_sharding
,
bias_sharding
def
_mean_dense
(
x
,
weight
,
bias
,
input_axes
,
weight_axes
,
output_axes
,
collective_op_set
):
output
=
dense
(
x
,
weight
,
bias
,
contracting_dims
=
((
2
,),
(
0
,)),
input_axes
=
input_axes
,
kernel_axes
=
weight_axes
,
output_axes
=
output_axes
,
collective_op_set
=
collective_op_set
,
)
return
jnp
.
mean
(
output
.
astype
(
jnp
.
float32
))
def
_value_and_grad_dense
(
x
,
weight
,
bias
,
input_axes
,
weight_axes
,
output_axes
,
collective_op_set
):
return
jax
.
jit
(
jax
.
value_and_grad
(
_mean_dense
,
(
0
,
1
,
2
)),
static_argnums
=
(
3
,
4
,
5
,
6
))(
x
,
weight
,
bias
,
input_axes
,
weight_axes
,
output_axes
,
collective_op_set
)
def
run_dense_grad_tests
(
args
,
mesh
=
None
):
"""Execute Dense Gradient tests."""
print
(
args
)
_initialize_distributed
(
args
)
mesh
=
mesh
or
_create_mesh
(
args
)
# Create test data
rng
=
jax
.
random
.
PRNGKey
(
0
)
rng
,
x_rng
,
weight_rng
,
bias_rng
=
jax
.
random
.
split
(
rng
,
4
)
x
=
jax
.
random
.
normal
(
x_rng
,
(
args
.
batch_size
,
args
.
seq_len
,
args
.
hidden_in
),
dtype
=
jnp
.
bfloat16
)
weight
=
jax
.
random
.
normal
(
weight_rng
,
(
args
.
hidden_in
,
args
.
hidden_out
),
dtype
=
jnp
.
bfloat16
)
bias
=
jax
.
random
.
normal
(
bias_rng
,
(
args
.
hidden_out
,),
dtype
=
jnp
.
bfloat16
)
collective_op
=
(
CollectiveOp
.
ALL_GATHER
if
args
.
collective_type
==
"all_gather"
else
CollectiveOp
.
REDUCE_SCATTER
)
collective_op_set
=
CollectiveOpSet
.
create
(
forward_collective_op
=
collective_op
)
with
mesh
,
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
None
,
mesh_resource
=
MeshResource
(
dp_resource
=
DP_AXIS
,
tpsp_resource
=
TPSP_AXIS
),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
axis_rules
=
flax
.
linen
.
get_logical_axis_rules
()
axis_rules
+=
((
TPSP_AXIS
,
TPSP_AXIS
),
(
DP_AXIS
,
DP_AXIS
))
te_extended_axis_rules
=
te_flax
.
extend_logical_axis_rules
(
axis_rules
)
with
flax
.
linen
.
logical_axis_rules
(
te_extended_axis_rules
):
x_sharding
,
weight_sharding
,
bias_sharding
=
_get_operand_sharding
(
mesh
,
collective_op
)
x_sharded
=
jax
.
device_put
(
x
,
x_sharding
)
weight_sharded
=
jax
.
device_put
(
weight
,
weight_sharding
)
bias_sharded
=
jax
.
device_put
(
bias
,
bias_sharding
)
input_axes
,
weight_axes
,
_
,
output_axes
=
_get_logical_axes
(
collective_op
)
ref_output
,
ref_grads
=
_value_and_grad_dense
(
x_sharded
,
weight_sharded
,
bias_sharded
,
input_axes
,
weight_axes
,
output_axes
,
noop_collective_op_set
,
)
output
,
sharded_grads
=
_value_and_grad_dense
(
x_sharded
,
weight_sharded
,
bias_sharded
,
input_axes
,
weight_axes
,
output_axes
,
collective_op_set
,
)
jax
.
block_until_ready
(
ref_output
)
jax
.
block_until_ready
(
output
)
gathered_grads
=
[]
gathered_ref_grads
=
[]
for
ref_grad
,
grad
in
zip
(
ref_grads
,
sharded_grads
):
gathered_grads
.
append
(
jax
.
lax
.
with_sharding_constraint
(
grad
,
NamedSharding
(
mesh
,
PartitionSpec
(
None
)))
)
gathered_ref_grads
.
append
(
jax
.
lax
.
with_sharding_constraint
(
ref_grad
,
NamedSharding
(
mesh
,
PartitionSpec
(
None
)))
)
jax
.
block_until_ready
(
gathered_grads
)
jax
.
block_until_ready
(
gathered_ref_grads
)
if
args
.
enable_result_check
and
args
.
process_id
==
0
:
assert_allclose
(
ref_output
,
output
,
dtype
=
jnp
.
bfloat16
)
for
ref_grad
,
gathered_grad
in
zip
(
gathered_ref_grads
,
gathered_grads
):
assert_allclose
(
ref_grad
,
gathered_grad
,
dtype
=
jnp
.
bfloat16
)
class
TestCollectiveDenseGradient
(
unittest
.
TestCase
):
"""Collective Dense Gradient unittests"""
def
setUp
(
self
):
self
.
args
=
cgemm_parser
(
"Collective Dense Gradient test on multi-GPU with tensor parallelism"
).
parse_args
([])
self
.
args
.
coordinator_address
=
self
.
coordinator_address
self
.
args
.
num_processes
=
self
.
num_processes
self
.
args
.
process_id
=
self
.
process_id
self
.
args
.
local_device_ids
=
self
.
local_device_ids
self
.
args
.
num_devices_per_process
=
self
.
num_devices_per_process
self
.
args
.
enable_data_parallel
=
True
self
.
args
.
tensor_parallel_size
=
_get_dp_and_tp_sizes
(
self
.
args
)[
1
]
_initialize_distributed
(
self
.
args
)
# Create mesh once for all tests
self
.
mesh
=
_create_mesh
(
self
.
args
)
jax
.
sharding
.
set_mesh
(
self
.
mesh
)
self
.
args
.
enable_result_check
=
True
os
.
environ
[
"NVTE_JAX_ALL_REDUCE_IN_FP32"
]
=
"1"
def
tearDown
(
self
):
os
.
environ
.
pop
(
"NVTE_JAX_ALL_REDUCE_IN_FP32"
,
None
)
def
test_te_bf16_all_gather
(
self
):
"""Test Collective Dense Gradient with AllGather"""
self
.
args
.
collective_type
=
"all_gather"
run_dense_grad_tests
(
self
.
args
,
self
.
mesh
)
def
test_te_bf16_reduce_scatter
(
self
):
"""Test Collective Dense Gradient with ReduceScatter"""
self
.
args
.
collective_type
=
"reduce_scatter"
run_dense_grad_tests
(
self
.
args
,
self
.
mesh
)
if
__name__
==
"__main__"
:
import
sys
if
len
(
sys
.
argv
)
<
7
:
# Need at least the 3 required distributed args
print
(
"Error: This script requires distributed initialization arguments."
)
print
(
"Usage: python test_dense_grad.py --coordinator-address <address> --num-processes <num>"
" --process-id <id> [--local-device-ids <ids>] [other args]"
)
print
(
"Example: python test_dense_grad.py --coordinator-address localhost:1234"
" --num-processes 4 --process-id 0"
)
print
(
"Example: python test_dense_grad.py --coordinator-address localhost:1234"
" --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
)
sys
.
exit
(
1
)
args
=
cgemm_parser
(
"Collective Dense Gradient test on multi-GPU with tensor parallelism"
).
parse_args
([])
_initialize_distributed
(
args
)
run_dense_grad_tests
(
args
,
mesh
=
None
)
examples/jax/collective_gemm/test_gemm.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective GEMM test on multi-GPU with tensor parallelism
This script uses custom distributed initialization with the following arguments:
- --coordinator-address: Coordinator address for distributed initialization
- --num-processes: Number of processes for distributed initialization
- --process-id: Process ID for distributed initialization
- --local-device-ids: Local device IDs for distributed initialization
Example:
python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3
"""
import
unittest
import
os
from
functools
import
partial
import
jax
import
jax.numpy
as
jnp
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
common
import
(
assert_allclose
,
_initialize_distributed
,
_get_dp_and_tp_sizes
,
_create_mesh
,
DP_AXIS
,
TPSP_AXIS
,
PARAMS_KEY
,
cgemm_parser
,
)
import
transformer_engine.jax.cpp_extensions
as
tex
from
transformer_engine.jax.quantize
import
fp8_autocast
from
transformer_engine.jax.cpp_extensions.gemm
import
CollectiveOp
from
transformer_engine.jax.sharding
import
MeshResource
def
_get_operand_sharding
(
mesh
,
collective_op
,
is_with_dp
):
dp_axis
=
DP_AXIS
if
is_with_dp
else
None
if
collective_op
==
CollectiveOp
.
ALL_GATHER
:
x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
dp_axis
,
TPSP_AXIS
,
None
))
weight_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
TPSP_AXIS
))
bias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
TPSP_AXIS
))
output_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
dp_axis
,
None
,
TPSP_AXIS
))
else
:
# RS
x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
dp_axis
,
None
,
TPSP_AXIS
))
weight_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
TPSP_AXIS
,
None
))
bias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
output_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
dp_axis
,
TPSP_AXIS
,
None
))
return
x_sharding
,
weight_sharding
,
bias_sharding
,
output_sharding
def
_get_dp_and_tp_sizes
(
args
):
num_gpu
=
args
.
num_processes
*
args
.
num_devices_per_process
if
args
.
tensor_parallel_size
is
None
:
num_gpu_dp
=
2
if
args
.
enable_data_parallel
else
1
assert
(
num_gpu
>
1
and
num_gpu
%
num_gpu_dp
==
0
),
"Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_tp
=
num_gpu
//
num_gpu_dp
else
:
num_gpu_tp
=
args
.
tensor_parallel_size
assert
(
num_gpu
>
1
and
num_gpu
%
num_gpu_tp
==
0
),
"Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_dp
=
num_gpu
//
num_gpu_tp
return
num_gpu_dp
,
num_gpu_tp
@
partial
(
jax
.
jit
,
static_argnames
=
(
"contracting_dims"
,
"collective_op"
,
"output_sharding"
))
def
_jitted_cgemm
(
x
,
weight
,
bias
,
contracting_dims
,
collective_op
,
output_sharding
):
output
=
tex
.
gemm
(
x
,
weight
,
bias
=
bias
,
contracting_dims
=
contracting_dims
,
collective_op
=
collective_op
,
)
if
output_sharding
is
not
None
:
output
=
jax
.
lax
.
with_sharding_constraint
(
output
,
output_sharding
)
return
output
def
run_gemm_tests
(
args
,
mesh
=
None
):
"""Execute GEMM tests."""
print
(
args
)
# Collective GEMM requires Shardy partitioner to be disabled
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
False
)
# Initialize distributed with provided arguments
_initialize_distributed
(
args
)
mesh
=
mesh
or
_create_mesh
(
args
)
# Create test data
rng
=
jax
.
random
.
PRNGKey
(
0
)
rng
,
x_rng
,
weight_rng
,
bias_rng
=
jax
.
random
.
split
(
rng
,
4
)
x
=
jax
.
random
.
normal
(
x_rng
,
(
args
.
batch_size
,
args
.
seq_len
,
args
.
hidden_in
),
dtype
=
jnp
.
bfloat16
)
weight
=
jax
.
random
.
normal
(
weight_rng
,
(
args
.
hidden_in
,
args
.
hidden_out
),
dtype
=
jnp
.
bfloat16
)
bias
=
jax
.
random
.
normal
(
bias_rng
,
(
args
.
hidden_out
,),
dtype
=
jnp
.
bfloat16
)
collective_op
=
(
CollectiveOp
.
ALL_GATHER
if
args
.
collective_type
==
"all_gather"
else
CollectiveOp
.
REDUCE_SCATTER
)
with
mesh
,
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
None
,
mesh_resource
=
MeshResource
(
dp_resource
=
DP_AXIS
,
tpsp_resource
=
TPSP_AXIS
),
):
print
(
f
"Device mesh:
{
mesh
}
"
)
x_sharding
,
weight_sharding
,
bias_sharding
,
output_sharding
=
_get_operand_sharding
(
mesh
,
collective_op
,
args
.
enable_data_parallel
)
x_sharded
=
jax
.
device_put
(
x
,
x_sharding
)
weight_sharded
=
jax
.
device_put
(
weight
,
weight_sharding
)
bias_sharded
=
jax
.
device_put
(
bias
,
bias_sharding
)
ref_output
=
_jitted_cgemm
(
x_sharded
,
weight_sharded
,
bias_sharded
,
contracting_dims
=
((
2
,),
(
0
,)),
collective_op
=
CollectiveOp
.
NONE
,
output_sharding
=
output_sharding
,
)
output
=
_jitted_cgemm
(
x_sharded
,
weight_sharded
,
bias_sharded
,
contracting_dims
=
((
2
,),
(
0
,)),
collective_op
=
collective_op
,
# CollectiveGEMM output should have a correct sharding without applying sharding constraint
output_sharding
=
None
,
)
assert
(
ref_output
.
sharding
==
output
.
sharding
),
f
"ref_output.sharding=
{
ref_output
.
sharding
}
, output.sharding=
{
output
.
sharding
}
"
gathered_ref_output
=
jax
.
lax
.
with_sharding_constraint
(
ref_output
,
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
)
gathered_output
=
jax
.
lax
.
with_sharding_constraint
(
output
,
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
)
jax
.
block_until_ready
(
gathered_ref_output
)
jax
.
block_until_ready
(
gathered_output
)
if
args
.
enable_result_check
and
args
.
process_id
==
0
:
assert_allclose
(
gathered_ref_output
,
gathered_output
)
class
TestCollectiveGemmWithDP
(
unittest
.
TestCase
):
"""Collective GEMM with DP unittests"""
def
setUp
(
self
):
self
.
args
=
cgemm_parser
(
"Collective GEMM test on multi-GPU with tensor parallelism"
).
parse_args
([])
self
.
args
.
coordinator_address
=
self
.
coordinator_address
self
.
args
.
num_processes
=
self
.
num_processes
self
.
args
.
process_id
=
self
.
process_id
self
.
args
.
local_device_ids
=
self
.
local_device_ids
self
.
args
.
num_devices_per_process
=
self
.
num_devices_per_process
self
.
args
.
enable_data_parallel
=
True
self
.
args
.
tensor_parallel_size
=
_get_dp_and_tp_sizes
(
self
.
args
)[
1
]
_initialize_distributed
(
self
.
args
)
self
.
mesh
=
_create_mesh
(
self
.
args
)
jax
.
sharding
.
set_mesh
(
self
.
mesh
)
self
.
args
.
enable_result_check
=
True
os
.
environ
[
"NVTE_JAX_ALL_REDUCE_IN_FP32"
]
=
"1"
def
tearDown
(
self
):
os
.
environ
.
pop
(
"NVTE_JAX_ALL_REDUCE_IN_FP32"
,
None
)
def
test_te_bf16_all_gather_with_dp
(
self
):
"""Test Collective GEMM with AllGather"""
self
.
args
.
collective_type
=
"all_gather"
run_gemm_tests
(
self
.
args
,
self
.
mesh
)
def
test_te_bf16_reduce_scatter_with_dp
(
self
):
"""Test Collective GEMM with ReduceScatter"""
self
.
args
.
collective_type
=
"reduce_scatter"
run_gemm_tests
(
self
.
args
,
self
.
mesh
)
if
__name__
==
"__main__"
:
import
sys
if
len
(
sys
.
argv
)
<
5
:
# Need at least the 3 required distributed args
print
(
"Error: This script requires distributed initialization arguments."
)
print
(
"Usage: python test_gemm.py --coordinator-address <address> --num-processes <num>"
" --process-id <id> [--local-device-ids <ids>] [other args]"
)
sys
.
exit
(
1
)
args
=
cgemm_parser
(
"Collective GEMM test on multi-GPU with tensor parallelism"
).
parse_args
()
_initialize_distributed
(
args
)
run_gemm_tests
(
args
,
mesh
=
None
)
examples/jax/collective_gemm/test_layernorm_mlp_grad.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
import
argparse
import
unittest
import
os
import
jax
import
jax.numpy
as
jnp
from
jax.sharding
import
PartitionSpec
,
NamedSharding
import
flax
from
common
import
(
assert_allclose
,
_initialize_distributed
,
_get_dp_and_tp_sizes
,
_create_mesh
,
DP_AXIS
,
TPSP_AXIS
,
PARAMS_KEY
,
cgemm_parser
,
)
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
from
transformer_engine.jax.quantize
import
fp8_autocast
from
transformer_engine.jax.cpp_extensions.gemm
import
(
CollectiveOpSet
,
CollectiveOp
,
noop_collective_op_set
,
)
from
transformer_engine.jax.sharding
import
MeshResource
import
transformer_engine.jax.flax
as
te_flax
def
_get_logical_axes
():
input_1_axes
=
(
DP_AXIS
,
TPSP_AXIS
,
None
)
weight_1_axes
=
(
None
,
None
,
TPSP_AXIS
)
bias_axes_1
=
(
None
,
TPSP_AXIS
)
input_2_axes
=
(
DP_AXIS
,
None
,
TPSP_AXIS
)
weight_2_axes
=
(
TPSP_AXIS
,
None
)
bias_axes_2
=
(
None
,)
return
input_1_axes
,
weight_1_axes
,
bias_axes_1
,
input_2_axes
,
weight_2_axes
,
bias_axes_2
def
_get_operand_sharding
(
mesh
):
input_1_axes
,
weight_1_axes
,
bias_axes_1
,
input_2_axes
,
weight_2_axes
,
bias_axes_2
=
(
_get_logical_axes
()
)
x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
input_1_axes
))
weight_1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
weight_1_axes
))
bias_1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_axes_1
))
weight_2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
weight_2_axes
))
bias_2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_axes_2
))
return
x_sharding
,
weight_1_sharding
,
bias_1_sharding
,
weight_2_sharding
,
bias_2_sharding
def
_mean_layernorm_mlp
(
x
,
weight_1
,
bias_1
,
weight_2
,
bias_2
,
gamma
,
input_1_axes
,
input_2_axes
,
weight_1_axes
,
weight_2_axes
,
collective_op_sets
,
):
output
=
layernorm_mlp
(
x
,
gamma
,
beta
=
None
,
kernels
=
[
weight_1
,
weight_2
],
biases
=
[
bias_1
,
bias_2
],
norm_type
=
"rmsnorm"
,
dot_1_input_axes
=
input_1_axes
,
dot_2_input_axes
=
input_2_axes
,
kernel_1_axes
=
weight_1_axes
,
kernel_2_axes
=
weight_2_axes
,
activation_type
=
(
"gelu"
,),
collective_op_sets
=
collective_op_sets
,
)
return
jnp
.
mean
(
output
)
def
_value_and_grad_layernorm_mlp
(
x
,
weight_1
,
bias_1
,
weight_2
,
bias_2
,
gamma
,
input_1_axes
,
input_2_axes
,
weight_1_axes
,
weight_2_axes
,
collective_op_sets
,
):
return
jax
.
jit
(
jax
.
value_and_grad
(
_mean_layernorm_mlp
,
(
0
,
1
,
2
,
3
,
4
,
5
)),
static_argnums
=
(
6
,
7
,
8
,
9
,
10
)
)(
x
,
weight_1
,
bias_1
,
weight_2
,
bias_2
,
gamma
,
input_1_axes
,
input_2_axes
,
weight_1_axes
,
weight_2_axes
,
collective_op_sets
,
)
def
run_layernorm_mlp_grad_tests
(
args
,
mesh
=
None
):
"""Execute Dense Gradient tests."""
print
(
args
)
# Collective GEMM requires Shardy partitioner to be disabled
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
False
)
# Initialize distributed with provided arguments
_initialize_distributed
(
args
)
mesh
=
mesh
or
_create_mesh
(
args
)
# Create test data
rng
=
jax
.
random
.
PRNGKey
(
0
)
rng
,
x_rng
,
weight_1_rng
,
bias_1_rng
,
weight_2_rng
,
bias_2_rng
,
gamma_rng
=
jax
.
random
.
split
(
rng
,
7
)
x
=
jax
.
random
.
normal
(
x_rng
,
(
args
.
batch_size
,
args
.
seq_len
,
args
.
hidden_in
),
dtype
=
jnp
.
bfloat16
)
weight_1
=
jax
.
random
.
normal
(
weight_1_rng
,
(
args
.
hidden_in
,
1
,
args
.
hidden_out
),
dtype
=
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
args
.
hidden_in
)
bias_1
=
jax
.
random
.
normal
(
bias_1_rng
,
(
1
,
args
.
hidden_out
),
dtype
=
jnp
.
bfloat16
)
weight_2
=
jax
.
random
.
normal
(
weight_2_rng
,
(
args
.
hidden_out
,
args
.
hidden_in
),
dtype
=
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
args
.
hidden_out
)
bias_2
=
jax
.
random
.
normal
(
bias_2_rng
,
(
args
.
hidden_in
,),
dtype
=
jnp
.
bfloat16
)
gamma
=
jax
.
random
.
normal
(
gamma_rng
,
(
args
.
hidden_in
,),
dtype
=
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
args
.
hidden_in
)
collective_op_set_1
=
CollectiveOpSet
.
create
(
forward_collective_op
=
CollectiveOp
.
ALL_GATHER
)
collective_op_set_2
=
CollectiveOpSet
.
create
(
forward_collective_op
=
CollectiveOp
.
REDUCE_SCATTER
)
collective_op_sets
=
(
collective_op_set_1
,
collective_op_set_2
)
noop_collective_op_sets
=
(
noop_collective_op_set
,
noop_collective_op_set
)
with
mesh
,
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
None
,
mesh_resource
=
MeshResource
(
dp_resource
=
DP_AXIS
,
tpsp_resource
=
TPSP_AXIS
),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
axis_rules
=
flax
.
linen
.
get_logical_axis_rules
()
axis_rules
+=
((
TPSP_AXIS
,
TPSP_AXIS
),
(
DP_AXIS
,
DP_AXIS
))
te_extended_axis_rules
=
te_flax
.
extend_logical_axis_rules
(
axis_rules
)
with
flax
.
linen
.
logical_axis_rules
(
te_extended_axis_rules
):
x_sharding
,
weight_1_sharding
,
bias_1_sharding
,
weight_2_sharding
,
bias_2_sharding
=
(
_get_operand_sharding
(
mesh
)
)
x_sharded
=
jax
.
device_put
(
x
,
x_sharding
)
weight_1_sharded
=
jax
.
device_put
(
weight_1
,
weight_1_sharding
)
bias_1_sharded
=
jax
.
device_put
(
bias_1
,
bias_1_sharding
)
weight_2_sharded
=
jax
.
device_put
(
weight_2
,
weight_2_sharding
)
bias_2_sharded
=
jax
.
device_put
(
bias_2
,
bias_2_sharding
)
input_1_axes
,
weight_1_axes
,
_
,
input_2_axes
,
weight_2_axes
,
_
=
_get_logical_axes
()
ref_output
,
ref_grads
=
_value_and_grad_layernorm_mlp
(
x_sharded
,
weight_1_sharded
,
bias_1_sharded
,
weight_2_sharded
,
bias_2_sharded
,
gamma
,
input_1_axes
,
input_2_axes
,
weight_1_axes
,
weight_2_axes
,
noop_collective_op_sets
,
)
output
,
sharded_grads
=
_value_and_grad_layernorm_mlp
(
x_sharded
,
weight_1_sharded
,
bias_1_sharded
,
weight_2_sharded
,
bias_2_sharded
,
gamma
,
input_1_axes
,
input_2_axes
,
weight_1_axes
,
weight_2_axes
,
collective_op_sets
,
)
jax
.
block_until_ready
(
ref_output
)
jax
.
block_until_ready
(
output
)
gathered_grads
=
[]
gathered_ref_grads
=
[]
for
ref_grad
,
grad
in
zip
(
ref_grads
,
sharded_grads
):
gathered_grads
.
append
(
jax
.
lax
.
with_sharding_constraint
(
grad
,
NamedSharding
(
mesh
,
PartitionSpec
(
None
)))
)
gathered_ref_grads
.
append
(
jax
.
lax
.
with_sharding_constraint
(
ref_grad
,
NamedSharding
(
mesh
,
PartitionSpec
(
None
)))
)
jax
.
block_until_ready
(
gathered_grads
)
jax
.
block_until_ready
(
gathered_ref_grads
)
if
args
.
enable_result_check
and
args
.
process_id
==
0
:
assert_allclose
(
ref_output
,
output
,
dtype
=
jnp
.
bfloat16
)
for
ref_grad
,
gathered_grad
in
zip
(
gathered_ref_grads
,
gathered_grads
):
assert_allclose
(
ref_grad
,
gathered_grad
,
dtype
=
jnp
.
bfloat16
)
class
TestCollectiveLayerNormMLPGradient
(
unittest
.
TestCase
):
"""Collective Dense Gradient unittests"""
def
setUp
(
self
):
self
.
args
=
cgemm_parser
(
"Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
).
parse_args
([])
self
.
args
.
coordinator_address
=
self
.
coordinator_address
self
.
args
.
num_processes
=
self
.
num_processes
self
.
args
.
process_id
=
self
.
process_id
self
.
args
.
local_device_ids
=
self
.
local_device_ids
self
.
args
.
num_devices_per_process
=
self
.
num_devices_per_process
self
.
args
.
enable_data_parallel
=
True
self
.
args
.
tensor_parallel_size
=
_get_dp_and_tp_sizes
(
self
.
args
)[
1
]
_initialize_distributed
(
self
.
args
)
# Create mesh once for all tests
self
.
mesh
=
_create_mesh
(
self
.
args
)
jax
.
sharding
.
set_mesh
(
self
.
mesh
)
self
.
args
.
enable_result_check
=
True
os
.
environ
[
"NVTE_JAX_ALL_REDUCE_IN_FP32"
]
=
"1"
def
tearDown
(
self
):
os
.
environ
.
pop
(
"NVTE_JAX_ALL_REDUCE_IN_FP32"
,
None
)
def
test_te_bf16_layernorm_mlp_grad
(
self
):
"""Test Collective Dense Gradient with AllGather"""
run_layernorm_mlp_grad_tests
(
self
.
args
,
self
.
mesh
)
if
__name__
==
"__main__"
:
import
sys
if
len
(
sys
.
argv
)
<
7
:
# Need at least the 3 required distributed args
print
(
"Error: This script requires distributed initialization arguments."
)
print
(
"Usage: python test_layernorm_mlp_grad.py --coordinator-address <address>"
" --num-processes <num> --process-id <id> [--local-device-ids <ids>] [other args]"
)
print
(
"Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
" --num-processes 4 --process-id 0"
)
print
(
"Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
" --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
)
sys
.
exit
(
1
)
args
=
cgemm_parser
(
"Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
).
parse_args
([])
_initialize_distributed
(
args
)
run_layernorm_mlp_grad_tests
(
args
,
mesh
=
None
)
examples/jax/encoder/run_test_multiprocessing_encoder.sh
View file @
53fa872c
...
...
@@ -15,11 +15,37 @@ TEST_CASES=(
"test_te_current_scaling_fp8_shardy"
)
:
${
TE_PATH
:
=/opt/transformerengine
}
:
${
XML_LOG_DIR
:
=/logs
}
mkdir
-p
"
$XML_LOG_DIR
"
echo
echo
"*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***"
HAS_FAILURE
=
0
# Global failure flag
PIDS
=()
# Array to store all process PIDs
# Cleanup function to kill all processes
cleanup
()
{
for
pid
in
"
${
PIDS
[@]
}
"
;
do
if
kill
-0
"
$pid
"
2>/dev/null
;
then
echo
"Killing process
$pid
"
kill
-TERM
"
$pid
"
2>/dev/null
||
true
fi
done
# Wait a bit and force kill if needed
sleep
2
for
pid
in
"
${
PIDS
[@]
}
"
;
do
if
kill
-0
"
$pid
"
2>/dev/null
;
then
echo
"Force killing process
$pid
"
kill
-KILL
"
$pid
"
2>/dev/null
||
true
fi
done
}
# Set up signal handlers to cleanup on exit
trap
cleanup EXIT INT TERM
# Run each test case across all GPUs
for
TEST_CASE
in
"
${
TEST_CASES
[@]
}
"
;
do
echo
...
...
@@ -29,25 +55,40 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
# Define output file for logs
LOG_FILE
=
"
${
TEST_CASE
}
_gpu_
${
i
}
.log"
# Run pytest and redirect stdout and stderr to the log file
pytest
-s
-c
"
$TE_PATH
/tests/jax/pytest.ini"
\
-vs
"
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::
$TEST_CASE
"
\
--num-process
=
$NUM_GPUS
\
--process-id
=
$i
>
"
$LOG_FILE
"
2>&1 &
done
# For process 0: show live output AND save to log file using tee
if
[
$i
-eq
0
]
;
then
echo
"=== Live output from process 0 ==="
pytest
-s
-c
"
$TE_PATH
/tests/jax/pytest.ini"
\
-vs
--junitxml
=
$XML_LOG_DIR
/multiprocessing_encoder_
${
TEST_CASE
}
.xml
\
"
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::
$TEST_CASE
"
\
--num-process
=
$NUM_GPUS
\
--process-id
=
$i
2>&1 |
tee
"
$LOG_FILE
"
&
PID
=
$!
PIDS+
=(
$PID
)
else
pytest
-s
-c
"
$TE_PATH
/tests/jax/pytest.ini"
\
-vs
"
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::
$TEST_CASE
"
\
--num-process
=
$NUM_GPUS
\
--process-id
=
$i
>
"
$LOG_FILE
"
2>&1 &
PID
=
$!
PIDS+
=(
$PID
)
fi
done
# Wait for the process to finish
wait
tail
-n
+7
"
${
TEST_CASE
}
_gpu_0.log"
# Check and print the log content accordingly
if
grep
-q
"SKIPPED"
"
${
TEST_CASE
}
_gpu_0.log"
;
then
echo
"...
$TEST_CASE
SKIPPED"
elif
grep
-q
"FAILED"
"
${
TEST_CASE
}
_gpu_0.log"
;
then
echo
"...
$TEST_CASE
FAILED"
HAS_FAILURE
=
1
elif
grep
-q
"PASSED"
"
${
TEST_CASE
}
_gpu_0.log"
;
then
echo
"...
$TEST_CASE
PASSED"
else
echo
"...
$TEST_CASE
INVALID"
HAS_FAILURE
=
1
echo
"...
$TEST_CASE
FAILED"
fi
# Remove the log file after processing it
...
...
@@ -56,4 +97,8 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
done
wait
# Final cleanup (trap will also call cleanup on exit)
cleanup
exit
$HAS_FAILURE
qa/L0_jax_distributed_unittest/test.sh
View file @
53fa872c
...
...
@@ -29,6 +29,10 @@ wait
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_model_parallel_encoder.xml
$TE_PATH
/examples/jax/encoder/test_model_parallel_encoder.py
||
test_fail
"test_model_parallel_encoder.py"
wait
TE_PATH
=
$TE_PATH
bash
$TE_PATH
/examples/jax/encoder/run_test_multiprocessing_encoder.sh
||
test_fail
"run_test_multiprocessing_encoder.sh"
wait
TE_PATH
=
$TE_PATH
bash
$TE_PATH
/examples/jax/collective_gemm/run_test_cgemm.sh
||
test_fail
"run_test_cgemm.sh"
wait
if
[
$RET
-ne
0
]
;
then
echo
"Error: some sub-tests failed:
$FAILED_CASES
"
...
...
qa/L0_jax_unittest/test.sh
View file @
53fa872c
...
...
@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_single_gpu_encoder.xml
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py"
# Test without custom calls
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS
=
"false"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_single_gpu_encoder.xml
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py without custom calls"
NVTE_JAX_CUSTOM_CALLS
=
"false"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_single_gpu_encoder
_without_custom_call
.xml
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py without custom calls"
if
[
$RET
-ne
0
]
;
then
echo
"Error: some sub-tests failed:
$FAILED_CASES
"
...
...
qa/L0_pytorch_debug_unittest/test.sh
View file @
53fa872c
...
...
@@ -7,6 +7,8 @@
:
${
TE_PATH
:
=/opt/transformerengine
}
:
${
NVTE_TEST_NVINSPECT_FEATURE_DIRS
:
=
$TE_PATH
/transformer_engine/debug/features
}
:
${
NVTE_TEST_NVINSPECT_CONFIGS_DIR
:
=
$TE_PATH
/tests/pytorch/debug/test_configs/
}
:
${
XML_LOG_DIR
:
=/logs
}
mkdir
-p
"
$XML_LOG_DIR
"
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
...
...
@@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect
pip
install
git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
pip
install
pytest
==
8.2.1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_sanity.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_config.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_numerics.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_log.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
NVTE_TORCH_COMPILE
=
0 pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_api_features.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_log.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_perf.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_sanity.xml
$TE_PATH
/tests/pytorch/debug/test_sanity.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_config.xml
$TE_PATH
/tests/pytorch/debug/test_config.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_numerics.xml
$TE_PATH
/tests/pytorch/debug/test_numerics.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_log.xml
$TE_PATH
/tests/pytorch/debug/test_log.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
NVTE_TORCH_COMPILE
=
0 pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_api_features.xml
$TE_PATH
/tests/pytorch/debug/test_api_features.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_perf.xml
$TE_PATH
/tests/pytorch/debug/test_perf.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
# standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED
=
1
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 pytest
-v
-s
$TE_PATH
/tests/pytorch/test_sanity.py
||
FAIL
=
1
NVTE_TEST_NVINSPECT_ENABLED
=
1
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 pytest
-v
-s
$TE_PATH
/tests/pytorch/test_numerics.py
||
FAIL
=
1
NVTE_TEST_NVINSPECT_ENABLED
=
1
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_sanity_2.xml
$TE_PATH
/tests/pytorch/test_sanity.py
||
FAIL
=
1
NVTE_TEST_NVINSPECT_ENABLED
=
1
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_numerics_2.xml
$TE_PATH
/tests/pytorch/test_numerics.py
||
FAIL
=
1
exit
$FAIL
qa/L0_pytorch_unittest/test.sh
View file @
53fa872c
...
...
@@ -31,6 +31,7 @@ ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0
ROCBLAS_ATOMICS_MOD
=
0
HIPBLASLT_ATOMICS_MOD
=
0
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_cuda_graphs.xml
$TE_PATH
/tests/pytorch/test_cuda_graphs.py
||
test_fail
"test_cuda_graphs.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_jit.xml
$TE_PATH
/tests/pytorch/test_jit.py
||
test_fail
"test_jit.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_rope.xml
$TE_PATH
/tests/pytorch/test_fused_rope.py
||
test_fail
"test_fused_rope.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_nvfp4.xml
$TE_PATH
/tests/pytorch/nvfp4
||
test_fail
"test_nvfp4"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8tensor.xml
$TE_PATH
/tests/pytorch/test_float8tensor.py
||
test_fail
"test_float8tensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8blockwisetensor.xml
$TE_PATH
/tests/pytorch/test_float8blockwisetensor.py
||
test_fail
"test_float8blockwisetensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
...
...
qa/L1_pytorch_distributed_unittest/test.sh
View file @
53fa872c
...
...
@@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_sanity.xml
$TE_PATH
/tests/pytorch/distributed/test_sanity.py
||
test_fail
"test_sanity.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_numerics.xml
$TE_PATH
/tests/pytorch/distributed/test_numerics.py
||
test_fail
"test_numerics.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_numerics_exact.xml
$TE_PATH
/tests/pytorch/distributed/test_numerics_exact.py
||
test_fail
"test_numerics_exact.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fusible_ops.xml
$TE_PATH
/tests/pytorch/distributed/test_fusible_ops.py
||
test_fail
"test_fusible_ops.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_torch_fsdp2.xml
$TE_PATH
/tests/pytorch/distributed/test_torch_fsdp2.py
||
test_fail
"test_torch_fsdp2.py"
python3
-m
pytest
-v
-s
--log-cli-level
=
INFO
--junitxml
=
$XML_LOG_DIR
/pytest_test_comm_gemm_overlap.xml
$TE_PATH
/tests/pytorch/distributed/test_comm_gemm_overlap.py
||
test_fail
"test_comm_gemm_overlap.py"
...
...
@@ -47,9 +48,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_
:
${
NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
:
=
$TE_PATH
/tests/pytorch/debug/test_configs/dummy_feature.yaml
}
:
${
NVTE_TEST_NVINSPECT_FEATURE_DIRS
:
=
$TE_PATH
/transformer_engine/debug/features
}
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_distributed.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
test_fail
"debug test_distributed.py"
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_distributed.xml
$TE_PATH
/tests/pytorch/debug/test_distributed.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
test_fail
"debug test_distributed.py"
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED
=
True
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
pytest
-v
-s
$TE_PATH
/tests/pytorch/distributed/test_numerics.py
||
test_fail
"debug test_numerics.py"
NVTE_TEST_NVINSPECT_ENABLED
=
True
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_numerics_2.xml
$TE_PATH
/tests/pytorch/distributed/test_numerics.py
||
test_fail
"debug test_numerics.py"
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
...
...
qa/L1_pytorch_onnx_unittest/test.sh
View file @
53fa872c
...
...
@@ -3,9 +3,11 @@
# See LICENSE for license information.
pip3
install
onnxruntime
==
1.20.1
pip3
install
onnxruntime_extensions
==
0.13.0
pip3
install
onnxruntime
pip3
install
onnxruntime_extensions
:
${
TE_PATH
:
=/opt/transformerengine
}
:
${
XML_LOG_DIR
:
=/logs
}
mkdir
-p
"
$XML_LOG_DIR
"
python3
-m
pytest
--tb
=
auto
$TE_PATH
/tests/pytorch/test_onnx_export.py
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/test_onnx_export.xml
$TE_PATH
/tests/pytorch/test_onnx_export.py
tests/cpp/operator/CMakeLists.txt
View file @
53fa872c
...
...
@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_cast_mxfp8.cu
test_cast_nvfp4_transpose.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_transpose.cu
...
...
@@ -66,6 +67,13 @@ else()
add_executable
(
test_operator
${
test_hip_sources
}
)
endif
()
# Add profiling and debug flags for CUDA compilation
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-lineinfo"
)
# Generate line info for device code
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-g"
)
# Add debug symbols for host code
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
--ptxas-options=-v"
)
# Add info about registers usage
# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping
# Find required packages
find_package
(
OpenMP REQUIRED
)
if
(
USE_CUDA
)
list
(
APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main
${
TE_LIB
}
CUDA::nvrtc CUDNN::cudnn
)
...
...
Prev
1
2
3
4
5
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment