Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
175 additions
and
88 deletions
+175
-88
tests/pytorch/debug/conftest.py
tests/pytorch/debug/conftest.py
+1
-1
tests/pytorch/debug/run_distributed.py
tests/pytorch/debug/run_distributed.py
+1
-1
tests/pytorch/debug/test_api_features.py
tests/pytorch/debug/test_api_features.py
+1
-1
tests/pytorch/debug/test_config.py
tests/pytorch/debug/test_config.py
+1
-1
tests/pytorch/debug/test_distributed.py
tests/pytorch/debug/test_distributed.py
+1
-1
tests/pytorch/debug/test_log.py
tests/pytorch/debug/test_log.py
+23
-1
tests/pytorch/debug/test_numerics.py
tests/pytorch/debug/test_numerics.py
+1
-3
tests/pytorch/debug/test_perf.py
tests/pytorch/debug/test_perf.py
+1
-1
tests/pytorch/debug/test_sanity.py
tests/pytorch/debug/test_sanity.py
+1
-1
tests/pytorch/debug/utils.py
tests/pytorch/debug/utils.py
+1
-1
tests/pytorch/distributed/run_fsdp2_model.py
tests/pytorch/distributed/run_fsdp2_model.py
+1
-1
tests/pytorch/distributed/run_gemm_with_overlap.py
tests/pytorch/distributed/run_gemm_with_overlap.py
+3
-12
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+1
-1
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+10
-2
tests/pytorch/distributed/run_numerics_exact.py
tests/pytorch/distributed/run_numerics_exact.py
+1
-1
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
+116
-54
tests/pytorch/distributed/test_comm_gemm_overlap.py
tests/pytorch/distributed/test_comm_gemm_overlap.py
+7
-1
tests/pytorch/distributed/test_fusible_ops.py
tests/pytorch/distributed/test_fusible_ops.py
+1
-1
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
.../pytorch/distributed/test_fusible_ops_with_userbuffers.py
+1
-1
tests/pytorch/distributed/test_numerics.py
tests/pytorch/distributed/test_numerics.py
+2
-2
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
tests/pytorch/debug/conftest.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
import
pytest
import
pytest
...
...
tests/pytorch/debug/run_distributed.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
tests/pytorch/debug/test_api_features.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
tests/pytorch/debug/test_config.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
import
pathlib
import
pathlib
...
...
tests/pytorch/debug/test_distributed.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
tests/pytorch/debug/test_log.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -363,6 +363,28 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
...
@@ -363,6 +363,28 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState
.
_reset
()
TEDebugState
.
_reset
()
def
test_log_grouped_gemm
(
feature_dirs
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
log_all_stats_config
=
LOG_QUANTIZED_CONFIG_BASE
.
format
(
stats
=
", "
.
join
(
all_stats
))
with
debug_session
(
log_all_stats_config
,
feature_dirs
)
as
log_dir
:
model
=
te
.
GroupedLinear
(
3
,
128
,
128
,
name
=
"linear1"
,
params_dtype
=
torch
.
bfloat16
)
inp
=
torch
.
randn
((
1
,
128
,
128
),
dtype
=
torch
.
bfloat16
).
cuda
()
m_splits
=
[
64
,
32
,
32
]
with
te
.
fp8_autocast
(
fp8_recipe
=
recipe
.
DelayedScaling
()):
output
=
model
(
inp
,
m_splits
=
m_splits
)
loss
=
output
.
sum
()
loss
.
backward
()
debug_api
.
step
()
output
=
read_log
(
log_dir
)
assert
"gemm_0"
in
output
,
"gemm0 not found in output"
assert
"gemm_1"
in
output
,
"gemm1 not found in output"
assert
"gemm_2"
in
output
,
"gemm2 not found in output"
def
test_compute_max_blockwise_dynamic_range_direct
():
def
test_compute_max_blockwise_dynamic_range_direct
():
"""Direct unit test for compute_max_blockwise_dynamic_range function.
"""Direct unit test for compute_max_blockwise_dynamic_range function.
...
...
tests/pytorch/debug/test_numerics.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -82,7 +82,6 @@ def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split
...
@@ -82,7 +82,6 @@ def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split
out
,
*
_
=
tepytorch
.
cpp_extensions
.
general_gemm
(
out
,
*
_
=
tepytorch
.
cpp_extensions
.
general_gemm
(
fp8_tensor1
,
fp8_tensor1
,
fp8_tensor2
,
fp8_tensor2
,
tepytorch
.
module
.
base
.
get_workspace
(),
torch
.
float32
,
torch
.
float32
,
use_split_accumulator
=
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
)
)
...
@@ -199,7 +198,6 @@ def _emulate_linear(
...
@@ -199,7 +198,6 @@ def _emulate_linear(
wgrad
,
*
_
=
tepytorch
.
cpp_extensions
.
general_gemm
(
wgrad
,
*
_
=
tepytorch
.
cpp_extensions
.
general_gemm
(
wgrad_input
,
wgrad_input
,
wgrad_gradient
,
wgrad_gradient
,
tepytorch
.
module
.
base
.
get_workspace
(),
torch
.
float32
,
torch
.
float32
,
layout
=
"NT"
,
layout
=
"NT"
,
grad
=
True
,
grad
=
True
,
...
...
tests/pytorch/debug/test_perf.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
tests/pytorch/debug/test_sanity.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
tests/pytorch/debug/utils.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
tests/pytorch/distributed/run_fsdp2_model.py
View file @
0d874a4e
#!/usr/bin/python3
#!/usr/bin/python3
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
tests/pytorch/distributed/run_gemm_with_overlap.py
View file @
0d874a4e
#!/usr/bin/python3
#!/usr/bin/python3
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -25,10 +25,8 @@ from transformer_engine.pytorch import (
...
@@ -25,10 +25,8 @@ from transformer_engine.pytorch import (
MXFP8Quantizer
,
MXFP8Quantizer
,
)
)
import
transformer_engine.pytorch.cpp_extensions
as
tex
import
transformer_engine.pytorch.cpp_extensions
as
tex
from
transformer_engine.pytorch.module.base
import
(
from
transformer_engine.pytorch.cpp_extensions.gemm
import
get_cublas_workspace_size_bytes
fill_userbuffers_buffer_for_all_gather
,
from
transformer_engine.pytorch.module.base
import
fill_userbuffers_buffer_for_all_gather
get_cublas_workspace_size_bytes
,
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
DeprecationWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
DeprecationWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
...
@@ -420,10 +418,6 @@ def _main(opts):
...
@@ -420,10 +418,6 @@ def _main(opts):
std
=
opts
.
std
,
std
=
opts
.
std
,
)
)
# Allocate cuBLAS workspace
workspace_size
=
1
*
get_cublas_workspace_size_bytes
()
workspace
=
torch
.
empty
(
workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
if
opts
.
bulk_overlap
:
if
opts
.
bulk_overlap
:
ker_g
=
torch
.
transpose
(
kernel_t
,
0
,
1
)
ker_g
=
torch
.
transpose
(
kernel_t
,
0
,
1
)
...
@@ -620,7 +614,6 @@ def _main(opts):
...
@@ -620,7 +614,6 @@ def _main(opts):
return
tex
.
general_gemm
(
return
tex
.
general_gemm
(
kernel_t_fp8
,
kernel_t_fp8
,
gemm_inp
,
gemm_inp
,
workspace
,
out_dtype
=
torch
.
float8_e4m3fn
if
opts
.
fp8_output
else
torch
.
bfloat16
,
out_dtype
=
torch
.
float8_e4m3fn
if
opts
.
fp8_output
else
torch
.
bfloat16
,
quantization_params
=
out_quantizer
,
quantization_params
=
out_quantizer
,
use_split_accumulator
=
te
.
module
.
base
.
_2X_ACC_FPROP
,
use_split_accumulator
=
te
.
module
.
base
.
_2X_ACC_FPROP
,
...
@@ -638,7 +631,6 @@ def _main(opts):
...
@@ -638,7 +631,6 @@ def _main(opts):
return
tex
.
general_gemm
(
return
tex
.
general_gemm
(
kernel2_t_fp8
,
kernel2_t_fp8
,
gemm2_inp
,
gemm2_inp
,
workspace
,
out_dtype
=
torch
.
float8_e4m3fn
if
opts
.
fp8_output
else
torch
.
bfloat16
,
out_dtype
=
torch
.
float8_e4m3fn
if
opts
.
fp8_output
else
torch
.
bfloat16
,
quantization_params
=
out2_quantizer
,
quantization_params
=
out2_quantizer
,
use_split_accumulator
=
te
.
module
.
base
.
_2X_ACC_FPROP
,
use_split_accumulator
=
te
.
module
.
base
.
_2X_ACC_FPROP
,
...
@@ -651,7 +643,6 @@ def _main(opts):
...
@@ -651,7 +643,6 @@ def _main(opts):
return
tex
.
general_gemm
(
return
tex
.
general_gemm
(
kernel_t
,
kernel_t
,
gemm_inp
,
gemm_inp
,
workspace
,
out_dtype
=
torch
.
bfloat16
,
out_dtype
=
torch
.
bfloat16
,
use_split_accumulator
=
te
.
module
.
base
.
_2X_ACC_FPROP
,
use_split_accumulator
=
te
.
module
.
base
.
_2X_ACC_FPROP
,
ub
=
ub_obj
,
ub
=
ub_obj
,
...
...
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
0d874a4e
#!/usr/bin/python3
#!/usr/bin/python3
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
tests/pytorch/distributed/run_numerics.py
View file @
0d874a4e
#!/usr/bin/python3
#!/usr/bin/python3
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -39,8 +39,9 @@ WORLD_RANK, WORLD_SIZE = None, None
...
@@ -39,8 +39,9 @@ WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD
=
None
NCCL_WORLD
=
None
LOSS_FN
=
nn
.
MSELoss
()
LOSS_FN
=
nn
.
MSELoss
()
QUANTIZATION
=
None
QUANTIZATION
=
None
NVTE_TEST_NVINSPECT_ENABLED
=
int
(
os
.
environ
.
get
(
"NVTE_TEST_NVINSPECT_ENABLED"
)
or
"0"
)
if
os
.
environ
.
get
(
"
NVTE_TEST_NVINSPECT_ENABLED
"
,
False
)
:
if
NVTE_TEST_NVINSPECT_ENABLED
:
# The numerics of all the layers should work the same,
# The numerics of all the layers should work the same,
# when debug=True. I fed them with dummy feature
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# to prevent switching off debug, which can happen if
...
@@ -754,6 +755,8 @@ def test_linear():
...
@@ -754,6 +755,8 @@ def test_linear():
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"save_original_input"
,
False
)
and
QUANTIZATION
==
"fp8"
:
if
kwargs
.
get
(
"save_original_input"
,
False
)
and
QUANTIZATION
==
"fp8"
:
continue
continue
if
kwargs
.
get
(
"delay_wgrad_compute"
,
False
)
and
NVTE_TEST_NVINSPECT_ENABLED
:
continue
for
parallel_mode
in
[
"column"
,
"row"
]:
for
parallel_mode
in
[
"column"
,
"row"
]:
for
sequence_parallel
in
[
False
,
True
]:
for
sequence_parallel
in
[
False
,
True
]:
_test_linear
(
parallel_mode
,
sequence_parallel
,
**
kwargs
)
_test_linear
(
parallel_mode
,
sequence_parallel
,
**
kwargs
)
...
@@ -941,6 +944,8 @@ def test_layernorm_linear():
...
@@ -941,6 +944,8 @@ def test_layernorm_linear():
else
:
else
:
kwargs_list
=
base_kwargs_list
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"delay_wgrad_compute"
,
False
)
and
NVTE_TEST_NVINSPECT_ENABLED
:
continue
for
parallel_mode
in
[
"column"
]:
for
parallel_mode
in
[
"column"
]:
for
sequence_parallel
in
[
False
,
True
]:
for
sequence_parallel
in
[
False
,
True
]:
_test_layernorm_linear
(
parallel_mode
,
sequence_parallel
,
**
kwargs
)
_test_layernorm_linear
(
parallel_mode
,
sequence_parallel
,
**
kwargs
)
...
@@ -1047,6 +1052,7 @@ def test_layernorm_mlp():
...
@@ -1047,6 +1052,7 @@ def test_layernorm_mlp():
{
"return_bias"
:
True
},
{
"return_bias"
:
True
},
{
"return_layernorm_output"
:
True
},
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
{
"delay_wgrad_compute"
:
True
},
{
"checkpoint"
:
True
},
]
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
"""
...
@@ -1058,6 +1064,8 @@ def test_layernorm_mlp():
...
@@ -1058,6 +1064,8 @@ def test_layernorm_mlp():
else
:
else
:
kwargs_list
=
base_kwargs_list
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"delay_wgrad_compute"
,
False
)
and
NVTE_TEST_NVINSPECT_ENABLED
:
continue
for
set_parallel_mode
in
[
True
]:
for
set_parallel_mode
in
[
True
]:
for
sequence_parallel
in
[
False
,
True
]:
for
sequence_parallel
in
[
False
,
True
]:
_test_layernorm_mlp
(
set_parallel_mode
,
sequence_parallel
,
**
kwargs
)
_test_layernorm_mlp
(
set_parallel_mode
,
sequence_parallel
,
**
kwargs
)
...
...
tests/pytorch/distributed/run_numerics_exact.py
View file @
0d874a4e
#!/usr/bin/python3
#!/usr/bin/python3
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -20,6 +20,7 @@ from transformer_engine.common.recipe import (
...
@@ -20,6 +20,7 @@ from transformer_engine.common.recipe import (
DelayedScaling
,
DelayedScaling
,
Float8CurrentScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
Float8BlockScaling
,
MXFP8BlockScaling
,
Format
,
Format
,
Recipe
,
Recipe
,
)
)
...
@@ -27,9 +28,11 @@ import transformer_engine.pytorch as te
...
@@ -27,9 +28,11 @@ import transformer_engine.pytorch as te
from
transformer_engine.pytorch
import
(
from
transformer_engine.pytorch
import
(
is_fp8_available
,
is_fp8_available
,
is_fp8_block_scaling_available
,
is_fp8_block_scaling_available
,
is_mxfp8_available
,
QuantizedTensor
,
QuantizedTensor
,
Float8Tensor
,
Float8Tensor
,
Float8BlockwiseQTensor
,
Float8BlockwiseQTensor
,
MXFP8Tensor
,
)
)
from
transformer_engine.pytorch.tensor
import
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor
import
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor.utils
import
post_all_gather_processing
,
replace_raw_data
from
transformer_engine.pytorch.tensor.utils
import
post_all_gather_processing
,
replace_raw_data
...
@@ -44,17 +47,21 @@ def _get_quantization_recipe(quantization) -> Recipe:
...
@@ -44,17 +47,21 @@ def _get_quantization_recipe(quantization) -> Recipe:
return
Float8CurrentScaling
(
fp8_format
=
fp8_format
)
return
Float8CurrentScaling
(
fp8_format
=
fp8_format
)
elif
quantization
==
"fp8_block"
:
elif
quantization
==
"fp8_block"
:
return
Float8BlockScaling
(
fp8_format
=
fp8_format
)
return
Float8BlockScaling
(
fp8_format
=
fp8_format
)
elif
quantization
==
"mxfp8"
:
return
MXFP8BlockScaling
()
else
:
else
:
raise
ValueError
(
f
"Unsupported quantization:
{
quantization
}
"
)
raise
ValueError
(
f
"Unsupported quantization:
{
quantization
}
"
)
def
_get_raw_data
(
quantized_tensor
):
def
_get_raw_data
(
quantized_tensor
,
colwise
=
False
):
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
if
isinstance
(
quantized_tensor
,
Float8Tensor
):
if
isinstance
(
quantized_tensor
,
Float8Tensor
):
assert
not
colwise
,
"Float8Tensor does not support get colwise data"
assert
hasattr
(
quantized_tensor
,
"_data"
),
"Float8Tensor does not have _data attribute"
assert
hasattr
(
quantized_tensor
,
"_data"
),
"Float8Tensor does not have _data attribute"
assert
quantized_tensor
.
_data
.
dtype
==
torch
.
uint8
,
"Float8Tensor _data must be uint8"
assert
quantized_tensor
.
_data
.
dtype
==
torch
.
uint8
,
"Float8Tensor _data must be uint8"
return
quantized_tensor
.
_data
return
quantized_tensor
.
_data
elif
isinstance
(
quantized_tensor
,
Float8BlockwiseQTensor
):
elif
isinstance
(
quantized_tensor
,
Float8BlockwiseQTensor
):
assert
not
colwise
,
"Float8BlockwiseQTensor does not support get colwise data"
assert
hasattr
(
assert
hasattr
(
quantized_tensor
,
"_rowwise_data"
quantized_tensor
,
"_rowwise_data"
),
"Float8BlockwiseQTensor does not have _rowwise_data attribute"
),
"Float8BlockwiseQTensor does not have _rowwise_data attribute"
...
@@ -62,6 +69,23 @@ def _get_raw_data(quantized_tensor):
...
@@ -62,6 +69,23 @@ def _get_raw_data(quantized_tensor):
quantized_tensor
.
_rowwise_data
.
dtype
==
torch
.
uint8
quantized_tensor
.
_rowwise_data
.
dtype
==
torch
.
uint8
),
"Float8BlockwiseQTensor _rowwise_data must be uint8"
),
"Float8BlockwiseQTensor _rowwise_data must be uint8"
return
quantized_tensor
.
_rowwise_data
return
quantized_tensor
.
_rowwise_data
elif
isinstance
(
quantized_tensor
,
MXFP8Tensor
):
if
colwise
:
assert
hasattr
(
quantized_tensor
,
"_columnwise_data"
),
"MXFP8Tensor does not have columnwise_data attribute"
assert
(
quantized_tensor
.
_columnwise_data
.
dtype
==
torch
.
uint8
),
"MXFP8Tensor columnwise_data must be uint8"
return
quantized_tensor
.
_columnwise_data
else
:
assert
hasattr
(
quantized_tensor
,
"_rowwise_data"
),
"MXFP8Tensor does not have rowwise_data attribute"
assert
(
quantized_tensor
.
_rowwise_data
.
dtype
==
torch
.
uint8
),
"MXFP8Tensor rowwise_data must be uint8"
return
quantized_tensor
.
_rowwise_data
else
:
else
:
raise
ValueError
(
f
"Unsupported quantized tensor type:
{
type
(
quantized_tensor
)
}
"
)
raise
ValueError
(
f
"Unsupported quantized tensor type:
{
type
(
quantized_tensor
)
}
"
)
...
@@ -231,38 +255,43 @@ class MiniZero_1:
...
@@ -231,38 +255,43 @@ class MiniZero_1:
end
=
start_offset
+
master_weight
.
numel
()
end
=
start_offset
+
master_weight
.
numel
()
weight
.
data
.
view
(
-
1
)[
start
:
end
].
copy_
(
master_weight
)
weight
.
data
.
view
(
-
1
)[
start
:
end
].
copy_
(
master_weight
)
# -----------------------------------------------------------------------------------------
colwise_list
=
[
False
]
# Step 5: Copy the updated weights (not all weights) to the weight buffer
if
isinstance
(
self
.
weights
[
0
],
MXFP8Tensor
):
# -----------------------------------------------------------------------------------------
colwise_list
.
append
(
True
)
for
i
in
range
(
len
(
self
.
weights
)):
master_weight
=
self
.
master_weights
[
i
]
if
master_weight
is
None
:
continue
start_offset
=
self
.
start_offsets
[
i
]
if
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
):
weight
=
_get_raw_data
(
self
.
weights
[
i
])
else
:
weight
=
self
.
weights
[
i
]
weight_slice
=
weight
.
view
(
-
1
)[
start_offset
:
start_offset
+
master_weight
.
numel
()]
overlapping_start
,
overlapping_end
=
self
.
overlapping_areas
[
i
]
self
.
weight_buffer
[
overlapping_start
:
overlapping_end
].
copy_
(
weight_slice
)
# -----------------------------------------------------------------------------------------
for
colwise
in
colwise_list
:
# Step 6: Weight all-gather (FP8 or BF16)
# -------------------------------------------------------------------------------------
# -----------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
dist
.
all_gather_into_tensor
(
# -------------------------------------------------------------------------------------
self
.
weight_buffer
,
self
.
weight_buffer_slice
,
group
=
self
.
dp_group
for
i
in
range
(
len
(
self
.
weights
)):
)
master_weight
=
self
.
master_weights
[
i
]
if
master_weight
is
None
:
continue
start_offset
=
self
.
start_offsets
[
i
]
if
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
):
weight
=
_get_raw_data
(
self
.
weights
[
i
],
colwise
)
else
:
weight
=
self
.
weights
[
i
]
weight_slice
=
weight
.
view
(
-
1
)[
start_offset
:
start_offset
+
master_weight
.
numel
()]
overlapping_start
,
overlapping_end
=
self
.
overlapping_areas
[
i
]
self
.
weight_buffer
[
overlapping_start
:
overlapping_end
].
copy_
(
weight_slice
)
# -------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -------------------------------------------------------------------------------------
dist
.
all_gather_into_tensor
(
self
.
weight_buffer
,
self
.
weight_buffer_slice
,
group
=
self
.
dp_group
)
# ----
-------------------------------------------------------------------------------------
#
-------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# ----
-------------------------------------------------------------------------------------
#
-------------------------------------------------------------------------------------
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
start
=
offset
start
=
offset
end
=
offset
+
weight
.
numel
()
end
=
offset
+
weight
.
numel
()
if
isinstance
(
weight
,
QuantizedTensor
):
if
isinstance
(
weight
,
QuantizedTensor
):
weight
=
_get_raw_data
(
weight
)
weight
=
_get_raw_data
(
weight
,
colwise
)
weight
.
view
(
-
1
).
data
.
copy_
(
self
.
weight_buffer
[
start
:
end
])
weight
.
view
(
-
1
).
data
.
copy_
(
self
.
weight_buffer
[
start
:
end
])
if
self
.
manual_post_all_gather_processing
:
if
self
.
manual_post_all_gather_processing
:
quantized_weights
=
[
quantized_weights
=
[
...
@@ -287,9 +316,15 @@ class MiniFSDP:
...
@@ -287,9 +316,15 @@ class MiniFSDP:
else
:
else
:
raw_data_list
=
[
w
.
view
(
-
1
)
for
w
in
weights
]
raw_data_list
=
[
w
.
view
(
-
1
)
for
w
in
weights
]
self
.
flatten_weight
,
original_length
=
self
.
_flatten_tensors_with_pad
(
raw_data_list
)
self
.
flatten_weight
,
original_length
=
self
.
_flatten_tensors_with_pad
(
raw_data_list
)
if
isinstance
(
weights
[
0
],
MXFP8Tensor
):
self
.
flatten_columnwise
=
self
.
flatten_weight
.
clone
()
else
:
self
.
flatten_columnwise
=
None
# Split flattened weights into shards
# Split flattened weights into shards
self
.
local_weight_shard
=
torch
.
chunk
(
self
.
flatten_weight
,
world_size
)[
rank
]
self
.
local_weight_shard
=
torch
.
chunk
(
self
.
flatten_weight
,
world_size
)[
rank
]
if
self
.
flatten_columnwise
is
not
None
:
self
.
local_columnwise_shard
=
torch
.
chunk
(
self
.
flatten_columnwise
,
world_size
)[
rank
]
self
.
local_main_grad_shard
=
torch
.
zeros_like
(
self
.
local_main_grad_shard
=
torch
.
zeros_like
(
self
.
local_weight_shard
,
dtype
=
torch
.
float32
,
device
=
"cuda"
self
.
local_weight_shard
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
)
...
@@ -321,14 +356,25 @@ class MiniFSDP:
...
@@ -321,14 +356,25 @@ class MiniFSDP:
self
.
shard_indices
.
append
((
None
,
None
))
self
.
shard_indices
.
append
((
None
,
None
))
if
isinstance
(
weights
[
idx
],
QuantizedTensor
):
if
isinstance
(
weights
[
idx
],
QuantizedTensor
):
replace_raw_data
(
if
self
.
flatten_columnwise
is
not
None
:
weights
[
idx
],
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
new_rowwise_data
=
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
)
new_rowwise_data
.
copy_
(
weights
[
idx
].
_rowwise_data
)
weights
[
idx
].
_rowwise_data
=
new_rowwise_data
new_columnwise_data
=
self
.
flatten_columnwise
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
new_columnwise_data
.
copy_
(
weights
[
idx
].
_columnwise_data
)
weights
[
idx
].
_columnwise_data
=
new_columnwise_data
else
:
replace_raw_data
(
weights
[
idx
],
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
)
else
:
else
:
weights
[
idx
].
data
=
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
weights
[
idx
].
data
=
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
# Initialize local model weights and high-precision master weights
# Initialize local model weights and high-precision master weights
self
.
local_weights
=
[]
self
.
local_weights
=
[]
self
.
local_columnwise
=
[]
self
.
master_weights
=
[]
self
.
master_weights
=
[]
for
i
,
weight
in
enumerate
(
self
.
weights
):
for
i
,
weight
in
enumerate
(
self
.
weights
):
weight_start
,
weight_end
=
self
.
weight_indices
[
i
]
weight_start
,
weight_end
=
self
.
weight_indices
[
i
]
...
@@ -336,6 +382,11 @@ class MiniFSDP:
...
@@ -336,6 +382,11 @@ class MiniFSDP:
if
shard_start
is
not
None
and
shard_end
is
not
None
:
if
shard_start
is
not
None
and
shard_end
is
not
None
:
local_weight_shard
=
self
.
local_weight_shard
[
shard_start
:
shard_end
]
local_weight_shard
=
self
.
local_weight_shard
[
shard_start
:
shard_end
]
self
.
local_weights
.
append
(
local_weight_shard
)
self
.
local_weights
.
append
(
local_weight_shard
)
if
self
.
flatten_columnwise
is
not
None
:
local_columnwise_shard
=
self
.
local_columnwise_shard
[
shard_start
:
shard_end
]
else
:
local_columnwise_shard
=
None
self
.
local_columnwise
.
append
(
local_columnwise_shard
)
if
isinstance
(
weight
,
QuantizedTensor
):
if
isinstance
(
weight
,
QuantizedTensor
):
high_precision_init_val
=
weight
.
get_high_precision_init_val
().
view
(
-
1
)
high_precision_init_val
=
weight
.
get_high_precision_init_val
().
view
(
-
1
)
...
@@ -347,6 +398,7 @@ class MiniFSDP:
...
@@ -347,6 +398,7 @@ class MiniFSDP:
self
.
master_weights
.
append
(
master_weight_shard
)
self
.
master_weights
.
append
(
master_weight_shard
)
else
:
else
:
self
.
local_weights
.
append
(
None
)
self
.
local_weights
.
append
(
None
)
self
.
local_columnwise
.
append
(
None
)
self
.
master_weights
.
append
(
None
)
self
.
master_weights
.
append
(
None
)
setattr
(
setattr
(
weight
,
"main_grad"
,
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
weight
,
"main_grad"
,
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
...
@@ -417,12 +469,12 @@ class MiniFSDP:
...
@@ -417,12 +469,12 @@ class MiniFSDP:
# Step 3: Cast master weights to FP8 or BF16 precision
# Step 3: Cast master weights to FP8 or BF16 precision
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
local_weights
=
[]
local_weights
=
[]
for
local_weight
in
self
.
local_weights
:
for
i
,
local_weight
in
enumerate
(
self
.
local_weights
)
:
if
local_weight
is
None
:
if
self
.
flatten_columnwise
is
not
None
:
local_
weights
.
append
(
None
)
local_
columnwise
=
self
.
local_columnwise
[
i
]
continue
local_weights
.
append
((
local_weight
,
local_columnwise
))
else
:
local_weights
.
append
(
local_weight
)
local_weights
.
append
(
local_weight
)
cast_master_weights_to_fp8
(
cast_master_weights_to_fp8
(
self
.
weights
,
self
.
weights
,
...
@@ -444,6 +496,10 @@ class MiniFSDP:
...
@@ -444,6 +496,10 @@ class MiniFSDP:
dist
.
all_gather_into_tensor
(
dist
.
all_gather_into_tensor
(
self
.
flatten_weight
,
self
.
local_weight_shard
,
group
=
self
.
dp_group
self
.
flatten_weight
,
self
.
local_weight_shard
,
group
=
self
.
dp_group
)
)
if
self
.
flatten_columnwise
is
not
None
:
dist
.
all_gather_into_tensor
(
self
.
flatten_columnwise
,
self
.
local_columnwise_shard
,
group
=
self
.
dp_group
)
if
self
.
manual_post_all_gather_processing
:
if
self
.
manual_post_all_gather_processing
:
quantized_weights
=
[
quantized_weights
=
[
...
@@ -515,15 +571,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
...
@@ -515,15 +571,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
preserve_high_precision_init_val
=
True
,
preserve_high_precision_init_val
=
True
,
):
):
model_fp8
=
nn
.
Sequential
(
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
32
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
+
32
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
)
# Create model with BF16 weights
# Create model with BF16 weights
model
=
nn
.
Sequential
(
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
32
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
+
32
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
)
...
@@ -548,7 +604,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
...
@@ -548,7 +604,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
w
.
main_grad
.
zero_
()
w
.
main_grad
.
zero_
()
inputs
=
[
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
torch
.
randn
(
32
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
]
# Choose based on rank to make sure the inputs of different ranks are different.
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
x
=
inputs
[
rank
]
...
@@ -579,7 +635,9 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
...
@@ -579,7 +635,9 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
optimizer_fp8
.
step
()
optimizer_fp8
.
step
()
optimizer
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
assert
torch
.
allclose
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
),
f
"Loss mismatch at rank
{
rank
}
, step
{
i
}
for
{
quantization
}
"
def
_test_fsdp_cast_master_weights_to_fp8
(
def
_test_fsdp_cast_master_weights_to_fp8
(
...
@@ -611,15 +669,15 @@ def _test_fsdp_cast_master_weights_to_fp8(
...
@@ -611,15 +669,15 @@ def _test_fsdp_cast_master_weights_to_fp8(
preserve_high_precision_init_val
=
True
,
preserve_high_precision_init_val
=
True
,
):
):
model_fp8
=
nn
.
Sequential
(
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
32
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
+
32
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
)
# Create model with BF16 weights
# Create model with BF16 weights
model
=
nn
.
Sequential
(
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
32
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
+
32
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
)
...
@@ -633,12 +691,12 @@ def _test_fsdp_cast_master_weights_to_fp8(
...
@@ -633,12 +691,12 @@ def _test_fsdp_cast_master_weights_to_fp8(
)
)
optimizer
=
MiniFSDP
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
optimizer
=
MiniFSDP
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
_
in
range
(
100
):
for
i
in
range
(
100
):
optimizer_fp8
.
zero_grad
()
optimizer_fp8
.
zero_grad
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
inputs
=
[
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
torch
.
randn
(
32
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
]
# Choose based on rank to make sure the inputs of different ranks are different.
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
x
=
inputs
[
rank
]
...
@@ -669,7 +727,9 @@ def _test_fsdp_cast_master_weights_to_fp8(
...
@@ -669,7 +727,9 @@ def _test_fsdp_cast_master_weights_to_fp8(
optimizer_fp8
.
step
()
optimizer_fp8
.
step
()
optimizer
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
assert
torch
.
allclose
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
),
f
"Loss mismatch at rank
{
rank
}
, step
{
i
}
for
{
quantization
}
(FSDP)"
def
run_parallel_tests
()
->
None
:
def
run_parallel_tests
()
->
None
:
...
@@ -700,6 +760,8 @@ def run_parallel_tests() -> None:
...
@@ -700,6 +760,8 @@ def run_parallel_tests() -> None:
quantizations
.
extend
([
"fp8"
,
"fp8_cs"
])
quantizations
.
extend
([
"fp8"
,
"fp8_cs"
])
if
is_fp8_block_scaling_available
():
if
is_fp8_block_scaling_available
():
quantizations
.
append
(
"fp8_block"
)
quantizations
.
append
(
"fp8_block"
)
if
is_mxfp8_available
():
quantizations
.
append
(
"mxfp8"
)
manual_post_all_gather_processings
=
[
False
,
True
]
manual_post_all_gather_processings
=
[
False
,
True
]
...
...
tests/pytorch/distributed/test_comm_gemm_overlap.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
# mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_gemm_with_overlap.py --check-numerics --seed=42 --seq-length=1024 --batch-size=2 --num-heads=48 --head-dim=64 --comm-type=AG --p2p
# mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_gemm_with_overlap.py --check-numerics --seed=42 --seq-length=1024 --batch-size=2 --num-heads=48 --head-dim=64 --comm-type=AG --p2p
...
@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
...
@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
os
.
environ
[
"PYTORCH_JIT"
]
=
"0"
os
.
environ
[
"PYTORCH_JIT"
]
=
"0"
os
.
environ
[
"NVTE_TORCH_COMPILE"
]
=
"0"
os
.
environ
[
"NVTE_TORCH_COMPILE"
]
=
"0"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"0"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"0"
if
te
.
get_device_compute_capability
()
<=
(
8
,
0
):
# We've experienced numerical discrepancies in Flash Attention
# backward when running with Userbuffers on A100s. This does
# not show up in more recent GPUs.
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
capture_output
=
True
,
check
=
False
)
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
capture_output
=
True
,
check
=
False
)
os
.
unsetenv
(
"PYTORCH_JIT"
)
os
.
unsetenv
(
"PYTORCH_JIT"
)
os
.
unsetenv
(
"NVTE_TORCH_COMPILE"
)
os
.
unsetenv
(
"NVTE_TORCH_COMPILE"
)
os
.
unsetenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
)
os
.
unsetenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
)
os
.
unsetenv
(
"NVTE_FLASH_ATTN"
)
if
(
if
(
result
.
returncode
!=
0
result
.
returncode
!=
0
...
...
tests/pytorch/distributed/test_fusible_ops.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
tests/pytorch/distributed/test_numerics.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -13,7 +13,7 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
...
@@ -13,7 +13,7 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
"""
"""
Distributed numerics tests
Distributed numerics tests
These tests test the numerical corectness of the TransformerEngine layers.
These tests test the numerical cor
r
ectness of the TransformerEngine layers.
Tests are parametrized by the layer and fp8 precision.
Tests are parametrized by the layer and fp8 precision.
One test consists of running multiple configurations from file run_numerics.py
One test consists of running multiple configurations from file run_numerics.py
Such design is due to the fact the initialization of one test is long
Such design is due to the fact the initialization of one test is long
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment