Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
44740c6c
Commit
44740c6c
authored
Jul 22, 2025
by
yuguo
Browse files
Merge commit '
7a9a0825
' of...
Merge commit '
7a9a0825
' of
https://github.com/NVIDIA/TransformerEngine
parents
8113d9e0
7a9a0825
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
488 additions
and
106 deletions
+488
-106
.github/workflows/build.yml
.github/workflows/build.yml
+2
-2
.gitignore
.gitignore
+2
-1
3rdparty/cudnn-frontend
3rdparty/cudnn-frontend
+0
-1
benchmarks/linear/benchmark_grouped_linear.py
benchmarks/linear/benchmark_grouped_linear.py
+290
-0
build_tools/VERSION.txt
build_tools/VERSION.txt
+1
-1
build_tools/jax.py
build_tools/jax.py
+8
-1
build_tools/pytorch.py
build_tools/pytorch.py
+8
-1
examples/jax/encoder/requirements.txt
examples/jax/encoder/requirements.txt
+1
-1
examples/jax/encoder/run_test_multiprocessing_encoder.sh
examples/jax/encoder/run_test_multiprocessing_encoder.sh
+11
-12
examples/jax/encoder/test_model_parallel_encoder.py
examples/jax/encoder/test_model_parallel_encoder.py
+59
-28
examples/jax/encoder/test_multigpu_encoder.py
examples/jax/encoder/test_multigpu_encoder.py
+41
-21
examples/jax/encoder/test_multiprocessing_encoder.py
examples/jax/encoder/test_multiprocessing_encoder.py
+42
-24
examples/jax/mnist/requirements.txt
examples/jax/mnist/requirements.txt
+1
-1
qa/L0_cppunittest/test.sh
qa/L0_cppunittest/test.sh
+1
-1
qa/L0_jax_distributed_unittest/test.sh
qa/L0_jax_distributed_unittest/test.sh
+5
-5
qa/L0_pytorch_debug_unittest/test.sh
qa/L0_pytorch_debug_unittest/test.sh
+2
-1
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+5
-0
qa/L3_pytorch_FA_versions_test/test.sh
qa/L3_pytorch_FA_versions_test/test.sh
+2
-2
tests/cpp/CMakeLists.txt
tests/cpp/CMakeLists.txt
+5
-3
tests/cpp/operator/CMakeLists.txt
tests/cpp/operator/CMakeLists.txt
+2
-0
No files found.
.github/workflows/build.yml
View file @
44740c6c
...
...
@@ -43,7 +43,7 @@ jobs:
run
:
|
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
onnxscript
-
name
:
'
Checkout'
uses
:
actions/checkout@v3
with
:
...
...
@@ -83,7 +83,7 @@ jobs:
options
:
--user root
steps
:
-
name
:
'
Dependencies'
run
:
pip install torch pybind11[global] einops
run
:
pip install torch pybind11[global] einops
onnxscript
-
name
:
'
Checkout'
uses
:
actions/checkout@v3
with
:
...
...
.gitignore
View file @
44740c6c
...
...
@@ -38,4 +38,5 @@ downloads/
.pytest_cache/
compile_commands.json
.nfs
tensor_dumps/
\ No newline at end of file
tensor_dumps/
artifacts/
cudnn-frontend
@
20c28ea7
Compare
20c28ea7
...
20c28ea7
Subproject commit 20c28ea798fe99e31d7274e009ee2fbf0e88abfd
benchmarks/linear/benchmark_grouped_linear.py
0 → 100644
View file @
44740c6c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
argparse
import
torch
import
torch.utils.benchmark
as
benchmark
import
pandas
as
pd
import
pathlib
from
transformer_engine.pytorch.module
import
GroupedLinear
from
transformer_engine.common.recipe
import
Float8BlockScaling
,
MXFP8BlockScaling
from
transformer_engine.pytorch.fp8
import
fp8_autocast
,
FP8GlobalStateManager
from
contextlib
import
nullcontext
"""
# Profile BF16 recipe with Nsight Systems
nsys profile
\
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_bf16
\
--force-overwrite true
\
--trace=cuda,nvtx,cudnn,cublas
\
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16
# Profile FP8 sub-channel recipe with Nsight Systems
nsys profile
\
--output=./benchmarks/linear/h100hbm_mkn_4096_4096_4096_numgemm_8_fp8_sub_channel
\
--force-overwrite true
\
--trace=cuda,nvtx,cudnn,cublas
\
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel
# Profile MXFP8 recipe with Nsight Systems
nsys profile
\
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_mxfp8
\
--force-overwrite true
\
--trace=cuda,nvtx,cudnn,cublas
\
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8
"""
RECIPES
=
{
"bf16"
:
None
,
"fp8_sub_channel"
:
Float8BlockScaling
(),
"mxfp8"
:
MXFP8BlockScaling
(),
}
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
def
run_linear_multiple_steps
(
layer
,
x
,
m_splits
,
mode
,
gradient
,
run_num_steps
=
1
,
recipe
=
None
):
assert
mode
in
[
"fwd_only"
,
"fwd_bwd"
]
fp8_context
=
(
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe
)
if
recipe
is
not
None
else
nullcontext
()
)
# print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}")
if
mode
==
"fwd_only"
:
with
torch
.
no_grad
(),
fp8_context
:
for
i
in
range
(
run_num_steps
):
y_q
=
layer
.
forward
(
x
,
m_splits
,
is_first_microbatch
=
(
i
==
0
),
)
return
y_q
else
:
# reset gradients
layer
.
zero_grad
()
x
.
grad
=
None
with
fp8_context
:
for
i
in
range
(
run_num_steps
):
label
=
f
"step_
{
i
}
"
torch
.
cuda
.
nvtx
.
range_push
(
label
)
y_q
=
layer
.
forward
(
x
,
m_splits
,
is_first_microbatch
=
(
i
==
0
),
)
y_q
.
backward
(
gradient
)
torch
.
cuda
.
nvtx
.
range_pop
()
grads_q
=
[]
grads_q
.
append
(
x
.
grad
)
# remaining derivatives are in respect to model parameters
for
p
in
layer
.
parameters
():
if
p
.
requires_grad
:
grads_q
.
append
(
p
.
grad
)
return
y_q
,
grads_q
def
benchmark_linear
(
x
,
ws
,
m_splits
,
bias
,
recipe_name
,
mode
,
num_gemms
=
4
,
):
params_dtype
=
torch
.
bfloat16
recipe
=
RECIPES
[
recipe_name
]
in_features
=
x
.
shape
[
1
]
out_features
=
ws
[
0
].
shape
[
0
]
gradient
=
torch
.
ones
((
x
.
shape
[
0
],
out_features
),
dtype
=
torch
.
bfloat16
,
device
=
x
.
device
)
layer
=
GroupedLinear
(
num_gemms
,
in_features
,
out_features
,
bias
=
bias
is
not
None
,
params_dtype
=
params_dtype
,
)
layer
=
layer
.
to
(
"cuda"
)
with
torch
.
no_grad
():
for
i
in
range
(
num_gemms
):
weight_i
=
getattr
(
layer
,
f
"weight
{
i
}
"
)
weight_i
.
copy_
(
ws
[
i
])
if
bias
is
not
None
:
bias_i
=
getattr
(
layer
,
f
"bias
{
i
}
"
)
bias_i
.
copy_
(
bias
)
num_microbatches
=
32
label
=
f
"
{
recipe_name
}
_
{
'grouped'
}
"
torch
.
cuda
.
nvtx
.
range_push
(
label
)
timing
=
benchmark
.
Timer
(
stmt
=
(
"run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches,"
" recipe)"
),
globals
=
{
"run_linear_multiple_steps"
:
run_linear_multiple_steps
,
"layer"
:
layer
,
"x"
:
x
,
"m_splits"
:
m_splits
,
"mode"
:
mode
,
"gradient"
:
gradient
,
"num_microbatches"
:
num_microbatches
,
"recipe"
:
recipe
,
},
num_threads
=
1
,
).
blocked_autorange
(
min_run_time
=
5
)
print
(
f
"
{
recipe_name
}
:
{
timing
}
\n
"
)
timing_ms
=
timing
.
median
*
1000
/
num_microbatches
return
timing_ms
def
run_benchmark_linear
(
mkns
,
recipe_name
,
use_bias
,
num_gemms
=
4
):
data
=
[]
assert
not
use_bias
,
"Bias is not supported for GroupedLinear benchmark"
print
(
f
"========== Benchmarking
{
recipe_name
}
=========="
)
for
m
,
k
,
n
in
mkns
:
device
=
"cuda"
x
=
torch
.
randn
((
m
,
k
),
dtype
=
torch
.
bfloat16
,
device
=
device
,
requires_grad
=
True
)
ws
=
[
torch
.
randn
((
n
,
k
),
dtype
=
torch
.
bfloat16
,
device
=
device
)
for
_
in
range
(
num_gemms
)]
assert
m
%
num_gemms
==
0
m_splits
=
[
m
//
num_gemms
]
*
num_gemms
# Bias is not supported for GroupedLinear benchmark
bias
=
None
# Run the benchmark
print
(
f
"fwd_m=
{
m
}
, fwd_k=
{
k
}
, fwd_n=
{
n
}
"
)
grouped_fwd_bwd_timing_ms
=
benchmark_linear
(
x
,
ws
,
m_splits
,
bias
,
recipe_name
,
mode
=
"fwd_bwd"
,
num_gemms
=
num_gemms
,
)
# Append the results
data
.
append
(
[
m
,
k
,
n
,
recipe_name
,
num_gemms
,
grouped_fwd_bwd_timing_ms
,
]
)
df
=
pd
.
DataFrame
(
data
=
data
,
columns
=
[
"m"
,
"k"
,
"n"
,
"recipe"
,
"num_gemms"
,
"grouped_fwd_bwd_time_ms"
,
],
)
print
(
df
,
"
\n
"
)
return
df
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"Enable profiling mode"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"benchmark_output/"
,
help
=
"output path for report"
,
)
# arguments for recipe, options are fp8_sub_channel, mxfp8, bf16, all
parser
.
add_argument
(
"--recipe"
,
type
=
str
,
default
=
"bf16"
,
help
=
"Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all"
,
)
args
=
parser
.
parse_args
()
use_bias
=
False
# Set the MKN values to benchmark
mkns
=
[]
for
m
in
[
8192
]:
# for m in [4096, 8192, 16384]:
# for n in [1024, 2048, 4096, 8192, 16384]:
for
n
in
[
8192
]:
for
k
in
[
4096
]:
mkns
.
append
((
m
,
k
,
n
))
# default recipes to run if not specified
recipe_list
=
[
"bf16"
]
if
args
.
recipe
==
"all"
:
recipe_list
=
[
"bf16"
,
"fp8_sub_channel"
,
"mxfp8"
]
else
:
recipe_list
=
[
args
.
recipe
]
num_gemms_list
=
[
8
]
if
args
.
profile
:
mkns
=
[(
4096
,
4096
,
4096
)]
# in profile mode, only run one recipe specified in args.recipe
assert
args
.
recipe
!=
"all"
,
(
"In profile mode, only one recipe can be specified, please specify the recipe as"
" fp8_sub_channel, mxfp8, or bf16"
)
recipe_list
=
[
args
.
recipe
]
num_gemms_list
=
[
8
]
torch
.
autograd
.
profiler
.
emit_nvtx
(
record_shapes
=
True
).
__enter__
()
# Initialize a dataframe to store the results
df_linears
=
pd
.
DataFrame
()
# Run the fp8 benchmarks
for
num_gemms
in
num_gemms_list
:
print
(
f
"========== Benchmarking with num_gemms=
{
num_gemms
}
=========="
)
for
recipe_name
in
recipe_list
:
assert
recipe_name
in
[
"bf16"
,
"fp8_sub_channel"
,
"mxfp8"
,
],
"Recipe must be one of bf16, fp8_sub_channel, or mxfp8"
if
recipe_name
==
"mxfp8"
and
not
mxfp8_available
:
print
(
f
"MXFP8 is not available, skipping
{
recipe_name
}
"
)
continue
if
recipe_name
==
"fp8_sub_channel"
and
not
fp8_block_scaling_available
:
print
(
f
"FP8 block scaling is not available, skipping
{
recipe_name
}
"
)
continue
df
=
run_benchmark_linear
(
mkns
,
recipe_name
,
use_bias
,
num_gemms
=
num_gemms
,
)
df_linears
=
pd
.
concat
([
df_linears
,
df
])
print
(
df_linears
)
if
args
.
profile
:
torch
.
autograd
.
profiler
.
emit_nvtx
().
__exit__
(
None
,
None
,
None
)
build_tools/VERSION.txt
View file @
44740c6c
2.
6
.0.dev0
2.
7
.0.dev0
build_tools/jax.py
View file @
44740c6c
...
...
@@ -5,6 +5,7 @@
"""JAX related extensions."""
import
os
from
pathlib
import
Path
from
packaging
import
version
import
setuptools
...
...
@@ -27,7 +28,13 @@ def xla_path() -> str:
Throws FileNotFoundError if XLA source is not found."""
try
:
from
jax.extend
import
ffi
import
jax
if
version
.
parse
(
jax
.
__version__
)
>=
version
.
parse
(
"0.5.0"
):
from
jax
import
ffi
# pylint: disable=ungrouped-imports
else
:
from
jax.extend
import
ffi
# pylint: disable=ungrouped-imports
except
ImportError
:
if
os
.
getenv
(
"XLA_HOME"
):
xla_home
=
Path
(
os
.
getenv
(
"XLA_HOME"
))
...
...
build_tools/pytorch.py
View file @
44740c6c
...
...
@@ -13,12 +13,19 @@ from typing import List
def
install_requirements
()
->
List
[
str
]:
"""Install dependencies for TE/
JAX
extensions."""
"""Install dependencies for TE/
PyTorch
extensions."""
reqs
=
[
"torch>=2.1"
,
"einops"
]
# reqs.append(
# "nvdlfw-inspect @"
# " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
# )
reqs
.
extend
(
[
"torch>=2.1"
,
# "onnx",
# "onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871",
]
)
return
reqs
...
...
examples/jax/encoder/requirements.txt
View file @
44740c6c
datasets
datasets
<4.0.0
flax>=0.7.1
nltk>=3.8.2
optax
examples/jax/encoder/run_test_multiprocessing_encoder.sh
View file @
44740c6c
...
...
@@ -6,13 +6,13 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
# Define the test cases to run
TEST_CASES
=(
#
"test_te_bf16"
"test_te_bf16"
"test_te_delayed_scaling_fp8"
#
"test_te_current_scaling_fp8"
#
"test_te_mxfp8"
#
"test_te_bf16_shardy"
"test_te_current_scaling_fp8"
"test_te_mxfp8"
"test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy"
#
"test_te_current_scaling_fp8_shardy"
"test_te_current_scaling_fp8_shardy"
)
echo
...
...
@@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
LOG_FILE
=
"
${
TEST_CASE
}
_gpu_
${
i
}
.log"
# Run pytest and redirect stdout and stderr to the log file
pytest
-c
"
$TE_PATH
/tests/jax/pytest.ini"
\
pytest
-s
-c
"
$TE_PATH
/tests/jax/pytest.ini"
\
-vs
"
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::
$TEST_CASE
"
\
--num-process
=
$NUM_GPUS
\
--process-id
=
$i
>
"
$LOG_FILE
"
2>&1 &
...
...
@@ -40,21 +40,20 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
wait
tail
-n
+7
"
${
TEST_CASE
}
_gpu_0.log"
tail
-n
+7
"
${
TEST_CASE
}
_gpu_0.log"
# Check and print the log content accordingly
if
grep
-q
"FAILED"
"
${
TEST_CASE
}
_gpu_0.log"
;
then
HAS_FAILURE
=
1
echo
"...
$TEST_CASE
FAILED"
elif
grep
-q
"SKIPPED"
"
${
TEST_CASE
}
_gpu_0.log"
;
then
if
grep
-q
"SKIPPED"
"
${
TEST_CASE
}
_gpu_0.log"
;
then
echo
"...
$TEST_CASE
SKIPPED"
elif
grep
-q
"PASSED"
"
${
TEST_CASE
}
_gpu_0.log"
;
then
echo
"...
$TEST_CASE
PASSED"
else
echo
"Invalid
${
TEST_CASE
}
_gpu_0.log"
HAS_FAILURE
=
1
echo
"...
$TEST_CASE
FAILED"
fi
# Remove the log file after processing it
wait
rm
${
TEST_CASE
}
_gpu_
*
.log
done
wait
exit
$HAS_FAILURE
examples/jax/encoder/test_model_parallel_encoder.py
View file @
44740c6c
...
...
@@ -25,6 +25,7 @@ from common import (
assert_params_sufficiently_sharded
,
)
import
transformer_engine.jax
as
te
import
transformer_engine.jax.cpp_extensions
as
tex
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
...
...
@@ -263,8 +264,10 @@ def train_and_evaluate(args):
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu_dp
,
num_gpu_tp
))
with
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
)
)
as
mesh
,
nn_partitioning
.
axis_rules
(
((
NAMED_BROADCAST_AXIS
,
None
),
(
NAMED_TP_AXIS
,
DEVICE_TP_AXIS
))
)
as
mesh
,
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
):
rng
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
rng
,
params_rng
=
jax
.
random
.
split
(
rng
)
...
...
@@ -275,22 +278,21 @@ def train_and_evaluate(args):
mask_shape
=
[
args
.
batch_size
,
1
,
args
.
max_seq_len
,
args
.
max_seq_len
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
axis_rules
=
flax
.
linen
.
get_logical_axis_rules
()
axis_rules
+=
((
NAMED_BROADCAST_AXIS
,
None
),
(
NAMED_TP_AXIS
,
DEVICE_TP_AXIS
))
te_extended_axis_rules
=
te_flax
.
extend_logical_axis_rules
(
axis_rules
)
with
flax
.
linen
.
logical_axis_rules
(
te_extended_axis_rules
):
print
(
f
"Device mesh:
{
mesh
}
"
)
print
(
f
"Axis rules:
{
te_extended_axis_rules
}
"
)
encoder
=
Net
(
num_embed
,
args
.
enable_sp
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
masks
=
jnp
.
zeros
(
mask_shape
,
dtype
=
jnp
.
uint8
)
abs_var_collect
=
jax
.
eval_shape
(
encoder
.
init
,
init_rngs
,
inputs
,
masks
)
# Get the base axis rules and extend them with TE's rules.
axis_rules
=
nn_partitioning
.
get_axis_rules
()
te_extended_axis_rules
=
te_flax
.
extend_logical_axis_rules
(
axis_rules
)
print
(
f
"Device mesh:
{
mesh
}
"
)
print
(
f
"Axis rules:
{
te_extended_axis_rules
}
"
)
logical_partition_spec
=
nn
.
get_partition_spec
(
abs_var_collect
)
# Note that `nn.logical_to_mesh_sharding` returns a dict with an extra
...
...
@@ -307,7 +309,9 @@ def train_and_evaluate(args):
key
:
params_sharding
[
PARAMS_KEY
]
if
key
is
PARAMS_KEY
else
None
for
key
in
abs_var_collect
}
jit_encoder_init
=
jax
.
jit
(
encoder
.
init
,
in_shardings
,
out_shardings
)
jit_encoder_init
=
jax
.
jit
(
encoder
.
init
,
in_shardings
=
in_shardings
,
out_shardings
=
out_shardings
)
var_collect
=
jit_encoder_init
(
init_rngs
,
inputs
,
masks
)
# Check if params are sufficiently sharded after initialization
...
...
@@ -344,11 +348,15 @@ def train_and_evaluate(args):
None
,
)
out_shardings
=
(
state_sharding
,
None
,
None
,
None
)
jit_train_step
=
jax
.
jit
(
train_step
,
in_shardings
,
out_shardings
)
jit_train_step
=
jax
.
jit
(
train_step
,
in_shardings
=
in_shardings
,
out_shardings
=
out_shardings
)
in_shardings
=
(
state_sharding
,
inputs_sharding
,
masks_sharding
,
labels_sharding
,
None
)
out_shardings
=
(
None
,
None
)
jit_eval_step
=
jax
.
jit
(
eval_step
,
in_shardings
,
out_shardings
)
jit_eval_step
=
jax
.
jit
(
eval_step
,
in_shardings
=
in_shardings
,
out_shardings
=
out_shardings
)
if
args
.
use_fp8
:
labels
=
jnp
.
zeros
(
label_shape
,
dtype
=
jnp
.
bfloat16
)
...
...
@@ -459,14 +467,14 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
def
setUp
(
self
):
"""Run
3
epochs for testing"""
self
.
args
=
encoder_parser
([
"--epochs"
,
"
3
"
])
"""Run
5
epochs for testing"""
self
.
args
=
encoder_parser
([
"--epochs"
,
"
5
"
])
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
455
and
actual
[
1
]
>
0.
785
assert
actual
[
0
]
<
0.
39
and
actual
[
1
]
>
0.
83
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
...
...
@@ -474,7 +482,7 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
455
and
actual
[
1
]
>
0.
785
assert
actual
[
0
]
<
0.
39
and
actual
[
1
]
>
0.
83
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8
(
self
):
...
...
@@ -482,14 +490,14 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
455
and
actual
[
1
]
>
0.
785
assert
actual
[
0
]
<
0.
39
and
actual
[
1
]
>
0.
83
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16_with_sp
(
self
):
"""Test Transformer Engine with BF16 + SP"""
self
.
args
.
enable_sp
=
True
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
455
and
actual
[
1
]
>
0.
785
assert
actual
[
0
]
<
0.
39
and
actual
[
1
]
>
0.
83
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8_with_sp
(
self
):
...
...
@@ -498,7 +506,7 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
455
and
actual
[
1
]
>
0.
785
assert
actual
[
0
]
<
0.
39
and
actual
[
1
]
>
0.
83
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8_with_sp
(
self
):
...
...
@@ -507,14 +515,14 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
455
and
actual
[
1
]
>
0.
785
assert
actual
[
0
]
<
0.
39
and
actual
[
1
]
>
0.
83
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16_shardy
(
self
):
"""Test Transformer Engine with BF16"""
self
.
args
.
enable_shardy
=
True
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
455
and
actual
[
1
]
>
0.
785
assert
actual
[
0
]
<
0.
39
and
actual
[
1
]
>
0.
83
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8_shardy
(
self
):
...
...
@@ -523,7 +531,7 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
455
and
actual
[
1
]
>
0.
785
assert
actual
[
0
]
<
0.
39
and
actual
[
1
]
>
0.
83
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8_with_sp_shardy
(
self
):
...
...
@@ -533,9 +541,32 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
455
and
actual
[
1
]
>
0.
785
assert
actual
[
0
]
<
0.
39
and
actual
[
1
]
>
0.
83
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
@
unittest
.
skipIf
(
tex
.
gemm_uses_jax_dot
(),
"`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def
test_te_mxfp8_shardy
(
self
):
"""Test Transformer Engine with MXFP8"""
self
.
args
.
enable_shardy
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.39
and
actual
[
1
]
>
0.83
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
@
unittest
.
skipIf
(
tex
.
gemm_uses_jax_dot
(),
"`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def
test_te_mxfp8_with_sp_shardy
(
self
):
"""Test Transformer Engine with MXFP8 + SP"""
self
.
args
.
enable_shardy
=
True
self
.
args
.
enable_sp
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.39
and
actual
[
1
]
>
0.83
if
__name__
==
"__main__"
:
...
...
examples/jax/encoder/test_multigpu_encoder.py
View file @
44740c6c
...
...
@@ -21,6 +21,7 @@ from jax.sharding import PartitionSpec, NamedSharding
from
common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
import
transformer_engine.jax
as
te
import
transformer_engine.jax.cpp_extensions
as
tex
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
...
...
@@ -258,7 +259,13 @@ def train_and_evaluate(args):
fp8_recipe
=
None
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu
,))
with
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,))
as
mesh
:
with
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,)
)
as
mesh
,
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
None
,
None
,
None
),
):
rng
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
rng
,
params_rng
=
jax
.
random
.
split
(
rng
)
...
...
@@ -269,17 +276,14 @@ def train_and_evaluate(args):
mask_shape
=
[
args
.
batch_size
,
1
,
args
.
max_seq_len
,
args
.
max_seq_len
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
None
,
None
,
None
),
):
# Add TE logical axis rules to our Flax logical axis rule context. This must be done inside fp8_autocast
sharding_rules
=
te_flax
.
extend_logical_axis_rules
(
tuple
())
with
flax
.
linen
.
logical_axis_rules
(
sharding_rules
):
encoder
=
Net
(
num_embed
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
masks
=
jnp
.
zeros
(
mask_shape
,
dtype
=
jnp
.
uint8
)
abs_var_collect
=
jax
.
eval_shape
(
encoder
.
init
,
init_rngs
,
inputs
,
masks
)
sharding_rules
=
te_flax
.
extend_logical_axis_rules
(
tuple
())
params_sharding
=
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
)
inputs_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
DEVICE_DP_AXIS
,
None
))
masks_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
DEVICE_DP_AXIS
,
None
,
None
,
None
))
...
...
@@ -288,7 +292,9 @@ def train_and_evaluate(args):
out_shardings
=
{
key
:
params_sharding
if
key
is
PARAMS_KEY
else
None
for
key
in
abs_var_collect
}
jit_encoder_init
=
jax
.
jit
(
encoder
.
init
,
in_shardings
,
out_shardings
)
jit_encoder_init
=
jax
.
jit
(
encoder
.
init
,
in_shardings
=
in_shardings
,
out_shardings
=
out_shardings
)
var_collect
=
jit_encoder_init
(
init_rngs
,
inputs
,
masks
)
optimizer
=
optax
.
adamw
(
args
.
lr
)
...
...
@@ -312,11 +318,15 @@ def train_and_evaluate(args):
None
,
)
out_shardings
=
(
state_sharding
,
None
,
None
,
None
)
jit_train_step
=
jax
.
jit
(
train_step
,
in_shardings
,
out_shardings
)
jit_train_step
=
jax
.
jit
(
train_step
,
in_shardings
=
in_shardings
,
out_shardings
=
out_shardings
)
in_shardings
=
(
state_sharding
,
inputs_sharding
,
masks_sharding
,
labels_sharding
,
None
)
out_shardings
=
(
None
,
None
)
jit_eval_step
=
jax
.
jit
(
eval_step
,
in_shardings
,
out_shardings
)
jit_eval_step
=
jax
.
jit
(
eval_step
,
in_shardings
=
in_shardings
,
out_shardings
=
out_shardings
)
if
args
.
use_fp8
:
labels
=
jnp
.
zeros
(
label_shape
,
dtype
=
jnp
.
bfloat16
)
...
...
@@ -424,14 +434,14 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
def
setUp
(
self
):
"""Run
3
epochs for testing"""
self
.
args
=
encoder_parser
([
"--epochs"
,
"
3
"
])
"""Run
5
epochs for testing"""
self
.
args
=
encoder_parser
([
"--epochs"
,
"
5
"
])
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.5
35
and
actual
[
1
]
>
0.7
3
assert
actual
[
0
]
<
0.5
2
and
actual
[
1
]
>
0.7
4
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
...
...
@@ -439,7 +449,7 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.5
35
and
actual
[
1
]
>
0.7
3
assert
actual
[
0
]
<
0.5
2
and
actual
[
1
]
>
0.7
4
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_current_scaling_fp8
(
self
):
...
...
@@ -447,7 +457,7 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"Float8CurrentScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.5
35
and
actual
[
1
]
>
0.7
3
assert
actual
[
0
]
<
0.5
2
and
actual
[
1
]
>
0.7
4
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8
(
self
):
...
...
@@ -455,14 +465,14 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.5
35
and
actual
[
1
]
>
0.7
3
assert
actual
[
0
]
<
0.5
2
and
actual
[
1
]
>
0.7
4
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16_shardy
(
self
):
"""Test Transformer Engine with BF16"""
self
.
args
.
enable_shardy
=
True
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.5
35
and
actual
[
1
]
>
0.7
3
assert
actual
[
0
]
<
0.5
2
and
actual
[
1
]
>
0.7
4
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8_shardy
(
self
):
...
...
@@ -471,9 +481,7 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.535
and
actual
[
1
]
>
0.73
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
assert
actual
[
0
]
<
0.52
and
actual
[
1
]
>
0.74
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_current_scaling_fp8_shardy
(
self
):
...
...
@@ -482,7 +490,19 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"Float8CurrentScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.535
and
actual
[
1
]
>
0.73
assert
actual
[
0
]
<
0.52
and
actual
[
1
]
>
0.74
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
@
unittest
.
skipIf
(
tex
.
gemm_uses_jax_dot
(),
"`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def
test_te_mxfp8_shardy
(
self
):
"""Test Transformer Engine with MXFP8"""
self
.
args
.
enable_shardy
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.52
and
actual
[
1
]
>
0.74
if
__name__
==
"__main__"
:
...
...
examples/jax/encoder/test_multiprocessing_encoder.py
View file @
44740c6c
...
...
@@ -28,8 +28,8 @@ from common import (
get_fp8_recipe_from_name_string
,
)
import
transformer_engine.jax
as
te
import
transformer_engine.jax.cpp_extensions
as
tex
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
os
.
environ
[
"CUDA_DEVICE_ORDER"
]
=
"PCI_BUS_ID"
...
...
@@ -379,8 +379,11 @@ def train_and_evaluate(args):
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu_dp
,
num_gpu_tp
))
with
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
)
)
as
mesh
:
)
as
mesh
,
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
):
rng
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
rng
,
params_rng
=
jax
.
random
.
split
(
rng
)
rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
)
...
...
@@ -390,18 +393,18 @@ def train_and_evaluate(args):
mask_shape
=
[
args
.
batch_size
,
1
,
args
.
max_seq_len
,
args
.
max_seq_len
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
):
# Create custom Flax logical axis rules for sharding.
customized_rules
=
((
NAMED_BROADCAST_AXIS
,
None
),
(
NAMED_TP_AXIS
,
DEVICE_TP_AXIS
))
# Extend the logical axis rules with TE's rules. This must be done inside fp8_autocast.
sharding_rules
=
te_flax
.
extend_logical_axis_rules
(
customized_rules
)
with
flax
.
linen
.
logical_axis_rules
(
sharding_rules
):
encoder
=
Net
(
num_embed
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
masks
=
jnp
.
zeros
(
mask_shape
,
dtype
=
jnp
.
uint8
)
abs_var_collect
=
jax
.
eval_shape
(
encoder
.
init
,
init_rngs
,
inputs
,
masks
)
customized_rules
=
((
NAMED_BROADCAST_AXIS
,
None
),
(
NAMED_TP_AXIS
,
DEVICE_TP_AXIS
))
sharding_rules
=
te_flax
.
extend_logical_axis_rules
(
tuple
())
+
customized_rules
params_sharding
=
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
)
inputs_pspec
=
jax
.
sharding
.
PartitionSpec
(
DEVICE_DP_AXIS
,
None
)
masks_pspec
=
jax
.
sharding
.
PartitionSpec
(
DEVICE_DP_AXIS
,
None
,
None
,
None
)
...
...
@@ -412,7 +415,9 @@ def train_and_evaluate(args):
out_shardings
=
{
key
:
params_sharding
if
key
is
PARAMS_KEY
else
None
for
key
in
abs_var_collect
}
jit_encoder_init
=
jax
.
jit
(
encoder
.
init
,
in_shardings
,
out_shardings
)
jit_encoder_init
=
jax
.
jit
(
encoder
.
init
,
in_shardings
=
in_shardings
,
out_shardings
=
out_shardings
)
var_collect
=
jit_encoder_init
(
init_rngs
,
inputs
,
masks
)
optimizer
=
optax
.
adamw
(
args
.
lr
)
...
...
@@ -432,11 +437,15 @@ def train_and_evaluate(args):
None
,
)
out_shardings
=
(
state_sharding
,
None
,
None
,
None
)
jit_train_step
=
jax
.
jit
(
train_step
,
in_shardings
,
out_shardings
)
jit_train_step
=
jax
.
jit
(
train_step
,
in_shardings
=
in_shardings
,
out_shardings
=
out_shardings
)
in_shardings
=
(
state_sharding
,
inputs_sharding
,
masks_sharding
,
labels_sharding
,
None
)
out_shardings
=
(
None
,
None
)
jit_eval_step
=
jax
.
jit
(
eval_step
,
in_shardings
,
out_shardings
)
jit_eval_step
=
jax
.
jit
(
eval_step
,
in_shardings
=
in_shardings
,
out_shardings
=
out_shardings
)
if
args
.
use_fp8
:
labels
=
jnp
.
zeros
(
label_shape
,
dtype
=
jnp
.
bfloat16
)
...
...
@@ -578,8 +587,8 @@ class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
def
exec
(
self
,
use_fp8
,
fp8_recipe
,
*
,
enable_shardy
=
False
):
"""Run
3
epochs for testing"""
args
=
encoder_parser
([])
"""Run
5
epochs for testing"""
args
=
encoder_parser
([
"--epochs"
,
"5"
])
num_gpu
=
self
.
num_process
tp_size
=
2
if
num_gpu
>
1
and
num_gpu
%
2
==
0
else
1
...
...
@@ -601,7 +610,7 @@ class TestEncoder(unittest.TestCase):
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
result
=
self
.
exec
(
False
,
None
)
assert
result
[
0
]
<
0.
505
and
result
[
1
]
>
0.
755
assert
result
[
0
]
<
0.
43
and
result
[
1
]
>
0.
80
@
unittest
.
skipIf
(
not
is_fp8_supported
(),
"Device compute capability 9.0+ is required for DelayedScaling FP8"
...
...
@@ -609,7 +618,7 @@ class TestEncoder(unittest.TestCase):
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
result
=
self
.
exec
(
True
,
"DelayedScaling"
)
assert
result
[
0
]
<
0.
506
and
result
[
1
]
>
0.
753
assert
result
[
0
]
<
0.
43
and
result
[
1
]
>
0.
80
@
unittest
.
skipIf
(
not
is_fp8_supported
(),
"Device compute capability 9.0+ is required for CurrentScaling FP8"
...
...
@@ -617,7 +626,7 @@ class TestEncoder(unittest.TestCase):
def
test_te_current_scaling_fp8
(
self
):
"""Test Transformer Engine with CurrentScaling FP8"""
result
=
self
.
exec
(
True
,
"Float8CurrentScaling"
)
assert
result
[
0
]
<
0.
507
and
result
[
1
]
>
0.
753
assert
result
[
0
]
<
0.
43
and
result
[
1
]
>
0.
80
@
unittest
.
skipIf
(
not
is_mxfp8_supported
(),
"Device compute capability 10.0+ is required for MXFP8"
...
...
@@ -625,13 +634,13 @@ class TestEncoder(unittest.TestCase):
def
test_te_mxfp8
(
self
):
"""Test Transformer Engine with MXFP8"""
result
=
self
.
exec
(
True
,
"MXFP8BlockScaling"
)
assert
result
[
0
]
<
0.
505
and
result
[
1
]
>
0.
754
assert
result
[
0
]
<
0.
43
and
result
[
1
]
>
0.
80
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16_shardy
(
self
):
"""Test Transformer Engine with BF16"""
result
=
self
.
exec
(
False
,
None
,
enable_shardy
=
True
)
assert
result
[
0
]
<
0.
505
and
result
[
1
]
>
0.
755
assert
result
[
0
]
<
0.
43
and
result
[
1
]
>
0.
80
@
unittest
.
skipIf
(
not
is_fp8_supported
(),
"Device compute capability 9.0+ is required for DelayedScaling FP8"
...
...
@@ -639,9 +648,7 @@ class TestEncoder(unittest.TestCase):
def
test_te_delayed_scaling_fp8_shardy
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
result
=
self
.
exec
(
True
,
"DelayedScaling"
,
enable_shardy
=
True
)
assert
result
[
0
]
<
0.506
and
result
[
1
]
>
0.753
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
assert
result
[
0
]
<
0.43
and
result
[
1
]
>
0.80
@
unittest
.
skipIf
(
not
is_fp8_supported
(),
"Device compute capability 9.0+ is required for CurrentScaling FP8"
...
...
@@ -649,7 +656,18 @@ class TestEncoder(unittest.TestCase):
def
test_te_current_scaling_fp8_shardy
(
self
):
"""Test Transformer Engine with CurrentScaling FP8"""
result
=
self
.
exec
(
True
,
"Float8CurrentScaling"
,
enable_shardy
=
True
)
assert
result
[
0
]
<
0.507
and
result
[
1
]
>
0.753
assert
result
[
0
]
<
0.43
and
result
[
1
]
>
0.80
@
unittest
.
skipIf
(
not
is_mxfp8_supported
(),
"Device compute capability 10.0+ is required for MXFP8"
)
@
unittest
.
skipIf
(
tex
.
gemm_uses_jax_dot
(),
"`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def
test_te_mxfp8_shardy
(
self
):
"""Test Transformer Engine with MXFP8"""
result
=
self
.
exec
(
True
,
"MXFP8BlockScaling"
,
enable_shardy
=
True
)
assert
result
[
0
]
<
0.43
and
result
[
1
]
>
0.80
if
__name__
==
"__main__"
:
...
...
examples/jax/mnist/requirements.txt
View file @
44740c6c
datasets
datasets
<4.0.0
flax>=0.7.1
optax
Pillow
qa/L0_cppunittest/test.sh
View file @
44740c6c
...
...
@@ -6,7 +6,7 @@ set -e
# Find TE
:
${
TE_PATH
:
=/opt/transformerengine
}
TE_LIB_PATH
=
`
pip3 show transformer-engine |
grep
Location
|
cut
-d
' '
-f
2
`
TE_LIB_PATH
=
$(
pip3 show transformer-engine |
grep
-E
"
Location
:|Editable project location:"
|
tail
-n
1 |
awk
'{print $NF}'
)
export
LD_LIBRARY_PATH
=
$TE_LIB_PATH
:
$LD_LIBRARY_PATH
# Set parallelization parameters
...
...
qa/L0_jax_distributed_unittest/test.sh
View file @
44740c6c
...
...
@@ -24,11 +24,11 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
#
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
#
wait
#
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
#
wait
.
$TE_PATH
/examples/jax/encoder/run_test_multiprocessing_encoder.sh
||
test_fail
"run_test_multiprocessing_encoder.sh"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_multigpu_encoder.xml
$TE_PATH
/examples/jax/encoder/test_multigpu_encoder.py
||
test_fail
"test_multigpu_encoder.py"
wait
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_model_parallel_encoder.xml
$TE_PATH
/examples/jax/encoder/test_model_parallel_encoder.py
||
test_fail
"test_model_parallel_encoder.py"
wait
TE_PATH
=
$TE_PATH
bash
$TE_PATH
/examples/jax/encoder/run_test_multiprocessing_encoder.sh
||
test_fail
"run_test_multiprocessing_encoder.sh"
if
[
$RET
-ne
0
]
;
then
echo
"Error: some sub-tests failed:
$FAILED_CASES
"
...
...
qa/L0_pytorch_debug_unittest/test.sh
View file @
44740c6c
...
...
@@ -20,7 +20,8 @@ pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TE
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_numerics.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
NVTE_TORCH_COMPILE
=
0 pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_api_features.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
# standard numerics tests with initialized debug
# standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED
=
True
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 pytest
-v
-s
$TE_PATH
/tests/pytorch/test_sanity.py
||
FAIL
=
1
NVTE_TEST_NVINSPECT_ENABLED
=
True
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 pytest
-v
-s
$TE_PATH
/tests/pytorch/test_numerics.py
||
FAIL
=
1
exit
$FAIL
qa/L0_pytorch_unittest/test.sh
View file @
44740c6c
...
...
@@ -23,6 +23,8 @@ set -x
mkdir
-p
"
$XML_LOG_DIR
"
pip3
install
pytest
==
8.2.1
||
error_exit
"Failed to install pytest"
pip3
install
onnxruntime
==
1.20.1
||
error_exit
"Failed to install onnxruntime"
pip3
install
onnxruntime_extensions
==
0.13.0
||
error_exit
"Failed to install onnxruntime_extensions"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_sanity.xml
$TE_PATH
/tests/pytorch/test_sanity.py
||
test_fail
"test_sanity.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_recipe.xml
$TE_PATH
/tests/pytorch/test_recipe.py
||
test_fail
"test_recipe.py"
...
...
@@ -43,6 +45,7 @@ NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_gqa.xml
$TE_PATH
/tests/pytorch/test_gqa.py
||
test_fail
"test_gqa.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_optimizer.xml
$TE_PATH
/tests/pytorch/test_fused_optimizer.py
||
test_fail
"test_fused_optimizer.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_multi_tensor.xml
$TE_PATH
/tests/pytorch/test_multi_tensor.py
||
test_fail
"test_multi_tensor.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_onnx_export.xml
$TE_PATH
/tests/pytorch/test_onnx_export.py
||
test_fail
"test_onnx_export.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fusible_ops.xml
$TE_PATH
/tests/pytorch/test_fusible_ops.py
||
test_fail
"test_fusible_ops.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_permutation.xml
$TE_PATH
/tests/pytorch/test_permutation.py
||
test_fail
"test_permutation.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_parallel_cross_entropy.xml
$TE_PATH
/tests/pytorch/test_parallel_cross_entropy.py
||
test_fail
"test_parallel_cross_entropy.py"
...
...
@@ -50,6 +53,8 @@ NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_attn.xml
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn.py
||
test_fail
"test_fused_attn.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_kv_cache.xml
$TE_PATH
/tests/pytorch/fused_attn/test_kv_cache.py
||
test_fail
"test_kv_cache.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_hf_integration.xml
$TE_PATH
/tests/pytorch/test_hf_integration.py
||
test_fail
"test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH
=
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_checkpoint.xml
$TE_PATH
/tests/pytorch/test_checkpoint.py
||
test_fail
"test_checkpoint.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_router.xml
$TE_PATH
/tests/pytorch/test_fused_router.py
||
test_fail
"test_fused_router.py"
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
...
...
qa/L3_pytorch_FA_versions_test/test.sh
View file @
44740c6c
...
...
@@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri
export
FLASH_ATTN_CUDA_ARCHS
=
$sm_arch
if
[
$sm_arch
-gt
90
]
then
FA_versions
=(
2.
7.3
)
FA_versions
=(
2.
8.1
)
elif
[
$sm_arch
-eq
90
]
then
FA_versions
=(
2.
5.7 2.7.3
3.0.0b1
)
FA_versions
=(
2.
7.3 2.8.1
3.0.0b1
)
fi
for
fa_version
in
"
${
FA_versions
[@]
}
"
...
...
tests/cpp/CMakeLists.txt
View file @
44740c6c
...
...
@@ -66,11 +66,13 @@ enable_testing()
include_directories
(
${
gtest_SOURCE_DIR
}
/include
${
gtest_SOURCE_DIR
}
)
if
(
NOT DEFINED TE_LIB_PATH
)
execute_process
(
COMMAND bash -c
"pip3 show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '
\n
'"
OUTPUT_VARIABLE TE_LIB_PATH
)
execute_process
(
COMMAND bash -c
"python3 -c 'import transformer_engine as te; print(te.__file__)'"
OUTPUT_VARIABLE TE_LIB_FILE
OUTPUT_STRIP_TRAILING_WHITESPACE
)
get_filename_component
(
TE_LIB_PATH
${
TE_LIB_FILE
}
DIRECTORY
)
endif
()
find_library
(
TE_LIB NAMES transformer_engine PATHS
"
${
TE_LIB_PATH
}
/
transformer_engine
"
${
TE_LIB_PATH
}
ENV TE_LIB_PATH REQUIRED
)
find_library
(
TE_LIB NAMES transformer_engine PATHS
"
${
TE_LIB_PATH
}
/
..
"
${
TE_LIB_PATH
}
ENV TE_LIB_PATH REQUIRED
)
message
(
STATUS
"Found transformer_engine library:
${
TE_LIB
}
"
)
include_directories
(
../../transformer_engine/common/include
)
...
...
tests/cpp/operator/CMakeLists.txt
View file @
44740c6c
...
...
@@ -22,8 +22,10 @@ list(APPEND test_cuda_sources
test_act.cu
test_normalization.cu
test_normalization_mxfp8.cu
test_memset.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
test_multi_unpadding.cu
test_causal_softmax.cu
test_swizzle.cu
../test_common.cu
)
...
...
Prev
1
2
3
4
5
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment