Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
175 additions
and
88 deletions
+175
-88
tests/pytorch/debug/conftest.py
tests/pytorch/debug/conftest.py
+1
-1
tests/pytorch/debug/run_distributed.py
tests/pytorch/debug/run_distributed.py
+1
-1
tests/pytorch/debug/test_api_features.py
tests/pytorch/debug/test_api_features.py
+1
-1
tests/pytorch/debug/test_config.py
tests/pytorch/debug/test_config.py
+1
-1
tests/pytorch/debug/test_distributed.py
tests/pytorch/debug/test_distributed.py
+1
-1
tests/pytorch/debug/test_log.py
tests/pytorch/debug/test_log.py
+23
-1
tests/pytorch/debug/test_numerics.py
tests/pytorch/debug/test_numerics.py
+1
-3
tests/pytorch/debug/test_perf.py
tests/pytorch/debug/test_perf.py
+1
-1
tests/pytorch/debug/test_sanity.py
tests/pytorch/debug/test_sanity.py
+1
-1
tests/pytorch/debug/utils.py
tests/pytorch/debug/utils.py
+1
-1
tests/pytorch/distributed/run_fsdp2_model.py
tests/pytorch/distributed/run_fsdp2_model.py
+1
-1
tests/pytorch/distributed/run_gemm_with_overlap.py
tests/pytorch/distributed/run_gemm_with_overlap.py
+3
-12
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+1
-1
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+10
-2
tests/pytorch/distributed/run_numerics_exact.py
tests/pytorch/distributed/run_numerics_exact.py
+1
-1
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
+116
-54
tests/pytorch/distributed/test_comm_gemm_overlap.py
tests/pytorch/distributed/test_comm_gemm_overlap.py
+7
-1
tests/pytorch/distributed/test_fusible_ops.py
tests/pytorch/distributed/test_fusible_ops.py
+1
-1
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
.../pytorch/distributed/test_fusible_ops_with_userbuffers.py
+1
-1
tests/pytorch/distributed/test_numerics.py
tests/pytorch/distributed/test_numerics.py
+2
-2
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
tests/pytorch/debug/conftest.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
...
...
tests/pytorch/debug/run_distributed.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/debug/test_api_features.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/debug/test_config.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pathlib
...
...
tests/pytorch/debug/test_distributed.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/debug/test_log.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -363,6 +363,28 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState
.
_reset
()
def
test_log_grouped_gemm
(
feature_dirs
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
log_all_stats_config
=
LOG_QUANTIZED_CONFIG_BASE
.
format
(
stats
=
", "
.
join
(
all_stats
))
with
debug_session
(
log_all_stats_config
,
feature_dirs
)
as
log_dir
:
model
=
te
.
GroupedLinear
(
3
,
128
,
128
,
name
=
"linear1"
,
params_dtype
=
torch
.
bfloat16
)
inp
=
torch
.
randn
((
1
,
128
,
128
),
dtype
=
torch
.
bfloat16
).
cuda
()
m_splits
=
[
64
,
32
,
32
]
with
te
.
fp8_autocast
(
fp8_recipe
=
recipe
.
DelayedScaling
()):
output
=
model
(
inp
,
m_splits
=
m_splits
)
loss
=
output
.
sum
()
loss
.
backward
()
debug_api
.
step
()
output
=
read_log
(
log_dir
)
assert
"gemm_0"
in
output
,
"gemm0 not found in output"
assert
"gemm_1"
in
output
,
"gemm1 not found in output"
assert
"gemm_2"
in
output
,
"gemm2 not found in output"
def
test_compute_max_blockwise_dynamic_range_direct
():
"""Direct unit test for compute_max_blockwise_dynamic_range function.
...
...
tests/pytorch/debug/test_numerics.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -82,7 +82,6 @@ def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split
out
,
*
_
=
tepytorch
.
cpp_extensions
.
general_gemm
(
fp8_tensor1
,
fp8_tensor2
,
tepytorch
.
module
.
base
.
get_workspace
(),
torch
.
float32
,
use_split_accumulator
=
use_split_accumulator
,
)
...
...
@@ -199,7 +198,6 @@ def _emulate_linear(
wgrad
,
*
_
=
tepytorch
.
cpp_extensions
.
general_gemm
(
wgrad_input
,
wgrad_gradient
,
tepytorch
.
module
.
base
.
get_workspace
(),
torch
.
float32
,
layout
=
"NT"
,
grad
=
True
,
...
...
tests/pytorch/debug/test_perf.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/debug/test_sanity.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/debug/utils.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/distributed/run_fsdp2_model.py
View file @
0d874a4e
#!/usr/bin/python3
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/distributed/run_gemm_with_overlap.py
View file @
0d874a4e
#!/usr/bin/python3
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -25,10 +25,8 @@ from transformer_engine.pytorch import (
MXFP8Quantizer
,
)
import
transformer_engine.pytorch.cpp_extensions
as
tex
from
transformer_engine.pytorch.module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_cublas_workspace_size_bytes
,
)
from
transformer_engine.pytorch.cpp_extensions.gemm
import
get_cublas_workspace_size_bytes
from
transformer_engine.pytorch.module.base
import
fill_userbuffers_buffer_for_all_gather
warnings
.
filterwarnings
(
"ignore"
,
category
=
DeprecationWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
...
...
@@ -420,10 +418,6 @@ def _main(opts):
std
=
opts
.
std
,
)
# Allocate cuBLAS workspace
workspace_size
=
1
*
get_cublas_workspace_size_bytes
()
workspace
=
torch
.
empty
(
workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
if
opts
.
bulk_overlap
:
ker_g
=
torch
.
transpose
(
kernel_t
,
0
,
1
)
...
...
@@ -620,7 +614,6 @@ def _main(opts):
return
tex
.
general_gemm
(
kernel_t_fp8
,
gemm_inp
,
workspace
,
out_dtype
=
torch
.
float8_e4m3fn
if
opts
.
fp8_output
else
torch
.
bfloat16
,
quantization_params
=
out_quantizer
,
use_split_accumulator
=
te
.
module
.
base
.
_2X_ACC_FPROP
,
...
...
@@ -638,7 +631,6 @@ def _main(opts):
return
tex
.
general_gemm
(
kernel2_t_fp8
,
gemm2_inp
,
workspace
,
out_dtype
=
torch
.
float8_e4m3fn
if
opts
.
fp8_output
else
torch
.
bfloat16
,
quantization_params
=
out2_quantizer
,
use_split_accumulator
=
te
.
module
.
base
.
_2X_ACC_FPROP
,
...
...
@@ -651,7 +643,6 @@ def _main(opts):
return
tex
.
general_gemm
(
kernel_t
,
gemm_inp
,
workspace
,
out_dtype
=
torch
.
bfloat16
,
use_split_accumulator
=
te
.
module
.
base
.
_2X_ACC_FPROP
,
ub
=
ub_obj
,
...
...
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
0d874a4e
#!/usr/bin/python3
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/distributed/run_numerics.py
View file @
0d874a4e
#!/usr/bin/python3
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -39,8 +39,9 @@ WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD
=
None
LOSS_FN
=
nn
.
MSELoss
()
QUANTIZATION
=
None
NVTE_TEST_NVINSPECT_ENABLED
=
int
(
os
.
environ
.
get
(
"NVTE_TEST_NVINSPECT_ENABLED"
)
or
"0"
)
if
os
.
environ
.
get
(
"
NVTE_TEST_NVINSPECT_ENABLED
"
,
False
)
:
if
NVTE_TEST_NVINSPECT_ENABLED
:
# The numerics of all the layers should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
...
...
@@ -754,6 +755,8 @@ def test_linear():
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"save_original_input"
,
False
)
and
QUANTIZATION
==
"fp8"
:
continue
if
kwargs
.
get
(
"delay_wgrad_compute"
,
False
)
and
NVTE_TEST_NVINSPECT_ENABLED
:
continue
for
parallel_mode
in
[
"column"
,
"row"
]:
for
sequence_parallel
in
[
False
,
True
]:
_test_linear
(
parallel_mode
,
sequence_parallel
,
**
kwargs
)
...
...
@@ -941,6 +944,8 @@ def test_layernorm_linear():
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"delay_wgrad_compute"
,
False
)
and
NVTE_TEST_NVINSPECT_ENABLED
:
continue
for
parallel_mode
in
[
"column"
]:
for
sequence_parallel
in
[
False
,
True
]:
_test_layernorm_linear
(
parallel_mode
,
sequence_parallel
,
**
kwargs
)
...
...
@@ -1047,6 +1052,7 @@ def test_layernorm_mlp():
{
"return_bias"
:
True
},
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
{
"checkpoint"
:
True
},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
...
...
@@ -1058,6 +1064,8 @@ def test_layernorm_mlp():
else
:
kwargs_list
=
base_kwargs_list
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"delay_wgrad_compute"
,
False
)
and
NVTE_TEST_NVINSPECT_ENABLED
:
continue
for
set_parallel_mode
in
[
True
]:
for
sequence_parallel
in
[
False
,
True
]:
_test_layernorm_mlp
(
set_parallel_mode
,
sequence_parallel
,
**
kwargs
)
...
...
tests/pytorch/distributed/run_numerics_exact.py
View file @
0d874a4e
#!/usr/bin/python3
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -20,6 +20,7 @@ from transformer_engine.common.recipe import (
DelayedScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
MXFP8BlockScaling
,
Format
,
Recipe
,
)
...
...
@@ -27,9 +28,11 @@ import transformer_engine.pytorch as te
from
transformer_engine.pytorch
import
(
is_fp8_available
,
is_fp8_block_scaling_available
,
is_mxfp8_available
,
QuantizedTensor
,
Float8Tensor
,
Float8BlockwiseQTensor
,
MXFP8Tensor
,
)
from
transformer_engine.pytorch.tensor
import
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor.utils
import
post_all_gather_processing
,
replace_raw_data
...
...
@@ -44,17 +47,21 @@ def _get_quantization_recipe(quantization) -> Recipe:
return
Float8CurrentScaling
(
fp8_format
=
fp8_format
)
elif
quantization
==
"fp8_block"
:
return
Float8BlockScaling
(
fp8_format
=
fp8_format
)
elif
quantization
==
"mxfp8"
:
return
MXFP8BlockScaling
()
else
:
raise
ValueError
(
f
"Unsupported quantization:
{
quantization
}
"
)
def
_get_raw_data
(
quantized_tensor
):
def
_get_raw_data
(
quantized_tensor
,
colwise
=
False
):
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
if
isinstance
(
quantized_tensor
,
Float8Tensor
):
assert
not
colwise
,
"Float8Tensor does not support get colwise data"
assert
hasattr
(
quantized_tensor
,
"_data"
),
"Float8Tensor does not have _data attribute"
assert
quantized_tensor
.
_data
.
dtype
==
torch
.
uint8
,
"Float8Tensor _data must be uint8"
return
quantized_tensor
.
_data
elif
isinstance
(
quantized_tensor
,
Float8BlockwiseQTensor
):
assert
not
colwise
,
"Float8BlockwiseQTensor does not support get colwise data"
assert
hasattr
(
quantized_tensor
,
"_rowwise_data"
),
"Float8BlockwiseQTensor does not have _rowwise_data attribute"
...
...
@@ -62,6 +69,23 @@ def _get_raw_data(quantized_tensor):
quantized_tensor
.
_rowwise_data
.
dtype
==
torch
.
uint8
),
"Float8BlockwiseQTensor _rowwise_data must be uint8"
return
quantized_tensor
.
_rowwise_data
elif
isinstance
(
quantized_tensor
,
MXFP8Tensor
):
if
colwise
:
assert
hasattr
(
quantized_tensor
,
"_columnwise_data"
),
"MXFP8Tensor does not have columnwise_data attribute"
assert
(
quantized_tensor
.
_columnwise_data
.
dtype
==
torch
.
uint8
),
"MXFP8Tensor columnwise_data must be uint8"
return
quantized_tensor
.
_columnwise_data
else
:
assert
hasattr
(
quantized_tensor
,
"_rowwise_data"
),
"MXFP8Tensor does not have rowwise_data attribute"
assert
(
quantized_tensor
.
_rowwise_data
.
dtype
==
torch
.
uint8
),
"MXFP8Tensor rowwise_data must be uint8"
return
quantized_tensor
.
_rowwise_data
else
:
raise
ValueError
(
f
"Unsupported quantized tensor type:
{
type
(
quantized_tensor
)
}
"
)
...
...
@@ -231,38 +255,43 @@ class MiniZero_1:
end
=
start_offset
+
master_weight
.
numel
()
weight
.
data
.
view
(
-
1
)[
start
:
end
].
copy_
(
master_weight
)
# -----------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -----------------------------------------------------------------------------------------
for
i
in
range
(
len
(
self
.
weights
)):
master_weight
=
self
.
master_weights
[
i
]
if
master_weight
is
None
:
continue
start_offset
=
self
.
start_offsets
[
i
]
if
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
):
weight
=
_get_raw_data
(
self
.
weights
[
i
])
else
:
weight
=
self
.
weights
[
i
]
weight_slice
=
weight
.
view
(
-
1
)[
start_offset
:
start_offset
+
master_weight
.
numel
()]
overlapping_start
,
overlapping_end
=
self
.
overlapping_areas
[
i
]
self
.
weight_buffer
[
overlapping_start
:
overlapping_end
].
copy_
(
weight_slice
)
colwise_list
=
[
False
]
if
isinstance
(
self
.
weights
[
0
],
MXFP8Tensor
):
colwise_list
.
append
(
True
)
# -----------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -----------------------------------------------------------------------------------------
dist
.
all_gather_into_tensor
(
self
.
weight_buffer
,
self
.
weight_buffer_slice
,
group
=
self
.
dp_group
)
for
colwise
in
colwise_list
:
# -------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -------------------------------------------------------------------------------------
for
i
in
range
(
len
(
self
.
weights
)):
master_weight
=
self
.
master_weights
[
i
]
if
master_weight
is
None
:
continue
start_offset
=
self
.
start_offsets
[
i
]
if
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
):
weight
=
_get_raw_data
(
self
.
weights
[
i
],
colwise
)
else
:
weight
=
self
.
weights
[
i
]
weight_slice
=
weight
.
view
(
-
1
)[
start_offset
:
start_offset
+
master_weight
.
numel
()]
overlapping_start
,
overlapping_end
=
self
.
overlapping_areas
[
i
]
self
.
weight_buffer
[
overlapping_start
:
overlapping_end
].
copy_
(
weight_slice
)
# -------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -------------------------------------------------------------------------------------
dist
.
all_gather_into_tensor
(
self
.
weight_buffer
,
self
.
weight_buffer_slice
,
group
=
self
.
dp_group
)
# ----
-------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# ----
-------------------------------------------------------------------------------------
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
start
=
offset
end
=
offset
+
weight
.
numel
()
if
isinstance
(
weight
,
QuantizedTensor
):
weight
=
_get_raw_data
(
weight
)
weight
.
view
(
-
1
).
data
.
copy_
(
self
.
weight_buffer
[
start
:
end
])
#
-------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
#
-------------------------------------------------------------------------------------
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
start
=
offset
end
=
offset
+
weight
.
numel
()
if
isinstance
(
weight
,
QuantizedTensor
):
weight
=
_get_raw_data
(
weight
,
colwise
)
weight
.
view
(
-
1
).
data
.
copy_
(
self
.
weight_buffer
[
start
:
end
])
if
self
.
manual_post_all_gather_processing
:
quantized_weights
=
[
...
...
@@ -287,9 +316,15 @@ class MiniFSDP:
else
:
raw_data_list
=
[
w
.
view
(
-
1
)
for
w
in
weights
]
self
.
flatten_weight
,
original_length
=
self
.
_flatten_tensors_with_pad
(
raw_data_list
)
if
isinstance
(
weights
[
0
],
MXFP8Tensor
):
self
.
flatten_columnwise
=
self
.
flatten_weight
.
clone
()
else
:
self
.
flatten_columnwise
=
None
# Split flattened weights into shards
self
.
local_weight_shard
=
torch
.
chunk
(
self
.
flatten_weight
,
world_size
)[
rank
]
if
self
.
flatten_columnwise
is
not
None
:
self
.
local_columnwise_shard
=
torch
.
chunk
(
self
.
flatten_columnwise
,
world_size
)[
rank
]
self
.
local_main_grad_shard
=
torch
.
zeros_like
(
self
.
local_weight_shard
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
...
...
@@ -321,14 +356,25 @@ class MiniFSDP:
self
.
shard_indices
.
append
((
None
,
None
))
if
isinstance
(
weights
[
idx
],
QuantizedTensor
):
replace_raw_data
(
weights
[
idx
],
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
)
if
self
.
flatten_columnwise
is
not
None
:
new_rowwise_data
=
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
new_rowwise_data
.
copy_
(
weights
[
idx
].
_rowwise_data
)
weights
[
idx
].
_rowwise_data
=
new_rowwise_data
new_columnwise_data
=
self
.
flatten_columnwise
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
new_columnwise_data
.
copy_
(
weights
[
idx
].
_columnwise_data
)
weights
[
idx
].
_columnwise_data
=
new_columnwise_data
else
:
replace_raw_data
(
weights
[
idx
],
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
)
else
:
weights
[
idx
].
data
=
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
# Initialize local model weights and high-precision master weights
self
.
local_weights
=
[]
self
.
local_columnwise
=
[]
self
.
master_weights
=
[]
for
i
,
weight
in
enumerate
(
self
.
weights
):
weight_start
,
weight_end
=
self
.
weight_indices
[
i
]
...
...
@@ -336,6 +382,11 @@ class MiniFSDP:
if
shard_start
is
not
None
and
shard_end
is
not
None
:
local_weight_shard
=
self
.
local_weight_shard
[
shard_start
:
shard_end
]
self
.
local_weights
.
append
(
local_weight_shard
)
if
self
.
flatten_columnwise
is
not
None
:
local_columnwise_shard
=
self
.
local_columnwise_shard
[
shard_start
:
shard_end
]
else
:
local_columnwise_shard
=
None
self
.
local_columnwise
.
append
(
local_columnwise_shard
)
if
isinstance
(
weight
,
QuantizedTensor
):
high_precision_init_val
=
weight
.
get_high_precision_init_val
().
view
(
-
1
)
...
...
@@ -347,6 +398,7 @@ class MiniFSDP:
self
.
master_weights
.
append
(
master_weight_shard
)
else
:
self
.
local_weights
.
append
(
None
)
self
.
local_columnwise
.
append
(
None
)
self
.
master_weights
.
append
(
None
)
setattr
(
weight
,
"main_grad"
,
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
...
...
@@ -417,12 +469,12 @@ class MiniFSDP:
# Step 3: Cast master weights to FP8 or BF16 precision
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
local_weights
=
[]
for
local_weight
in
self
.
local_weights
:
if
local_weight
is
None
:
local_
weights
.
append
(
None
)
continue
local_weights
.
append
(
local_weight
)
for
i
,
local_weight
in
enumerate
(
self
.
local_weights
)
:
if
self
.
flatten_columnwise
is
not
None
:
local_
columnwise
=
self
.
local_columnwise
[
i
]
local_weights
.
append
((
local_weight
,
local_columnwise
))
else
:
local_weights
.
append
(
local_weight
)
cast_master_weights_to_fp8
(
self
.
weights
,
...
...
@@ -444,6 +496,10 @@ class MiniFSDP:
dist
.
all_gather_into_tensor
(
self
.
flatten_weight
,
self
.
local_weight_shard
,
group
=
self
.
dp_group
)
if
self
.
flatten_columnwise
is
not
None
:
dist
.
all_gather_into_tensor
(
self
.
flatten_columnwise
,
self
.
local_columnwise_shard
,
group
=
self
.
dp_group
)
if
self
.
manual_post_all_gather_processing
:
quantized_weights
=
[
...
...
@@ -515,15 +571,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
32
,
**
linear_kwargs
),
te
.
Linear
(
256
+
32
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
32
,
**
linear_kwargs
),
te
.
Linear
(
256
+
32
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
...
...
@@ -548,7 +604,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
w
.
main_grad
.
zero_
()
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
torch
.
randn
(
32
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
...
...
@@ -579,7 +635,9 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
optimizer_fp8
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
assert
torch
.
allclose
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
),
f
"Loss mismatch at rank
{
rank
}
, step
{
i
}
for
{
quantization
}
"
def
_test_fsdp_cast_master_weights_to_fp8
(
...
...
@@ -611,15 +669,15 @@ def _test_fsdp_cast_master_weights_to_fp8(
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
32
,
**
linear_kwargs
),
te
.
Linear
(
256
+
32
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
32
,
**
linear_kwargs
),
te
.
Linear
(
256
+
32
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
...
...
@@ -633,12 +691,12 @@ def _test_fsdp_cast_master_weights_to_fp8(
)
optimizer
=
MiniFSDP
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
_
in
range
(
100
):
for
i
in
range
(
100
):
optimizer_fp8
.
zero_grad
()
optimizer
.
zero_grad
()
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
torch
.
randn
(
32
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
...
...
@@ -669,7 +727,9 @@ def _test_fsdp_cast_master_weights_to_fp8(
optimizer_fp8
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
assert
torch
.
allclose
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
),
f
"Loss mismatch at rank
{
rank
}
, step
{
i
}
for
{
quantization
}
(FSDP)"
def
run_parallel_tests
()
->
None
:
...
...
@@ -700,6 +760,8 @@ def run_parallel_tests() -> None:
quantizations
.
extend
([
"fp8"
,
"fp8_cs"
])
if
is_fp8_block_scaling_available
():
quantizations
.
append
(
"fp8_block"
)
if
is_mxfp8_available
():
quantizations
.
append
(
"mxfp8"
)
manual_post_all_gather_processings
=
[
False
,
True
]
...
...
tests/pytorch/distributed/test_comm_gemm_overlap.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_gemm_with_overlap.py --check-numerics --seed=42 --seq-length=1024 --batch-size=2 --num-heads=48 --head-dim=64 --comm-type=AG --p2p
...
...
@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
os
.
environ
[
"PYTORCH_JIT"
]
=
"0"
os
.
environ
[
"NVTE_TORCH_COMPILE"
]
=
"0"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"0"
if
te
.
get_device_compute_capability
()
<=
(
8
,
0
):
# We've experienced numerical discrepancies in Flash Attention
# backward when running with Userbuffers on A100s. This does
# not show up in more recent GPUs.
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
capture_output
=
True
,
check
=
False
)
os
.
unsetenv
(
"PYTORCH_JIT"
)
os
.
unsetenv
(
"NVTE_TORCH_COMPILE"
)
os
.
unsetenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
)
os
.
unsetenv
(
"NVTE_FLASH_ATTN"
)
if
(
result
.
returncode
!=
0
...
...
tests/pytorch/distributed/test_fusible_ops.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/distributed/test_numerics.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -13,7 +13,7 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
"""
Distributed numerics tests
These tests test the numerical corectness of the TransformerEngine layers.
These tests test the numerical cor
r
ectness of the TransformerEngine layers.
Tests are parametrized by the layer and fp8 precision.
One test consists of running multiple configurations from file run_numerics.py
Such design is due to the fact the initialization of one test is long
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment