Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
34 additions
and
24 deletions
+34
-24
examples/jax/encoder/common.py
examples/jax/encoder/common.py
+1
-1
examples/jax/encoder/conftest.py
examples/jax/encoder/conftest.py
+1
-1
examples/jax/encoder/run_test_multiprocessing_encoder.sh
examples/jax/encoder/run_test_multiprocessing_encoder.sh
+1
-1
examples/jax/encoder/test_model_parallel_encoder.py
examples/jax/encoder/test_model_parallel_encoder.py
+9
-5
examples/jax/encoder/test_multigpu_encoder.py
examples/jax/encoder/test_multigpu_encoder.py
+1
-1
examples/jax/encoder/test_multiprocessing_encoder.py
examples/jax/encoder/test_multiprocessing_encoder.py
+1
-1
examples/jax/encoder/test_single_gpu_encoder.py
examples/jax/encoder/test_single_gpu_encoder.py
+1
-1
examples/jax/mnist/test_single_gpu_mnist.py
examples/jax/mnist/test_single_gpu_mnist.py
+1
-1
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
+1
-1
examples/pytorch/fsdp/README.md
examples/pytorch/fsdp/README.md
+1
-1
examples/pytorch/fsdp/fsdp.py
examples/pytorch/fsdp/fsdp.py
+1
-1
examples/pytorch/mnist/main.py
examples/pytorch/mnist/main.py
+1
-1
pyproject.toml
pyproject.toml
+1
-1
qa/L0_cppunittest/test.sh
qa/L0_cppunittest/test.sh
+1
-1
qa/L0_jax_distributed_unittest/test.sh
qa/L0_jax_distributed_unittest/test.sh
+4
-1
qa/L0_jax_lint/test.sh
qa/L0_jax_lint/test.sh
+1
-1
qa/L0_jax_unittest/test.sh
qa/L0_jax_unittest/test.sh
+4
-1
qa/L0_jax_wheel/test.sh
qa/L0_jax_wheel/test.sh
+1
-1
qa/L0_license/copyright_checker.py
qa/L0_license/copyright_checker.py
+1
-1
qa/L0_license/test.sh
qa/L0_license/test.sh
+1
-1
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
examples/jax/encoder/common.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Shared functions for the encoder tests"""
...
...
examples/jax/encoder/conftest.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
examples/jax/encoder/run_test_multiprocessing_encoder.sh
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
examples/jax/encoder/test_model_parallel_encoder.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Encoder training on multi-GPU with tesnor parallelism"""
import
argparse
import
os
import
unittest
from
functools
import
partial
...
...
@@ -489,6 +490,9 @@ class TestEncoder(unittest.TestCase):
def
setUp
(
self
):
"""Run 5 epochs for testing"""
# TODO(jberchtold): Remove once fused attention from cuDNN supports determinism on Blackwell
if
"NVTE_FUSED_ATTN"
not
in
os
.
environ
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
self
.
args
=
encoder_parser
([
"--epochs"
,
"5"
])
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
...
...
@@ -503,7 +507,7 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.36
1
and
actual
[
1
]
>
0.84
assert
actual
[
0
]
<
0.36
2
and
actual
[
1
]
>
0.84
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8
(
self
):
...
...
@@ -535,7 +539,7 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.36
and
actual
[
1
]
>
0.84
assert
actual
[
0
]
<
0.36
2
and
actual
[
1
]
>
0.84
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8_with_sp
(
self
):
...
...
@@ -569,7 +573,7 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.36
and
actual
[
1
]
>
0.84
assert
actual
[
0
]
<
0.36
2
and
actual
[
1
]
>
0.84
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8_with_sp_shardy
(
self
):
...
...
@@ -579,7 +583,7 @@ class TestEncoder(unittest.TestCase):
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.36
1
and
actual
[
1
]
>
0.84
assert
actual
[
0
]
<
0.36
2
and
actual
[
1
]
>
0.84
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8_shardy
(
self
):
...
...
examples/jax/encoder/test_multigpu_encoder.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Encoder training on multi-GPU with data parallelism"""
...
...
examples/jax/encoder/test_multiprocessing_encoder.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Encoder training with multi-GPU, multiprocessing, and tensor parallelism"""
...
...
examples/jax/encoder/test_single_gpu_encoder.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Encoder training on single GPU"""
...
...
examples/jax/mnist/test_single_gpu_mnist.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""MNIST training on single GPU"""
...
...
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
View file @
0d874a4e
#!/usr/bin/python3
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
examples/pytorch/fsdp/README.md
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
examples/pytorch/fsdp/fsdp.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
examples/pytorch/mnist/main.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
pyproject.toml
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
qa/L0_cppunittest/test.sh
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
qa/L0_jax_distributed_unittest/test.sh
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
export
TRITON_PTXAS_PATH
=
/usr/local/cuda/bin/ptxas
function
error_exit
()
{
echo
"Error:
$1
"
...
...
@@ -16,6 +17,8 @@ function test_fail() {
RET
=
0
FAILED_CASES
=
""
export
NVTE_JAX_TEST_TIMING
=
1
:
${
TE_PATH
:
=/opt/transformerengine
}
:
${
XML_LOG_DIR
:
=/logs
}
mkdir
-p
"
$XML_LOG_DIR
"
...
...
qa/L0_jax_lint/test.sh
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
qa/L0_jax_unittest/test.sh
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
export
TRITON_PTXAS_PATH
=
/usr/local/cuda/bin/ptxas
set
-x
...
...
@@ -18,6 +19,8 @@ function test_fail() {
RET
=
0
FAILED_CASES
=
""
export
NVTE_JAX_TEST_TIMING
=
1
pip3
install
"nltk>=3.8.2"
||
error_exit
"Failed to install nltk"
pip3
install
pytest
==
8.2.1
||
error_exit
"Failed to install pytest"
...
...
qa/L0_jax_wheel/test.sh
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
qa/L0_license/copyright_checker.py
View file @
0d874a4e
#!/usr/bin/env python3
# coding: utf-8
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
qa/L0_license/test.sh
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment