Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
c1a1c04e
Commit
c1a1c04e
authored
Dec 27, 2025
by
wenjh
Browse files
Merge nv_main(2.10) to main
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
e698a0a7
66aed3ae
Changes
208
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2303 additions
and
1062 deletions
+2303
-1062
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+72
-21
tests/pytorch/attention/test_attention_with_cp.py
tests/pytorch/attention/test_attention_with_cp.py
+3
-3
tests/pytorch/debug/run_distributed.py
tests/pytorch/debug/run_distributed.py
+2
-1
tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml
...orch/debug/test_configs/test_switch_to_nondebug_mode.yaml
+11
-0
tests/pytorch/debug/test_log.py
tests/pytorch/debug/test_log.py
+196
-2
tests/pytorch/debug/test_perf.py
tests/pytorch/debug/test_perf.py
+55
-56
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
+0
-684
tests/pytorch/distributed/run_fsdp2_model.py
tests/pytorch/distributed/run_fsdp2_model.py
+247
-79
tests/pytorch/distributed/run_numerics_exact.py
tests/pytorch/distributed/run_numerics_exact.py
+3
-3
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
+729
-22
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
.../pytorch/distributed/test_fusible_ops_with_userbuffers.py
+1
-0
tests/pytorch/distributed/test_numerics_exact.py
tests/pytorch/distributed/test_numerics_exact.py
+1
-1
tests/pytorch/distributed/test_torch_fsdp2.py
tests/pytorch/distributed/test_torch_fsdp2.py
+13
-5
tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
+2
-2
tests/pytorch/nvfp4/test_nvfp4_module_exact.py
tests/pytorch/nvfp4/test_nvfp4_module_exact.py
+2
-2
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
+2
-2
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
+3
-3
tests/pytorch/test_cpu_offloading.py
tests/pytorch/test_cpu_offloading.py
+713
-171
tests/pytorch/test_cpu_offloading_v1.py
tests/pytorch/test_cpu_offloading_v1.py
+215
-0
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+33
-5
No files found.
tests/pytorch/attention/test_attention.py
View file @
c1a1c04e
...
@@ -45,7 +45,8 @@ from transformer_engine.pytorch.utils import (
...
@@ -45,7 +45,8 @@ from transformer_engine.pytorch.utils import (
)
)
from
transformer_engine.pytorch.utils
import
get_cudnn_version
from
transformer_engine.pytorch.utils
import
get_cudnn_version
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
from
transformer_engine.pytorch.quantized_tensor
import
(
Quantizer
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
)
)
...
@@ -60,8 +61,16 @@ from utils import (
...
@@ -60,8 +61,16 @@ from utils import (
get_available_attention_backends
,
get_available_attention_backends
,
)
)
# Check if hardware supports FP8
# Check if hardware supports FP8
attention.
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
fp8_attn_available
,
reason_for_no_fp8_attn
=
fp8_available
,
reason_for_no_fp8
device_compute_capability
=
get_device_compute_capability
()
if
fp8_available
and
(
device_compute_capability
<
(
9
,
0
)
or
device_compute_capability
>=
(
12
,
0
)):
fp8_attn_available
=
False
reason_for_no_fp8_attn
=
(
"FP8 attention is not supported for compute capability ="
f
" sm
{
device_compute_capability
[
0
]
*
10
+
device_compute_capability
[
1
]
}
"
)
# Reset RNG seed and states
# Reset RNG seed and states
seed
=
1234
seed
=
1234
...
@@ -130,6 +139,11 @@ def test_dot_product_attention(
...
@@ -130,6 +139,11 @@ def test_dot_product_attention(
if
config
.
window_size
==
(
-
1
,
-
1
)
and
swa
:
if
config
.
window_size
==
(
-
1
,
-
1
)
and
swa
:
config
.
window_size
=
[
2
,
2
]
config
.
window_size
=
[
2
,
2
]
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
qkv_format
=
qkv_layout
.
replace
(
"3"
,
""
).
replace
(
"2"
,
""
).
split
(
"_"
)[
0
]
if
qkv_format
==
"thd"
and
"padding"
not
in
config
.
attn_mask_type
:
config
.
attn_mask_type
=
(
"padding_"
+
config
.
attn_mask_type
if
config
.
attn_mask_type
!=
"no_mask"
else
"padding"
)
# Get backends
# Get backends
is_training
=
True
is_training
=
True
...
@@ -171,7 +185,7 @@ def test_dot_product_attention(
...
@@ -171,7 +185,7 @@ def test_dot_product_attention(
# UnfusedDotProductAttention backend
# UnfusedDotProductAttention backend
if
unfused_attn_supported
:
if
unfused_attn_supported
:
unfused_attn_fwd
,
unfused_attn_bwd
=
_run_dot_product_attention
(
unfused_attn_fwd
,
unfused_max_logit
,
unfused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"UnfusedDotProductAttention"
,
"UnfusedDotProductAttention"
,
...
@@ -185,7 +199,7 @@ def test_dot_product_attention(
...
@@ -185,7 +199,7 @@ def test_dot_product_attention(
# FusedAttention backend
# FusedAttention backend
if
fused_attn_supported
:
if
fused_attn_supported
:
if
len
(
fused_attn_backends
)
==
1
:
if
len
(
fused_attn_backends
)
==
1
:
fused_attn_fwd
,
fused_attn_bwd
=
_run_dot_product_attention
(
fused_attn_fwd
,
fused_max_logit
,
fused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FusedAttention"
,
"FusedAttention"
,
...
@@ -197,7 +211,7 @@ def test_dot_product_attention(
...
@@ -197,7 +211,7 @@ def test_dot_product_attention(
)
)
if
len
(
fused_attn_backends
)
==
2
:
if
len
(
fused_attn_backends
)
==
2
:
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"0"
fused_attn_fwd
,
fused_attn_bwd
=
_run_dot_product_attention
(
fused_attn_fwd
,
_
,
fused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FusedAttention"
,
"FusedAttention"
,
...
@@ -208,7 +222,7 @@ def test_dot_product_attention(
...
@@ -208,7 +222,7 @@ def test_dot_product_attention(
is_training
,
is_training
,
)
)
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"1"
fused_attn_fwd_1
,
fused_attn_bwd_1
=
_run_dot_product_attention
(
fused_attn_fwd_1
,
_
,
fused_attn_bwd_1
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FusedAttention"
,
"FusedAttention"
,
...
@@ -221,7 +235,7 @@ def test_dot_product_attention(
...
@@ -221,7 +235,7 @@ def test_dot_product_attention(
# FlashAttention backend
# FlashAttention backend
if
flash_attn_supported
:
if
flash_attn_supported
:
flash_attn_fwd
,
flash_attn_bwd
=
_run_dot_product_attention
(
flash_attn_fwd
,
_
,
flash_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FlashAttention"
,
"FlashAttention"
,
...
@@ -242,6 +256,8 @@ def test_dot_product_attention(
...
@@ -242,6 +256,8 @@ def test_dot_product_attention(
if
unfused_attn_supported
and
fused_attn_supported
:
if
unfused_attn_supported
and
fused_attn_supported
:
logging
.
info
(
"[test_dot_product_attention]: unfused attn vs fused attn"
)
logging
.
info
(
"[test_dot_product_attention]: unfused attn vs fused attn"
)
torch
.
testing
.
assert_close
(
fused_attn_fwd
,
unfused_attn_fwd
,
**
tols
)
torch
.
testing
.
assert_close
(
fused_attn_fwd
,
unfused_attn_fwd
,
**
tols
)
if
config
.
return_max_logit
:
torch
.
testing
.
assert_close
(
fused_max_logit
,
unfused_max_logit
,
**
tols
)
for
i
,
_
in
enumerate
(
unfused_attn_bwd
):
for
i
,
_
in
enumerate
(
unfused_attn_bwd
):
torch
.
testing
.
assert_close
(
fused_attn_bwd
[
i
],
unfused_attn_bwd
[
i
],
**
tols
)
torch
.
testing
.
assert_close
(
fused_attn_bwd
[
i
],
unfused_attn_bwd
[
i
],
**
tols
)
if
fused_attn_supported
and
flash_attn_supported
:
if
fused_attn_supported
and
flash_attn_supported
:
...
@@ -265,6 +281,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
...
@@ -265,6 +281,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
model_configs_max_logit
=
{
# test: ModelConfig(b, sq, hq, dqk)
"max_logit_1"
:
ModelConfig
(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
),
"max_logit_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
),
"max_logit_3"
:
ModelConfig
(
2
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal"
),
"max_logit_4"
:
ModelConfig
(
8
,
128
,
16
,
192
,
max_seqlen_kv
=
2048
,
attn_bias_type
=
"post_scale_bias"
),
"max_logit_5"
:
ModelConfig
(
8
,
128
,
16
,
512
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal"
,
window_size
=
(
20
,
0
)
),
"max_logit_6"
:
ModelConfig
(
8
,
1
,
16
,
1024
,
max_seqlen_kv
=
2048
),
}
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
1
),
reason
=
"cuDNN 8.9.1+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_max_logit
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_max_logit
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
[
"sbhd_sbhd_sbhd"
,
"thd_thd_thd"
])
def
test_dpa_max_logit
(
dtype
,
model_configs
,
model
,
qkv_layout
):
"""Test DotProductAttention module with checkpointing"""
config
=
model_configs
[
model
]
config
.
return_max_logit
=
True
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
False
,
True
,
qkv_layout
,
False
,
False
)
model_configs_softmax
=
{
model_configs_softmax
=
{
# test: ModelConfig(b, sq, hq, dqk)
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0"
:
ModelConfig
(
2
,
2048
,
64
,
64
,
num_gqa_groups
=
8
),
"softmax_1_0"
:
ModelConfig
(
2
,
2048
,
64
,
64
,
num_gqa_groups
=
8
),
...
@@ -962,6 +1005,8 @@ def _run_dot_product_attention(
...
@@ -962,6 +1005,8 @@ def _run_dot_product_attention(
layout
=
layout
.
replace
(
"d"
,
"dqk"
)
layout
=
layout
.
replace
(
"d"
,
"dqk"
)
tensor_shape
=
[
dim_to_num
[
j
]
for
j
in
layout
.
split
(
"_"
)]
tensor_shape
=
[
dim_to_num
[
j
]
for
j
in
layout
.
split
(
"_"
)]
tensor
=
0.1
*
torch
.
randn
(
tensor_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
tensor
=
0.1
*
torch
.
randn
(
tensor_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
# tensor: with padding tokens
# tensor_orig: without padding tokens
tensor_orig
=
tensor
tensor_orig
=
tensor
if
qkv_format
==
"thd"
and
pad_between_seqs
:
if
qkv_format
==
"thd"
and
pad_between_seqs
:
tensor_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
tensor_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
...
@@ -1071,6 +1116,7 @@ def _run_dot_product_attention(
...
@@ -1071,6 +1116,7 @@ def _run_dot_product_attention(
layer_number
=
1
,
layer_number
=
1
,
attention_type
=
config
.
attn_type
,
attention_type
=
config
.
attn_type
,
softmax_type
=
config
.
softmax_type
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
).
to
(
dtype
=
dtype
,
device
=
"cuda"
)
).
to
(
dtype
=
dtype
,
device
=
"cuda"
)
if
not
is_training
:
if
not
is_training
:
block
=
block
.
eval
()
block
=
block
.
eval
()
...
@@ -1108,16 +1154,21 @@ def _run_dot_product_attention(
...
@@ -1108,16 +1154,21 @@ def _run_dot_product_attention(
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
fast_zero_fill
=
True
,
fast_zero_fill
=
True
,
)
)
max_logit
=
None
if
config
.
return_max_logit
:
out
,
max_logit
=
out
if
is_training
:
if
is_training
:
out
.
backward
(
d_out
)
out
.
backward
(
d_out
)
d_softmax_offset
=
None
d_softmax_offset
=
None
if
is_training
and
config
.
softmax_type
!=
"vanilla"
:
if
is_training
and
config
.
softmax_type
!=
"vanilla"
:
d_softmax_offset
=
block
.
softmax_offset
.
grad
d_softmax_offset
=
block
.
softmax_offset
.
grad
if
backend
in
[
"FlashAttention"
,
"UnfusedDotProductAttention"
]:
if
backend
in
[
"FlashAttention"
,
"UnfusedDotProductAttention"
]:
if
is_training
:
if
is_training
:
return
out
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
return
out
,
max_logit
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
else
:
else
:
return
out
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
if
backend
==
"FusedAttention"
:
if
backend
==
"FusedAttention"
:
if
qkv_format
==
"thd"
and
pad_between_seqs
:
if
qkv_format
==
"thd"
and
pad_between_seqs
:
out_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
out_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
...
@@ -1146,14 +1197,18 @@ def _run_dot_product_attention(
...
@@ -1146,14 +1197,18 @@ def _run_dot_product_attention(
[
v_grad_orig
,
v
.
grad
[
valid_range_kv
[
0
]
:
valid_range_kv
[
1
]]],
dim
=
0
[
v_grad_orig
,
v
.
grad
[
valid_range_kv
[
0
]
:
valid_range_kv
[
1
]]],
dim
=
0
)
)
if
is_training
:
if
is_training
:
return
out_orig
,
(
q_grad_orig
,
k_grad_orig
,
v_grad_orig
,
d_softmax_offset
)
return
(
out_orig
,
max_logit
,
(
q_grad_orig
,
k_grad_orig
,
v_grad_orig
,
d_softmax_offset
),
)
else
:
else
:
return
out_orig
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out_orig
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
else
:
else
:
if
is_training
:
if
is_training
:
return
out
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
return
out
,
max_logit
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
else
:
else
:
return
out
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
model_configs_te_layer
=
{
model_configs_te_layer
=
{
...
@@ -1527,8 +1582,7 @@ model_configs_fp8_extra_state = {
...
@@ -1527,8 +1582,7 @@ model_configs_fp8_extra_state = {
}
}
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
3
,
0
),
reason
=
"cuDNN 9.3.0+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
3
,
0
),
reason
=
"cuDNN 9.3.0+ is required."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"large"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"large"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
...
@@ -1690,8 +1744,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
...
@@ -1690,8 +1744,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
qkv_format_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
qkv_format_fp8_vs_f16
)
...
@@ -1927,8 +1980,7 @@ def _run_mha_fp8_vs_f16(
...
@@ -1927,8 +1980,7 @@ def _run_mha_fp8_vs_f16(
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
qkv_layout_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
qkv_layout_fp8_vs_f16
)
...
@@ -2256,8 +2308,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
...
@@ -2256,8 +2308,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
),
),
reason
=
f
"""cuDNN
{
"8.9.3"
if
cudnn_frontend_version
==
0
else
"9.2.1"
}
+ is required."""
,
reason
=
f
"""cuDNN
{
"8.9.3"
if
cudnn_frontend_version
==
0
else
"9.2.1"
}
+ is required."""
,
)
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models_v1
if
cudnn_frontend_version
==
1
else
models_v0
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models_v1
if
cudnn_frontend_version
==
1
else
models_v0
)
def
test_custom_mha_fp8_vs_f16
(
dtype
,
model
):
def
test_custom_mha_fp8_vs_f16
(
dtype
,
model
):
...
...
tests/pytorch/attention/test_attention_with_cp.py
View file @
c1a1c04e
...
@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
...
@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn
=
{
model_configs_fused_attn
=
{
# test: ModelConfig(b, sq, hq, dqk)
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
),
# MHA
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
return_max_logit
=
True
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
return_max_logit
=
True
),
# MHA
"cp_1_2"
:
ModelConfig
(
"cp_1_2"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
),
# MHA
...
@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"]
...
@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
if
test_essential
:
if
test_essential
:
configs
=
[
"cp_1_0"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
configs
=
[
"cp_1_0"
,
"cp_1_1"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
model_configs_fused_attn
=
{
k
:
model_configs_fused_attn
[
k
]
for
k
in
configs
}
model_configs_fused_attn
=
{
k
:
model_configs_fused_attn
[
k
]
for
k
in
configs
}
dtypes
=
[
"bf16"
,
"fp8"
]
dtypes
=
[
"bf16"
,
"fp8"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
...
...
tests/pytorch/debug/run_distributed.py
View file @
c1a1c04e
...
@@ -685,11 +685,12 @@ if __name__ == "__main__":
...
@@ -685,11 +685,12 @@ if __name__ == "__main__":
pass
pass
else
:
else
:
test_log_expert_parallel
()
test_log_expert_parallel
()
if
fp8_available
:
for
parallel_mode
in
[
"column"
,
"row"
]:
for
parallel_mode
in
[
"column"
,
"row"
]:
for
gather_weight
in
[
True
,
False
]:
for
gather_weight
in
[
True
,
False
]:
test_log_distributed
(
parallel_mode
,
gather_weight
)
test_log_distributed
(
parallel_mode
,
gather_weight
)
if
fp8_available
:
for
parallel_mode
in
[
"row"
,
"column"
]:
for
parallel_mode
in
[
"row"
,
"column"
]:
test_disable_fp8_layer
(
parallel_mode
)
test_disable_fp8_layer
(
parallel_mode
)
...
...
tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml
0 → 100644
View file @
c1a1c04e
test_switch_to_nondebug_mode
:
enabled
:
True
layers
:
layer_name_regex_pattern
:
.*
transformer_engine
:
TestDummyFeature
:
enabled
:
True
inspect_only_once
:
True
tensors
:
[
weight
,
activation
,
gradient
,
output
,
wgrad
,
dgrad
]
gemms
:
[
wgrad
,
dgrad
,
fprop
]
tests/pytorch/debug/test_log.py
View file @
c1a1c04e
...
@@ -18,7 +18,11 @@ from transformer_engine.pytorch import (
...
@@ -18,7 +18,11 @@ from transformer_engine.pytorch import (
)
)
from
transformer_engine.pytorch.quantization
import
RecipeState
from
transformer_engine.pytorch.quantization
import
RecipeState
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.debug.features.utils.stats_computation
import
(
compute_max_blockwise_dynamic_range
,
BlockwiseDynamicRangeStat
,
)
import
math
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
is_mxfp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
is_mxfp8_available
(
return_reason
=
True
)
...
@@ -154,7 +158,7 @@ fp8_recipes = [
...
@@ -154,7 +158,7 @@ fp8_recipes = [
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
def
test_numerics
(
fp8_recipe
,
feature_dirs
):
def
test_
log_quantized_stats_
numerics
(
fp8_recipe
,
feature_dirs
):
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
not
mxfp8_available
and
fp8_recipe
==
recipe
.
MXFP8BlockScaling
():
if
not
mxfp8_available
and
fp8_recipe
==
recipe
.
MXFP8BlockScaling
():
...
@@ -210,6 +214,107 @@ def test_numerics(fp8_recipe, feature_dirs):
...
@@ -210,6 +214,107 @@ def test_numerics(fp8_recipe, feature_dirs):
assert
overflows
==
pytest
.
approx
(
expected
.
cpu
(),
abs
=
1e-4
)
assert
overflows
==
pytest
.
approx
(
expected
.
cpu
(),
abs
=
1e-4
)
LOG_HIGH_PRECISION_CONFIG
=
"""
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
stats:
- dynamic_range
- max_blockwise_dynamic_range:
block_size: 4
dims: 1
- max_blockwise_dynamic_range:
block_size: 4
dims: 2
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
@
pytest
.
mark
.
parametrize
(
"tensor_name"
,
[
"activation"
,
"weight"
,
"gradient"
])
def
test_log_stats_numerics
(
feature_dirs
,
tensor_name
):
"""Check correctness of dynamic range and max blockwise dynamic range stats.
Tests different tensor types:
- activation/weight: use both orientations (rowwise + columnwise), takes max
- gradient/dgrad: use single orientation (rowwise only)
"""
log_only_bare_stats_config
=
LOG_HIGH_PRECISION_CONFIG
with
debug_session
(
log_only_bare_stats_config
,
feature_dirs
)
as
log_dir
:
# There is 1024 x 1024 tensor with very small epsilon values in almost all elements,
# one row of large value A and three rows of large value B.
epsilon
=
1e-10
A
=
1000
B
=
50
tensor
=
torch
.
zeros
(
1024
,
1024
).
cuda
()
+
epsilon
tensor
[
0
,
:]
=
A
tensor
[
1
:
4
,
:]
=
B
debug_api
.
transformer_engine
.
inspect_tensor
(
layer_name
=
"layer_name"
,
tensor_name
=
tensor_name
,
iteration
=
0
,
tp_group
=
None
,
tensor
=
tensor
,
quantizer
=
None
,
rowwise_quantized_tensor
=
None
,
columnwise_quantized_tensor
=
None
,
)
debug_api
.
step
()
output
=
read_log
(
log_dir
)
max_over_orientations
=
tensor_name
in
[
"activation"
,
"weight"
]
max_over_orientations_suffix
=
"_max_over_orientations"
if
max_over_orientations
else
""
# Track which stats were found to ensure all are present
found_dims_1
=
False
found_dims_2
=
False
found_dynamic_range
=
False
for
line
in
output
.
splitlines
():
if
f
"max_blockwise_dynamic_range_block_size_4_dims_1
{
max_over_orientations_suffix
}
"
in
line
:
max_blockwise_dynamic_range_block_size_4_dims_1
=
float
(
line
.
split
(
"value="
)[
1
])
if
max_over_orientations
:
# Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B)
expected
=
math
.
log2
(
A
)
-
math
.
log2
(
B
)
else
:
# Rowwise blocks have uniform values -> dynamic_range = 0
expected
=
0
assert
max_blockwise_dynamic_range_block_size_4_dims_1
==
pytest
.
approx
(
expected
,
abs
=
1e-4
)
found_dims_1
=
True
elif
(
f
"max_blockwise_dynamic_range_block_size_4_dims_2
{
max_over_orientations_suffix
}
"
in
line
):
max_blockwise_dynamic_range_block_size_4_dims_2
=
float
(
line
.
split
(
"value="
)[
1
])
# For 2D blocks (4x4 tiles), blocks always contain mixed values from different rows
expected
=
math
.
log2
(
A
)
-
math
.
log2
(
B
)
assert
max_blockwise_dynamic_range_block_size_4_dims_2
==
pytest
.
approx
(
expected
,
abs
=
1e-4
)
found_dims_2
=
True
elif
"_dynamic_range"
in
line
and
"max_blockwise_dynamic_range"
not
in
line
:
dynamic_range
=
float
(
line
.
split
(
"value="
)[
1
])
expected
=
math
.
log2
(
A
)
-
math
.
log2
(
epsilon
)
assert
dynamic_range
==
pytest
.
approx
(
expected
,
abs
=
1e-4
)
found_dynamic_range
=
True
# Ensure all expected stats were found in the output
assert
found_dims_1
,
"max_blockwise_dynamic_range (dims=1) not found in output"
assert
found_dims_2
,
"max_blockwise_dynamic_range (dims=2) not found in output"
assert
found_dynamic_range
,
"dynamic_range not found in output"
@
pytest
.
mark
.
parametrize
(
"layer"
,
[
"linear"
,
"transformer"
])
@
pytest
.
mark
.
parametrize
(
"layer"
,
[
"linear"
,
"transformer"
])
def
test_log_every_3_or_5_layers
(
layer
,
configs_dir
,
feature_dirs
):
def
test_log_every_3_or_5_layers
(
layer
,
configs_dir
,
feature_dirs
):
if
not
fp8_available
:
if
not
fp8_available
:
...
@@ -256,3 +361,92 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
...
@@ -256,3 +361,92 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
debug_api
.
end_debug
()
debug_api
.
end_debug
()
TEDebugState
.
_reset
()
TEDebugState
.
_reset
()
def
test_compute_max_blockwise_dynamic_range_direct
():
"""Direct unit test for compute_max_blockwise_dynamic_range function.
Tests the function with various configurations to ensure correct behavior
for different block sizes, dimensions, and orientation settings.
"""
# Create test tensor with uniform rows but mixed columns
# Row 0: all 1000, Row 1-3: all 50, remaining: all 0.01
epsilon
=
0.01
A
=
1000.0
B
=
50.0
tensor
=
torch
.
zeros
(
1024
,
1024
).
cuda
()
+
epsilon
tensor
[
0
,
:]
=
A
tensor
[
1
:
4
,
:]
=
B
# Test 1: dims=1, max_over_orientations=False (rowwise only)
# Rowwise blocks have uniform values -> dynamic_range should be 0
stat_config
=
BlockwiseDynamicRangeStat
(
block_size
=
4
,
dims
=
1
,
max_over_orientations
=
False
)
result
=
compute_max_blockwise_dynamic_range
(
tensor
,
stat_config
)
assert
result
.
item
()
==
pytest
.
approx
(
0.0
,
abs
=
1e-4
),
"Rowwise 1D blocks with uniform values should have dynamic_range=0"
# Test 2: dims=1, max_over_orientations=True (max of rowwise and columnwise)
# Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B)
stat_config
=
BlockwiseDynamicRangeStat
(
block_size
=
4
,
dims
=
1
,
max_over_orientations
=
True
)
result
=
compute_max_blockwise_dynamic_range
(
tensor
,
stat_config
)
expected
=
math
.
log2
(
A
)
-
math
.
log2
(
B
)
assert
result
.
item
()
==
pytest
.
approx
(
expected
,
abs
=
1e-4
),
(
f
"Max over orientations should capture columnwise dynamic_range, expected
{
expected
}
, got"
f
"
{
result
.
item
()
}
"
)
# Test 3: dims=2, block_size=4 (4x4 tiles)
# 2D blocks span multiple rows -> always have mixed values
stat_config
=
BlockwiseDynamicRangeStat
(
block_size
=
4
,
dims
=
2
,
max_over_orientations
=
False
)
result
=
compute_max_blockwise_dynamic_range
(
tensor
,
stat_config
)
expected
=
math
.
log2
(
A
)
-
math
.
log2
(
B
)
assert
result
.
item
()
==
pytest
.
approx
(
expected
,
abs
=
1e-4
),
(
f
"2D blocks should capture mixed values from different rows, expected
{
expected
}
, got"
f
"
{
result
.
item
()
}
"
)
# Test 4: Different block size
# With block_size=8, columnwise blocks contain [A, B, B, B, epsilon, epsilon, epsilon, epsilon]
# So max=A, min=epsilon (not B anymore)
stat_config
=
BlockwiseDynamicRangeStat
(
block_size
=
8
,
dims
=
1
,
max_over_orientations
=
True
)
result
=
compute_max_blockwise_dynamic_range
(
tensor
,
stat_config
)
expected
=
math
.
log2
(
A
)
-
math
.
log2
(
epsilon
)
# min is epsilon, not B
assert
result
.
item
()
==
pytest
.
approx
(
expected
,
abs
=
1e-4
),
f
"Block size 8 should work correctly, expected
{
expected
}
, got
{
result
.
item
()
}
"
# Test 5: Tensor with all uniform values -> dynamic_range should be 0
uniform_tensor
=
torch
.
ones
(
64
,
64
).
cuda
()
*
42.0
stat_config
=
BlockwiseDynamicRangeStat
(
block_size
=
4
,
dims
=
1
,
max_over_orientations
=
True
)
result
=
compute_max_blockwise_dynamic_range
(
uniform_tensor
,
stat_config
)
assert
result
.
item
()
==
pytest
.
approx
(
0.0
,
abs
=
1e-4
),
"Uniform tensor should have dynamic_range=0"
# Test 6: 3D tensor flattening validation using 2D/3D comparison
# Create a 4x4 tensor with distinct 2x2 blocks, compute with dims=2, block_size=2
# Then reshape to 3D and compute again - results should match if flattening is correct
tensor_2d
=
torch
.
tensor
(
[
[
1.0
,
1.0
,
10.0
,
10.0
],
[
1.0
,
1.0
,
10.0
,
10.0
],
[
100.0
,
100.0
,
1000.0
,
1000.0
],
[
100.0
,
100.0
,
1000.0
,
1000.0
],
]
).
cuda
()
# Compute on 2D tensor: 4 blocks of 2x2, max range is log2(1000/100)
stat_config
=
BlockwiseDynamicRangeStat
(
block_size
=
2
,
dims
=
2
,
max_over_orientations
=
False
)
result_2d
=
compute_max_blockwise_dynamic_range
(
tensor_2d
,
stat_config
)
# Reshape to 3D [2, 2, 4] and compute - should give same result if flattening is correct
tensor_3d
=
tensor_2d
.
reshape
(
2
,
2
,
4
)
result_3d
=
compute_max_blockwise_dynamic_range
(
tensor_3d
,
stat_config
)
assert
result_2d
.
item
()
==
pytest
.
approx
(
result_3d
.
item
(),
abs
=
1e-6
),
(
"3D tensor [2,2,4] flattened to [4,4] must give same result as original 2D, got"
f
" 2D=
{
result_2d
.
item
()
}
, 3D=
{
result_3d
.
item
()
}
"
)
print
(
"All direct tests for compute_max_blockwise_dynamic_range passed!"
)
tests/pytorch/debug/test_perf.py
View file @
c1a1c04e
...
@@ -6,71 +6,70 @@
...
@@ -6,71 +6,70 @@
import
pytest
import
pytest
import
torch
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
import
time
import
nvdlfw_inspect.api
as
debug_api
import
nvdlfw_inspect.api
as
debug_api
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
def
_run_cpu_overhead
(
debug_tools_initialized
,
layer
,
configs_dir
,
feature_dirs
):
@
pytest
.
mark
.
parametrize
(
"use_microbatching"
,
[
False
,
True
])
debug_api
.
end_debug
()
def
test_layer_switches_to_nondebug_mode
(
configs_dir
,
feature_dirs
,
use_microbatching
):
TEDebugState
.
_reset
()
"""
if
debug_tools_initialized
:
Test that layers switch to non-debug mode when no features are active.
# This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS.
# So after 1 warm-up iteration, this layers should work in non-debug mode.
debug_api
.
initialize
(
config_file
=
configs_dir
+
"/perf_config.yaml"
,
feature_dirs
=
feature_dirs
)
try
:
Uses TestDummyFeature with inspect_only_once=True, which makes inspect_tensor_enabled return (False, None).
if
layer
==
"linear"
:
The TE should:
model
=
torch
.
nn
.
Sequential
(
1. Call inspect_tensor_enabled to check if feature is needed
te
.
Linear
(
1
,
1
,
name
=
"linear1"
),
te
.
Linear
(
1
,
1
,
name
=
"linear2"
)
2. Never call inspect_tensor
).
cuda
()
3. Allow layers to switch to non-debug mode for optimal performance,
NUM_ITERS
=
18000
so that inspect_tensor_enabled is never called again.
elif
layer
==
"transformer"
:
model
=
torch
.
nn
.
Sequential
(
te
.
TransformerLayer
(
1
,
1
,
1
,
name
=
"transformer1"
),
te
.
TransformerLayer
(
1
,
1
,
1
,
name
=
"transformer2"
),
).
cuda
()
NUM_ITERS
=
2000
x
=
torch
.
randn
(
1
,
1
,
1
).
cuda
()
Tests both with and without microbatching to ensure proper behavior in both scenarios.
"""
try
:
debug_api
.
initialize
(
config_file
=
configs_dir
+
"/test_switch_to_nondebug_mode.yaml"
,
feature_dirs
=
feature_dirs
,
)
import
transformer_engine.debug.features._test_dummy_feature
as
dummy_feature
# Reset counters
dummy_feature
.
_inspect_tensor_enabled_call_count
=
0
dummy_feature
.
_inspect_tensor_call_count
=
0
model
=
te
.
Linear
(
256
,
256
,
name
=
"test_linear"
).
cuda
()
x
=
torch
.
randn
(
8
,
256
,
256
).
cuda
()
# Run multiple iterations
for
i
in
range
(
20
):
if
use_microbatching
:
# Alternate between first and non-first microbatch
is_first_microbatch
=
i
%
2
==
0
y
=
model
(
x
,
is_first_microbatch
=
is_first_microbatch
)
else
:
# Run without specifying is_first_microbatch
y
=
model
(
x
)
y
=
model
(
x
)
y
.
sum
().
backward
()
y
.
sum
().
backward
()
debug_api
.
step
()
debug_api
.
step
()
torch
.
cuda
.
synchronize
()
time_start
=
time
.
time
()
# Verify inspect_tensor_enabled was called only once per tensor
for
i
in
range
(
NUM_ITERS
):
# (activation, weight, gradient, output, wgrad, dgrad)
y
=
model
(
x
)
enabled_call_count
=
dummy_feature
.
_inspect_tensor_enabled_call_count
y
.
sum
().
backward
()
microbatch_info
=
"with microbatching"
if
use_microbatching
else
"without microbatching"
if
debug_tools_initialized
:
assert
enabled_call_count
==
6
,
(
debug_api
.
step
()
f
"inspect_tensor_enabled was called
{
enabled_call_count
}
times (
{
microbatch_info
}
), "
torch
.
cuda
.
synchronize
()
"but should be called 6 times to check if feature is needed for each tensor "
time_end
=
time
.
time
()
"(activation, weight, gradient, output, wgrad, dgrad)"
)
# Verify inspect_tensor was never called - it should not be called if inspect_tensor_enabled returns (False, None)
inspect_call_count
=
dummy_feature
.
_inspect_tensor_call_count
assert
inspect_call_count
==
0
,
(
f
"inspect_tensor was called
{
inspect_call_count
}
times (
{
microbatch_info
}
), "
"but should never be called when inspect_tensor_enabled returns (False, None)"
)
finally
:
finally
:
if
debug_tools_initialized
:
debug_api
.
end_debug
()
debug_api
.
end_debug
()
TEDebugState
.
_reset
()
return
time_end
-
time_start
@
pytest
.
mark
.
parametrize
(
"layer"
,
[
"linear"
,
"transformer"
])
def
test_cpu_overhead
(
layer
,
configs_dir
,
feature_dirs
):
# runs one layer many times on very small tensor
# - gpu time should be negligible, so time should be dominated by cpu time.
# if layers does not invoke any feature in current iteration,
# then it changed into non-debug mode and should not have any non-negligible cpu overhead
# compared to layer without debug tools initialized.
with_debug_tools
=
_run_cpu_overhead
(
True
,
layer
,
configs_dir
,
feature_dirs
)
without_debug_tools
=
_run_cpu_overhead
(
False
,
layer
,
configs_dir
,
feature_dirs
)
print
(
f
"with_debug_tools:
{
with_debug_tools
}
s"
)
print
(
f
"without_debug_tools:
{
without_debug_tools
}
s"
)
assert
with_debug_tools
<
without_debug_tools
*
1.25
# 25% overhead margin
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
deleted
100644 → 0
View file @
e698a0a7
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
argparse
import
datetime
import
os
import
sys
import
torch
from
torch
import
nn
import
torch.distributed
as
dist
from
transformer_engine.common.recipe
import
(
DelayedScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
Format
,
Recipe
,
)
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch
import
(
QuantizedTensor
,
Float8Tensor
,
Float8BlockwiseQTensor
,
)
from
transformer_engine.pytorch.tensor
import
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
def
_get_raw_data
(
quantized_tensor
):
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
if
isinstance
(
quantized_tensor
,
Float8Tensor
):
assert
hasattr
(
quantized_tensor
,
"_data"
),
"Float8Tensor does not have _data attribute"
assert
quantized_tensor
.
_data
.
dtype
==
torch
.
uint8
,
"Float8Tensor _data must be uint8"
return
quantized_tensor
.
_data
elif
isinstance
(
quantized_tensor
,
Float8BlockwiseQTensor
):
assert
hasattr
(
quantized_tensor
,
"_rowwise_data"
),
"Float8BlockwiseQTensor does not have _rowwise_data attribute"
assert
(
quantized_tensor
.
_rowwise_data
.
dtype
==
torch
.
uint8
),
"Float8BlockwiseQTensor _rowwise_data must be uint8"
return
quantized_tensor
.
_rowwise_data
else
:
raise
ValueError
(
f
"Unsupported quantized tensor type:
{
type
(
quantized_tensor
)
}
"
)
class
MiniZero_1
:
"""A mini zero-1 optimizer implementation, just used for this test"""
def
__init__
(
self
,
weights
,
lr
,
dp_group
):
self
.
rank
=
dist
.
get_rank
(
dp_group
)
self
.
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
# [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer
self
.
offsets
=
[
0
]
for
weight
in
self
.
weights
:
self
.
offsets
.
append
(
self
.
offsets
[
-
1
]
+
weight
.
numel
())
# Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may
# not be the end range of the last weight.
if
self
.
offsets
[
-
1
]
%
self
.
world_size
!=
0
:
self
.
offsets
[
-
1
]
+=
self
.
world_size
-
self
.
offsets
[
-
1
]
%
self
.
world_size
self
.
master_weights
=
[]
# The start offset of the master weight in the weight
self
.
start_offsets
=
[]
# The overlapping area of the weight and this rank's local buffer
self
.
overlapping_areas
=
[]
# The start and end of this rank's local buffer in the global buffer
rank_start
=
self
.
offsets
[
-
1
]
//
self
.
world_size
*
self
.
rank
rank_end
=
rank_start
+
self
.
offsets
[
-
1
]
//
self
.
world_size
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
if
offset
>=
rank_end
or
(
offset
+
weight
.
numel
())
<=
rank_start
:
# This weight is not in this rank's local buffer
master_weight
=
None
start_offset
=
None
overlapping_area
=
None
else
:
overlapping_start
=
max
(
rank_start
,
offset
)
overlapping_end
=
min
(
rank_end
,
offset
+
weight
.
numel
())
length
=
overlapping_end
-
overlapping_start
start_offset
=
overlapping_start
-
offset
if
isinstance
(
weight
,
QuantizedTensor
):
# If weight is a FP8 tensor, we need to use the original high precision version
# to initialize the master weight.
high_precision_init_val
=
weight
.
get_high_precision_init_val
().
view
(
-
1
)
master_weight
=
high_precision_init_val
.
to
(
weight
.
device
).
float
()[
start_offset
:
start_offset
+
length
]
else
:
master_weight
=
(
weight
.
detach
().
view
(
-
1
).
float
()[
start_offset
:
start_offset
+
length
]
)
overlapping_area
=
(
overlapping_start
,
overlapping_end
)
self
.
master_weights
.
append
(
master_weight
)
self
.
start_offsets
.
append
(
start_offset
)
self
.
overlapping_areas
.
append
(
overlapping_area
)
# Create global buffer for grads reduce-scatter
self
.
grad_buffer
=
torch
.
empty
(
[
self
.
offsets
[
-
1
]],
dtype
=
torch
.
float32
,
device
=
weights
[
0
].
device
)
self
.
grad_buffer_slice
=
self
.
grad_buffer
[
rank_start
:
rank_end
]
# Create global buffer for weights all-gather
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
weight_buffer_dtype
=
torch
.
uint8
else
:
weight_buffer_dtype
=
weights
[
0
].
dtype
self
.
weight_buffer
=
torch
.
empty
(
[
self
.
offsets
[
-
1
]],
dtype
=
weight_buffer_dtype
,
device
=
weights
[
0
].
device
)
self
.
weight_buffer_slice
=
self
.
weight_buffer
[
rank_start
:
rank_end
]
def
step
(
self
):
# -----------------------------------------------------------------------------------------
# Step 1: Copy grads to the grad buffer
# -----------------------------------------------------------------------------------------
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
start
=
offset
end
=
offset
+
weight
.
numel
()
self
.
grad_buffer
[
start
:
end
].
copy_
(
weight
.
main_grad
.
view
(
-
1
))
# -----------------------------------------------------------------------------------------
# Step 2: Grads reduce-scatter
# -----------------------------------------------------------------------------------------
# Don't use reduce_scatter directly to explicitly control the reduce order.
# dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG,
# group=self.dp_group)
buffers
=
[
torch
.
empty_like
(
self
.
grad_buffer
)
for
_
in
range
(
self
.
world_size
)]
dist
.
all_gather
(
buffers
,
self
.
grad_buffer
,
group
=
self
.
dp_group
)
for
i
in
range
(
1
,
self
.
world_size
):
buffers
[
0
]
+=
buffers
[
i
]
rank_start
=
self
.
offsets
[
-
1
]
//
self
.
world_size
*
self
.
rank
rank_end
=
rank_start
+
self
.
offsets
[
-
1
]
//
self
.
world_size
self
.
grad_buffer_slice
.
copy_
(
buffers
[
0
][
rank_start
:
rank_end
])
self
.
grad_buffer_slice
/=
self
.
world_size
# -----------------------------------------------------------------------------------------
# Step 3: Update master weights
# -----------------------------------------------------------------------------------------
for
master_weight
,
overlapping_area
in
zip
(
self
.
master_weights
,
self
.
overlapping_areas
):
if
master_weight
is
None
:
# This weight's master weight is in other rank.
continue
grad
=
self
.
grad_buffer
[
overlapping_area
[
0
]
:
overlapping_area
[
1
]]
master_weight
-=
grad
*
self
.
lr
# -----------------------------------------------------------------------------------------
# Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight
# -----------------------------------------------------------------------------------------
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
# FP8 weights case
for
i
in
range
(
1
,
len
(
self
.
weights
)):
assert
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
)
cast_master_weights_to_fp8
(
self
.
weights
,
self
.
master_weights
,
self
.
start_offsets
,
self
.
dp_group
)
else
:
# BF16 weights case
for
weight
,
master_weight
,
start_offset
in
zip
(
self
.
weights
,
self
.
master_weights
,
self
.
start_offsets
):
if
master_weight
is
None
:
continue
start
=
start_offset
end
=
start_offset
+
master_weight
.
numel
()
weight
.
data
.
view
(
-
1
)[
start
:
end
].
copy_
(
master_weight
)
# -----------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -----------------------------------------------------------------------------------------
for
i
in
range
(
len
(
self
.
weights
)):
master_weight
=
self
.
master_weights
[
i
]
if
master_weight
is
None
:
continue
start_offset
=
self
.
start_offsets
[
i
]
if
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
):
weight
=
_get_raw_data
(
self
.
weights
[
i
])
else
:
weight
=
self
.
weights
[
i
]
weight_slice
=
weight
.
view
(
-
1
)[
start_offset
:
start_offset
+
master_weight
.
numel
()]
overlapping_start
,
overlapping_end
=
self
.
overlapping_areas
[
i
]
self
.
weight_buffer
[
overlapping_start
:
overlapping_end
].
copy_
(
weight_slice
)
# -----------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -----------------------------------------------------------------------------------------
dist
.
all_gather_into_tensor
(
self
.
weight_buffer
,
self
.
weight_buffer_slice
,
group
=
self
.
dp_group
)
# -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -----------------------------------------------------------------------------------------
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
start
=
offset
end
=
offset
+
weight
.
numel
()
if
isinstance
(
weight
,
QuantizedTensor
):
weight
=
_get_raw_data
(
weight
)
weight
.
view
(
-
1
).
data
.
copy_
(
self
.
weight_buffer
[
start
:
end
])
class
MiniOptimizer
:
def
__init__
(
self
,
weights
,
lr
,
dp_group
):
self
.
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
master_weights
=
[]
for
weight
in
self
.
weights
:
master_weights
.
append
(
weight
.
detach
().
float
())
self
.
master_weights
=
master_weights
def
step
(
self
):
for
weight
,
master_weight
in
zip
(
self
.
weights
,
self
.
master_weights
):
main_grad
=
weight
.
main_grad
# Don't use all-reduce directly to explicitly control the reduce order.
# dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group)
buffers
=
[
torch
.
empty_like
(
main_grad
)
for
_
in
range
(
self
.
world_size
)]
dist
.
all_gather
(
buffers
,
main_grad
,
group
=
self
.
dp_group
)
for
i
in
range
(
1
,
self
.
world_size
):
buffers
[
0
]
+=
buffers
[
i
]
main_grad
.
copy_
(
buffers
[
0
])
main_grad
/=
self
.
world_size
master_weight
-=
main_grad
*
self
.
lr
weight
.
data
.
copy_
(
master_weight
)
class
MiniFSDP
:
def
__init__
(
self
,
weights
,
lr
,
dp_group
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
# Flatten the weights and pad to align with world size
raw_data_list
=
[
_get_raw_data
(
w
).
view
(
-
1
)
if
isinstance
(
w
,
QuantizedTensor
)
else
w
.
view
(
-
1
)
for
w
in
weights
]
if
isinstance
(
weights
[
0
],
QuantizedTensor
):
raw_data_list
=
[
_get_raw_data
(
w
).
view
(
-
1
)
for
w
in
weights
]
else
:
raw_data_list
=
[
w
.
view
(
-
1
)
for
w
in
weights
]
self
.
flatten_weight
,
original_length
=
self
.
_flatten_tensors_with_pad
(
raw_data_list
)
# Split flattened weights into shards
self
.
local_weight_shard
=
torch
.
chunk
(
self
.
flatten_weight
,
world_size
)[
rank
]
self
.
local_main_grad_shard
=
torch
.
zeros_like
(
self
.
local_weight_shard
)
shard_size
=
self
.
flatten_weight
.
size
(
0
)
//
world_size
# Map original tensors to flattened indices
tensor_indices
=
[]
cumulative_length
=
0
for
tensor
in
raw_data_list
:
length
=
tensor
.
size
(
0
)
tensor_indices
.
append
((
cumulative_length
,
cumulative_length
+
length
))
cumulative_length
+=
length
# Build shard index mappings
self
.
weight_indices
=
[]
self
.
shard_indices
=
[]
for
idx
,
(
start
,
end
)
in
enumerate
(
tensor_indices
):
shard_start
=
rank
*
shard_size
shard_end
=
shard_start
+
shard_size
adjusted_end
=
min
(
shard_end
,
original_length
)
if
start
<=
adjusted_end
and
end
>=
shard_start
:
start_idx
=
max
(
start
,
shard_start
)
end_idx
=
min
(
end
,
adjusted_end
)
self
.
weight_indices
.
append
((
start_idx
-
start
,
end_idx
-
start
))
self
.
shard_indices
.
append
((
start_idx
-
shard_start
,
end_idx
-
shard_start
))
else
:
self
.
weight_indices
.
append
((
None
,
None
))
self
.
shard_indices
.
append
((
None
,
None
))
if
isinstance
(
weights
[
idx
],
QuantizedTensor
):
replace_raw_data
(
weights
[
idx
],
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
)
else
:
weights
[
idx
].
data
=
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
# Initialize local model weights and high-precision master weights
self
.
local_weights
=
[]
self
.
master_weights
=
[]
for
i
,
weight
in
enumerate
(
self
.
weights
):
weight_start
,
weight_end
=
self
.
weight_indices
[
i
]
shard_start
,
shard_end
=
self
.
shard_indices
[
i
]
if
shard_start
is
not
None
and
shard_end
is
not
None
:
local_weight_shard
=
self
.
local_weight_shard
[
shard_start
:
shard_end
]
self
.
local_weights
.
append
(
local_weight_shard
)
if
isinstance
(
weight
,
QuantizedTensor
):
high_precision_init_val
=
weight
.
get_high_precision_init_val
().
view
(
-
1
)
master_weight_shard
=
high_precision_init_val
.
to
(
weight
.
device
).
float
()[
weight_start
:
weight_end
]
else
:
master_weight_shard
=
weight
.
detach
().
view
(
-
1
).
float
()[
weight_start
:
weight_end
]
self
.
master_weights
.
append
(
master_weight_shard
)
else
:
self
.
local_weights
.
append
(
None
)
self
.
master_weights
.
append
(
None
)
setattr
(
weight
,
"main_grad"
,
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
)
def
_flatten_tensors_with_pad
(
self
,
tensors
):
"""
Flatten the list of tensors and pad them to align with the world size.
Args:
tensors (list): List of tensors to flatten.
Returns:
tuple: Flattened tensor and its original length before padding.
"""
world_size
=
dist
.
get_world_size
(
self
.
dp_group
)
flatten_tensor
=
torch
.
cat
(
tensors
)
original_length
=
flatten_tensor
.
size
(
0
)
padding_needed
=
(
world_size
-
original_length
%
world_size
)
%
world_size
if
padding_needed
>
0
:
flatten_tensor
=
torch
.
cat
(
[
flatten_tensor
,
torch
.
zeros
(
padding_needed
,
dtype
=
flatten_tensor
.
dtype
)]
)
return
flatten_tensor
,
original_length
def
zero_grad
(
self
):
for
weight
in
self
.
weights
:
weight
.
grad
=
None
weight
.
main_grad
.
zero_
()
def
step
(
self
):
"""
Perform an optimization step for the distributed sharded model.
This method includes:
1. Gradient reduce-scatter: Synchronize gradients across all processes.
2. Master weight update: Update high-precision master weights using local gradients.
3. Precision casting: Cast updated master weights to FP8 or BF16 precision.
4. Weight synchronization: All-gather updated weights across all processes.
Returns:
None
"""
# Step 1: Reduce-scatter the gradients
main_grad_buffer
,
_
=
self
.
_flatten_tensors_with_pad
(
[
weight
.
main_grad
.
view
(
-
1
)
for
weight
in
self
.
weights
]
)
main_grad_buffer
=
main_grad_buffer
.
to
(
self
.
local_main_grad_shard
.
dtype
)
dist
.
reduce_scatter_tensor
(
self
.
local_main_grad_shard
,
main_grad_buffer
,
group
=
self
.
dp_group
)
# Step 2: Update the master weights
for
weight
,
master_weight
,
(
shard_start
,
shard_end
)
in
zip
(
self
.
weights
,
self
.
master_weights
,
self
.
shard_indices
):
if
master_weight
is
None
:
continue
# Extract the local gradient shard for this weight
grad
=
self
.
local_main_grad_shard
[
shard_start
:
shard_end
]
# Update the master weight using gradient descent
master_weight
-=
grad
*
self
.
lr
# Step 3: Cast master weights to FP8 or BF16 precision
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
local_weights
=
[]
for
local_weight
in
self
.
local_weights
:
if
local_weight
is
None
:
local_weights
.
append
(
None
)
continue
local_weights
.
append
(
local_weight
)
cast_master_weights_to_fp8
(
self
.
weights
,
self
.
master_weights
,
[
idx
[
0
]
for
idx
in
self
.
weight_indices
],
self
.
dp_group
,
local_weights
,
)
else
:
for
weight
,
master_weight
in
zip
(
self
.
local_weights
,
self
.
master_weights
):
if
master_weight
is
None
:
continue
# Copy updated master weights to local weights
weight
.
data
.
copy_
(
master_weight
)
# Step 4: All-gather updated weights across processes
dist
.
all_gather_into_tensor
(
self
.
flatten_weight
,
self
.
local_weight_shard
,
group
=
self
.
dp_group
)
def
_test_fsdp_cast_master_weights_to_fp8
(
quantization
,
dp_group
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
# Configuration constants
NUM_STEPS
=
100
SEED
=
12345
torch
.
manual_seed
(
SEED
)
torch
.
cuda
.
manual_seed
(
SEED
)
mock_groups
=
[
dist
.
new_group
(
ranks
=
[
i
])
for
i
in
range
(
world_size
)]
mock_group
=
mock_groups
[
rank
]
linear_kwargs
=
{
"params_dtype"
:
torch
.
bfloat16
,
"bias"
:
False
,
"fuse_wgrad_accumulation"
:
False
,
}
# Create model with FP8 weights
with
te
.
quantized_model_init
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
high_precision_init_val
=
w_fp8
.
get_high_precision_init_val
()
w
.
data
.
copy_
(
high_precision_init_val
)
optimizer_fp8
=
MiniFSDP
([
w
for
w
in
model_fp8
.
parameters
()],
10.0
,
dp_group
)
optimizer
=
MiniFSDP
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
_
in
range
(
100
):
optimizer_fp8
.
zero_grad
()
optimizer
.
zero_grad
()
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
amax_reduction_group
=
mock_group
,
):
y_fp8
=
model_fp8
(
x
)
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
amax_reduction_group
=
mock_group
,
):
y
=
model
(
x
)
targets
=
[
torch
.
randn_like
(
y
)
for
_
in
range
(
world_size
)]
# Choose based on rank to make sure the targets of different ranks are different.
target
=
targets
[
rank
]
loss_fp8
=
nn
.
MSELoss
()(
y_fp8
,
target
)
loss
=
nn
.
MSELoss
()(
y
,
target
)
loss_fp8
.
backward
()
loss
.
backward
()
optimizer_fp8
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
print
(
f
"✅ Successfully validated FSDP
{
NUM_STEPS
}
training steps with"
f
"
{
quantization
}
quantization"
)
def
_test_zero_1
(
dp_group
):
"""Make sure the implementation of zero-1 optimizer is correct"""
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
torch
.
manual_seed
(
12345
)
torch
.
cuda
.
manual_seed
(
12345
)
weights
=
[
torch
.
randn
(
256
*
256
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
torch
.
randn
(
256
*
256
*
3
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
torch
.
randn
(
256
*
256
*
2
-
1
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
]
weights_1
=
weights
weights_2
=
[
weight
.
clone
()
for
weight
in
weights
]
lr
=
1.0
optimizer_1
=
MiniZero_1
(
weights_1
,
lr
,
dp_group
)
optimizer_2
=
MiniOptimizer
(
weights_2
,
lr
,
dp_group
)
for
_
in
range
(
100
):
for
w1
,
w2
in
zip
(
weights_1
,
weights_2
):
main_grads
=
[
torch
.
randn_like
(
w1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the grads of different ranks are different.
main_grad
=
main_grads
[
rank
]
w1
.
main_grad
=
main_grad
w2
.
main_grad
=
main_grad
optimizer_1
.
step
()
optimizer_2
.
step
()
for
w1
,
w2
in
zip
(
weights_1
,
weights_2
):
torch
.
testing
.
assert_close
(
w1
,
w2
,
atol
=
0
,
rtol
=
0
)
def
quantization_recipe
(
quantization
)
->
Recipe
:
"""Quantization recipe setup"""
fp8_format
=
Format
.
HYBRID
if
quantization
==
"fp8"
:
return
DelayedScaling
(
fp8_format
=
fp8_format
,
amax_history_len
=
32
,
amax_compute_algo
=
"max"
)
elif
quantization
==
"fp8_cs"
:
return
Float8CurrentScaling
(
fp8_format
=
fp8_format
)
elif
quantization
==
"fp8_block"
:
return
Float8BlockScaling
(
fp8_format
=
fp8_format
)
else
:
raise
ValueError
(
f
"Unsupported quantization:
{
quantization
}
"
)
def
_test_cast_master_weights_to_fp8
(
quantization
,
dp_group
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
torch
.
manual_seed
(
12345
)
torch
.
cuda
.
manual_seed
(
12345
)
mock_groups
=
[
dist
.
new_group
(
ranks
=
[
i
])
for
i
in
range
(
world_size
)]
mock_group
=
mock_groups
[
rank
]
linear_kwargs
=
{
"params_dtype"
:
torch
.
bfloat16
,
"bias"
:
False
,
"fuse_wgrad_accumulation"
:
False
}
# Create model with FP8 weights
with
te
.
quantized_model_init
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
high_precision_init_val
=
w_fp8
.
get_high_precision_init_val
()
w
.
data
.
copy_
(
high_precision_init_val
)
# Allocate main_grads for each weight
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
w_fp8
.
main_grad
=
torch
.
zeros_like
(
w_fp8
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
w
.
main_grad
=
torch
.
zeros_like
(
w
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
optimizer_fp8
=
MiniZero_1
([
w
for
w
in
model_fp8
.
parameters
()],
10.0
,
dp_group
)
optimizer
=
MiniZero_1
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
i
in
range
(
100
):
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
w_fp8
.
main_grad
.
zero_
()
w
.
main_grad
.
zero_
()
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
amax_reduction_group
=
mock_group
,
):
y_fp8
=
model_fp8
(
x
)
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
amax_reduction_group
=
mock_group
,
):
y
=
model
(
x
)
targets
=
[
torch
.
randn_like
(
y
)
for
_
in
range
(
world_size
)]
# Choose based on rank to make sure the targets of different ranks are different.
target
=
targets
[
rank
]
loss_fp8
=
nn
.
MSELoss
()(
y_fp8
,
target
)
loss
=
nn
.
MSELoss
()(
y
,
target
)
loss_fp8
.
backward
()
loss
.
backward
()
optimizer_fp8
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
def
main
(
argv
=
None
,
namespace
=
None
):
WORLD_RANK
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
WORLD_SIZE
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
LOCAL_RANK
=
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
LOCAL_SIZE
=
int
(
os
.
getenv
(
"LOCAL_WORLD_SIZE"
,
"1"
))
assert
WORLD_SIZE
==
LOCAL_SIZE
# this test supports only 1 node
assert
LOCAL_SIZE
<=
torch
.
cuda
.
device_count
()
dist_init_kwargs
=
{
"backend"
:
"nccl"
,
"rank"
:
WORLD_RANK
,
"world_size"
:
WORLD_SIZE
,
"timeout"
:
datetime
.
timedelta
(
seconds
=
30
),
}
dist_init_kwargs
[
"init_method"
]
=
"env://"
dist_init_kwargs
[
"device_id"
]
=
torch
.
device
(
f
"cuda:
{
LOCAL_RANK
}
"
)
assert
dist
.
is_nccl_available
()
torch
.
cuda
.
set_device
(
LOCAL_RANK
)
dist
.
init_process_group
(
**
dist_init_kwargs
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
default
=
None
,
choices
=
[
"fp8"
,
"fp8_cs"
,
"fp8_block"
]
)
args
=
parser
.
parse_args
(
argv
,
namespace
)
dp_group
=
dist
.
new_group
(
backend
=
"nccl"
)
_test_zero_1
(
dp_group
)
_test_cast_master_weights_to_fp8
(
args
.
quantization
,
dp_group
)
_test_fsdp_cast_master_weights_to_fp8
(
args
.
quantization
,
dp_group
)
dist
.
destroy_process_group
()
return
0
if
__name__
==
"__main__"
:
sys
.
exit
(
main
())
tests/pytorch/distributed/run_fsdp2_model.py
View file @
c1a1c04e
...
@@ -8,58 +8,74 @@ import os
...
@@ -8,58 +8,74 @@ import os
import
sys
import
sys
import
argparse
import
argparse
import
transformer_engine.pytorch
as
te
from
transformer_engine.common.recipe
import
(
Format
,
DelayedScaling
,
Float8CurrentScaling
,
MXFP8BlockScaling
,
)
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed.tensor
import
DTensor
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
,
optim
from
torch
import
nn
,
optim
from
torch.distributed
import
DeviceMesh
from
torch.distributed
import
DeviceMesh
from
torch.distributed._composable.fsdp
import
fully_shard
from
torch.distributed._composable.fsdp
import
fully_shard
from
torch.distributed.device_mesh
import
init_device_mesh
from
torch.distributed.device_mesh
import
init_device_mesh
from
transformer_engine.pytorch
import
QuantizedTensor
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
import
transformer_engine.pytorch
as
te
LOCAL_RANK
=
None
from
transformer_engine.common.recipe
import
Format
,
DelayedScaling
class
SimpleNet
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
hidden_size
,
output_size
):
super
(
SimpleNet
,
self
).
__init__
()
self
.
fc1
=
te
.
Linear
(
input_size
,
hidden_size
)
self
.
fc2
=
te
.
Linear
(
hidden_size
,
output_size
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
x
def
save_custom_attrs
(
module
):
custom_attrs
=
{}
for
name
,
param
in
module
.
named_parameters
():
attrs
=
vars
(
param
)
custom_attrs
[
name
]
=
{
k
:
v
for
k
,
v
in
attrs
.
items
()}
return
custom_attrs
def
restore_custom_attrs
(
module
,
custom_attrs
):
def
dist_print
(
msg
):
for
name
,
param
in
module
.
named_parameters
():
if
LOCAL_RANK
==
0
:
if
name
in
custom_attrs
:
print
(
msg
)
for
attr_name
,
attr_value
in
custom_attrs
[
name
].
items
():
setattr
(
param
,
attr_name
,
attr_value
)
def
_parse_args
(
argv
=
None
,
namespace
=
None
):
def
_parse_args
(
argv
=
None
,
namespace
=
None
):
parser
=
argparse
.
ArgumentParser
(
description
=
"Toy example for debugging fully_shard()"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Toy example for debugging fully_shard()"
)
parser
.
add_argument
(
"--input-size"
,
type
=
int
,
default
=
2048
,
help
=
"Input size for the model"
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
,
help
=
"Number of attn. heads"
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
2048
,
help
=
"Hidden layer size"
)
parser
.
add_argument
(
"--head-dim"
,
type
=
int
,
default
=
64
,
help
=
"Attention head size"
)
parser
.
add_argument
(
"--output-size"
,
type
=
int
,
default
=
2048
,
help
=
"Output size for the model"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
,
help
=
"Batch size of input"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
2048
,
help
=
"Output size for the model"
)
parser
.
add_argument
(
"--seq-length"
,
type
=
int
,
default
=
128
,
help
=
"Sequence length of input"
)
parser
.
add_argument
(
"--params-dtype"
,
type
=
str
,
default
=
"float32"
,
help
=
"Parameter dtype."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--fp8-init"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Initialize primary weights in FP8."
"--fp8-init"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Initialize primary weights in FP8."
)
)
parser
.
add_argument
(
"--recipe"
,
type
=
str
,
default
=
"mx_fp8_block_scaling"
,
help
=
"Quantizer type."
,
choices
=
[
"delayed_scaling"
,
"current_scaling"
,
"mx_fp8_block_scaling"
],
)
parser
.
add_argument
(
"--layer-type"
,
type
=
str
,
default
=
"TransformerLayer"
,
choices
=
[
"Linear"
,
"LayerNormLinear"
,
"LayerNormMLP"
,
"MultiheadAttention"
,
"TransformerLayer"
,
],
help
=
"Transformer Engine layer type"
,
)
parser
.
add_argument
(
"--num-layers"
,
type
=
int
,
default
=
4
,
help
=
"Number of layers in the model"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--iter"
,
type
=
int
,
default
=
10
,
help
=
"Number of iterations for forward pass"
"--iter"
,
type
=
int
,
default
=
10
,
help
=
"Number of iterations for forward pass"
)
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"meta"
,
help
=
"Device to run the model on."
,
choices
=
[
"cuda"
,
"meta"
],
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"RNG seed."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"RNG seed."
)
# Adding hsdp_dim as a list argument, comma-separated
# Adding hsdp_dim as a list argument, comma-separated
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None):
...
@@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None):
return
args
return
args
sub_modules_to_wrap
=
[
te
.
Linear
]
## Methods to help initialize the TE model in an FSDP2 setting
## with required configurations based on command line args
def
get_te_layer_from_string
(
layer_name
):
te_layer_types
=
[
te
.
Linear
,
te
.
LayerNormLinear
,
te
.
LayerNormMLP
,
te
.
MultiheadAttention
,
te
.
TransformerLayer
,
]
te_layer_names
=
[
layer
.
__name__
for
layer
in
te_layer_types
]
te_layer_map
=
dict
(
zip
([
name
.
lower
()
for
name
in
te_layer_names
],
te_layer_types
))
if
layer_name
.
lower
()
not
in
te_layer_map
.
keys
():
raise
argparse
.
ArgumentTypeError
(
f
'"
{
layer_name
}
" is not a valid Transformer Engine layer, '
f
"please choose layer from
{
te_layer_names
}
."
)
return
te_layer_map
[
layer_name
.
lower
()]
def
get_recipe_from_string
(
recipe
,
fp8_format
=
Format
.
HYBRID
):
if
recipe
==
"delayed_scaling"
:
return
DelayedScaling
(
fp8_format
=
fp8_format
,
amax_history_len
=
16
,
amax_compute_algo
=
"max"
)
elif
recipe
==
"current_scaling"
:
return
Float8CurrentScaling
(
fp8_format
=
fp8_format
)
elif
recipe
==
"mx_fp8_block_scaling"
:
return
MXFP8BlockScaling
(
fp8_format
=
fp8_format
)
else
:
raise
ValueError
(
f
"Unknown quantizer type:
{
recipe
}
"
)
def
init_te_model
(
config
):
hidden_size
=
config
.
num_heads
*
config
.
head_dim
args
=
[
hidden_size
,
hidden_size
]
inp_shape
=
[
config
.
seq_length
,
config
.
batch_size
,
hidden_size
]
out_shape
=
[
config
.
seq_length
,
config
.
batch_size
,
hidden_size
]
if
config
.
params_dtype
==
"float16"
:
params_dtype
=
torch
.
float16
elif
config
.
params_dtype
==
"bfloat16"
:
params_dtype
=
torch
.
bfloat16
else
:
params_dtype
=
torch
.
float32
kwargs
=
{
"params_dtype"
:
params_dtype
,
}
kwargs
[
"device"
]
=
config
.
device
layer_type
=
get_te_layer_from_string
(
config
.
layer_type
)
# We are creating model in a way so that we can test both reshard_after_forward=True/False cases.
# more details below.
if
layer_type
in
[
te
.
MultiheadAttention
,
te
.
TransformerLayer
]:
# For this case, we are creating a model that resemebles production use-cases
# wherein there are mltiple TransformerLayers in the model. And we would need
# to shard each transformer layer. Since each transformer layer is not a root module,
# FSDP2's fully_shard assigns reshard_after_forward=False for all parameters of the model.
args
[
1
]
*=
4
# FFN hidden size
args
.
append
(
config
.
num_heads
)
kwargs
[
"fuse_qkv_params"
]
=
True
if
layer_type
is
te
.
MultiheadAttention
:
kwargs
[
"input_layernorm"
]
=
True
model
=
nn
.
Sequential
(
*
[
layer_type
(
*
args
,
**
kwargs
)
for
_
in
range
(
config
.
num_layers
)])
elif
layer_type
==
te
.
LayerNormLinear
:
# For this case, we are creating a model with just one LayerNormLinear layer
# so that the model itself is a root module, and FSDP2's fully_shard assigns
# reshard_after_forward=True for the parameters of these model.
args
[
1
]
*=
3
# QKV projection
out_shape
[
-
1
]
*=
3
model
=
layer_type
(
*
args
,
**
kwargs
)
else
:
model
=
layer_type
(
*
args
,
**
kwargs
)
return
model
,
inp_shape
,
out_shape
def
get_device_mesh
(
world_size
,
sharding_dims
):
dist_print
(
f
"sharding-dims:
{
sharding_dims
}
"
)
device_ids
=
list
(
range
(
world_size
))
if
sharding_dims
is
None
:
# FSDP
mesh
=
DeviceMesh
(
"cuda"
,
device_ids
)
elif
len
(
sharding_dims
)
==
1
:
assert
sharding_dims
[
0
]
==
world_size
mesh
=
DeviceMesh
(
"cuda"
,
device_ids
)
elif
len
(
sharding_dims
)
==
2
:
# HSDP
assert
sharding_dims
[
0
]
*
sharding_dims
[
1
]
==
world_size
mesh
=
init_device_mesh
(
"cuda"
,
(
sharding_dims
[
0
],
sharding_dims
[
1
]),
mesh_dim_names
=
(
"replicate"
,
"shard"
),
)
else
:
assert
False
return
mesh
def
shard_model_with_fsdp2
(
model
,
mesh
):
for
child
in
model
.
children
():
fully_shard
(
child
,
mesh
=
mesh
)
fully_shard
(
model
,
mesh
=
mesh
)
return
model
#### Methods to save the custom attributes of QuantizedTensors before sharding
#### them with FSDP2, and restore them after sharding.
def
save_custom_attrs
(
module
):
custom_attrs
=
{}
for
name
,
param
in
module
.
named_parameters
():
if
isinstance
(
param
,
QuantizedTensor
):
# Ignore FP8 metadata attributes. Otherwise we will save duplicate copies
# for data/transpose FP8 tensors on top of FP8 tensors that FSDP2 will save.
ignore_keys
=
[
key
for
key
in
param
.
__dict__
.
keys
()
if
key
.
startswith
(
"_"
)]
else
:
ignore_keys
=
[]
attrs
=
vars
(
param
)
custom_attrs
[
name
]
=
{
k
:
v
for
k
,
v
in
attrs
.
items
()
if
k
not
in
ignore_keys
}
return
custom_attrs
def
restore_custom_attrs
(
module
,
custom_attrs
):
for
name
,
param
in
module
.
named_parameters
():
if
name
in
custom_attrs
:
for
attr_name
,
attr_value
in
custom_attrs
[
name
].
items
():
setattr
(
param
,
attr_name
,
attr_value
)
@
torch
.
no_grad
()
def
test_fp8_fsdp2_allgather
(
model
):
# Do manual allgather in fp32 and match against fp8 allgather done
# with fsdp2
# FP32 manual weight allgather
fp32_allgathered_params
=
{}
for
name
,
param
in
model
.
named_parameters
():
assert
isinstance
(
param
,
DTensor
)
local_tensor
=
param
.
_local_tensor
device_mesh
=
param
.
device_mesh
dist_group
=
(
device_mesh
.
get_group
(
mesh_dim
=
"shard"
)
if
device_mesh
.
ndim
>
1
else
device_mesh
.
get_group
()
)
# Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch
# for local_tensor will go down the dequantization route.
gathered_tensor
=
[
torch
.
zeros_like
(
local_tensor
)
for
_
in
range
(
dist
.
get_world_size
(
group
=
dist_group
))
]
dist
.
all_gather
(
gathered_tensor
,
local_tensor
.
dequantize
(),
group
=
dist_group
)
full_tensor
=
torch
.
cat
(
gathered_tensor
,
dim
=
0
)
fp32_allgathered_params
[
name
]
=
full_tensor
# FP8 allgather using FSDP2
for
module
in
model
.
modules
():
# Not all modules are wrapped/sharded with FSDP2.
if
hasattr
(
module
,
"unshard"
):
module
.
unshard
()
# Make sure allgathered parameters match exactly
for
name
,
param
in
model
.
named_parameters
():
assert
torch
.
allclose
(
param
.
dequantize
(),
fp32_allgathered_params
[
name
])
# Revert model to original sharded state
for
module
in
model
.
modules
():
# Not all modules are wrapped/sharded with FSDP2.
if
hasattr
(
module
,
"reshard"
):
module
.
reshard
()
def
_train
(
args
):
def
_train
(
args
):
global
LOCAL_RANK
assert
"TORCHELASTIC_RUN_ID"
in
os
.
environ
assert
"TORCHELASTIC_RUN_ID"
in
os
.
environ
WORLD_RANK
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
WORLD_RANK
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
WORLD_SIZE
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
WORLD_SIZE
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
...
@@ -103,77 +279,69 @@ def _train(args):
...
@@ -103,77 +279,69 @@ def _train(args):
# FP8 Configuration
# FP8 Configuration
fp8_format
=
Format
.
HYBRID
fp8_format
=
Format
.
HYBRID
fp8_recipe
=
DelayedScaling
(
fp8_format
=
fp8_format
,
amax_history_len
=
16
,
amax_compute_algo
=
"max"
)
fp8_recipe
=
get_recipe_from_string
(
args
.
recipe
,
fp8_format
)
build_model_context_args
=
{}
if
not
args
.
fp8_init
:
if
not
args
.
fp8_init
:
# Build model context (FP8 init)
# Build model context (FP8 init)
build_model_context
=
nullcontext
build_model_context
=
nullcontext
build_model_context_args
=
{}
else
:
from
transformer_engine.pytorch
import
fp8_model_init
from
transformer_engine.pytorch
import
quantized_model_init
build_model_context
=
quantized
_model_init
build_model_context
=
fp8
_model_init
build_model_context_args
[
"enabled"
]
=
True
build_model_context_args
[
"enabled"
]
=
True
build_model_context_args
[
"recipe"
]
=
fp8_recipe
# Build the model with the specified context
dist_print
(
f
"Memory before model init:
{
torch
.
cuda
.
memory_allocated
(
device
)
/
1e6
}
MB"
)
# Create the model on the meta/cuda device as per args
with
build_model_context
(
**
build_model_context_args
):
with
build_model_context
(
**
build_model_context_args
):
model
=
SimpleNet
(
args
.
input_size
,
args
.
hidden_size
,
args
.
output_size
)
model
,
inp_shape
,
out_shape
=
init_te_model
(
args
)
else
:
dist_print
(
model
=
SimpleNet
(
args
.
input_size
,
args
.
hidden_size
,
args
.
output_size
)
f
"Memory after model init on device
{
args
.
device
}
:"
# Move the model to the correct device
f
"
{
torch
.
cuda
.
memory_allocated
(
device
)
/
1e6
}
MB"
)
model
.
to
(
device
)
if
LOCAL_RANK
==
0
:
print
(
f
"Rank
{
LOCAL_RANK
}
: Applying FSDP fully_shard() to the model..."
)
# Creating a DeviceMesh for fully_shard
# Creating a DeviceMesh for fully_shard
world_size
=
int
(
WORLD_SIZE
)
world_size
=
int
(
WORLD_SIZE
)
device_ids
=
list
(
range
(
world_size
))
if
LOCAL_RANK
==
0
:
print
(
f
"sharding-dims:
{
args
.
sharding_dims
}
"
)
# Setup the sharding mesh for FSDP/HSDP
# Setup the sharding mesh for FSDP/HSDP
if
args
.
sharding_dims
==
None
:
# FSDP
mesh
=
get_device_mesh
(
world_size
,
args
.
sharding_dims
)
mesh
=
DeviceMesh
(
"cuda"
,
device_ids
)
elif
len
(
args
.
sharding_dims
)
==
1
:
assert
args
.
sharding_dims
[
0
]
==
device_ids
[
-
1
]
+
1
mesh
=
DeviceMesh
(
"cuda"
,
device_ids
)
elif
len
(
args
.
sharding_dims
)
==
2
:
# HSDP
assert
args
.
sharding_dims
[
0
]
*
args
.
sharding_dims
[
1
]
==
device_ids
[
-
1
]
+
1
mesh
=
init_device_mesh
(
"cuda"
,
(
args
.
sharding_dims
[
0
],
args
.
sharding_dims
[
1
]),
mesh_dim_names
=
(
"replicate"
,
"shard"
),
)
else
:
assert
False
# Apply FSDP/HSDP
custom_attrs
=
save_custom_attrs
(
model
)
custom_attrs
=
save_custom_attrs
(
model
)
for
sub_module
in
model
.
modules
():
model
=
shard_model_with_fsdp2
(
model
,
mesh
)
if
any
(
isinstance
(
sub_module
,
sub_module_to_wrap
)
for
sub_module_to_wrap
in
sub_modules_to_wrap
):
fully_shard
(
sub_module
,
mesh
=
mesh
)
fully_shard
(
model
,
mesh
=
mesh
)
restore_custom_attrs
(
model
,
custom_attrs
)
restore_custom_attrs
(
model
,
custom_attrs
)
# model now has DTensors as its parameters
if
args
.
device
==
"meta"
:
# After FSDP2 has been applied, materialize and initialize the sharded parameters
# TE base.py's reset_parameters() handles DTensors with FP8 initialization
for
module
in
model
.
modules
():
if
hasattr
(
module
,
"reset_parameters"
):
module
.
reset_parameters
()
dist_print
(
f
" Sharded parameters materialized and initialized on cuda device."
)
dist_print
(
f
"FSDP2 model in cuda, memory allocated:
{
torch
.
cuda
.
memory_allocated
(
device
)
/
1e6
}
MB"
)
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
for
iteration
in
range
(
args
.
iter
):
for
iteration
in
range
(
args
.
iter
):
# Zero the parameter gradients
# Zero the parameter gradients
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
input_data
=
torch
.
randn
(
args
.
batch_size
,
args
.
input_size
).
to
(
device
)
input_data
=
torch
.
randn
(
inp_shape
).
to
(
device
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
fp8_recipe
):
output
=
model
(
input_data
)
output
=
model
(
input_data
)
target
=
torch
.
randn
(
args
.
batch_size
,
args
.
output_siz
e
).
to
(
device
)
target
=
torch
.
randn
(
out_shap
e
).
to
(
device
)
loss
=
F
.
mse_loss
(
output
,
target
)
loss
=
F
.
mse_loss
(
output
,
target
)
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
if
LOCAL_RANK
==
0
:
dist_print
(
f
"Iteration
{
iteration
}
completed with loss
{
loss
.
item
()
}
"
)
print
(
f
"Rank
{
LOCAL_RANK
}
: Iteration
{
iteration
}
completed."
)
# Some of the FSDP states are lazy initialized during FSDP forward pass
# so testing fp8 allgather at the end of the training loop.
if
args
.
fp8_init
:
test_fp8_fsdp2_allgather
(
model
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
if
LOCAL_RANK
==
0
:
print
(
f
"Rank
{
LOCAL_RANK
}
: Done..."
)
return
0
return
0
...
...
tests/pytorch/distributed/run_numerics_exact.py
View file @
c1a1c04e
...
@@ -22,8 +22,8 @@ from transformer_engine.common.recipe import (
...
@@ -22,8 +22,8 @@ from transformer_engine.common.recipe import (
)
)
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch.constants
import
NVFP4_BLOCK_SCALING_SIZE
from
transformer_engine.pytorch.constants
import
NVFP4_BLOCK_SCALING_SIZE
from
transformer_engine.pytorch.
experimental
import
quantization_nvfp4
from
transformer_engine.pytorch.
custom_recipes
import
quantization_nvfp4
from
transformer_engine.pytorch.
experimental
import
utils
from
transformer_engine.pytorch.
custom_recipes
import
utils
from
run_layer_with_overlap
import
_compare_tensors
from
run_layer_with_overlap
import
_compare_tensors
...
@@ -486,7 +486,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
...
@@ -486,7 +486,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
sequence_parallel (bool): Enable sequence parallelism if True.
sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer.
kwargs (dict): Additional arguments for the linear layer.
QUANTIZATION options: nvfp4 <=>
experimental
nvfp4 as a reference
QUANTIZATION options: nvfp4 <=>
custom
nvfp4 as a reference
"""
"""
params_dtype
=
torch
.
bfloat16
params_dtype
=
torch
.
bfloat16
use_bias
=
kwargs
.
get
(
"bias"
,
True
)
use_bias
=
kwargs
.
get
(
"bias"
,
True
)
...
...
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
View file @
c1a1c04e
...
@@ -2,39 +2,746 @@
...
@@ -2,39 +2,746 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
import
argparse
import
datetime
import
os
import
os
import
subprocess
import
subprocess
from
pathlib
import
Path
import
sys
import
pathlib
import
pytest
import
pytest
import
torch
import
torch
from
transformer_engine.pytorch
import
is_fp8_available
,
is_fp8_block_scaling_available
# NVTE_DISABLE_NVRTC=1 NVTE_INT8_SIM_FP8=1 torchrun --nproc_per_node=4 run_cast_master_weights_to_fp8.py --quantization fp8_block
# NVTE_DISABLE_NVRTC=1 NVTE_INT8_SIM_FP8=1 torchrun --nproc_per_node=4 run_cast_master_weights_to_fp8.py --quantization fp8_block
i
f
torch
.
cuda
.
device_count
()
<
2
:
f
rom
torch
import
nn
pytest
.
skip
(
"cast_master_weights_to_fp8 test ne
ed
s
a
t least 2 GPUs."
)
import
torch.distribut
ed
a
s
dist
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
from
transformer_engine.common.recipe
import
(
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
is_fp8_block_scaling_available
(
DelayedScaling
,
return_reason
=
True
Float8CurrentScaling
,
Float8BlockScaling
,
Format
,
Recipe
,
)
)
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch
import
(
is_fp8_available
,
is_fp8_block_scaling_available
,
QuantizedTensor
,
Float8Tensor
,
Float8BlockwiseQTensor
,
)
from
transformer_engine.pytorch.tensor
import
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor.utils
import
post_all_gather_processing
,
replace_raw_data
def
_get_quantization_recipe
(
quantization
)
->
Recipe
:
"""Quantization recipe setup"""
fp8_format
=
Format
.
HYBRID
if
quantization
==
"fp8"
:
return
DelayedScaling
(
fp8_format
=
fp8_format
,
amax_history_len
=
32
,
amax_compute_algo
=
"max"
)
elif
quantization
==
"fp8_cs"
:
return
Float8CurrentScaling
(
fp8_format
=
fp8_format
)
elif
quantization
==
"fp8_block"
:
return
Float8BlockScaling
(
fp8_format
=
fp8_format
)
else
:
raise
ValueError
(
f
"Unsupported quantization:
{
quantization
}
"
)
def
_get_raw_data
(
quantized_tensor
):
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
if
isinstance
(
quantized_tensor
,
Float8Tensor
):
assert
hasattr
(
quantized_tensor
,
"_data"
),
"Float8Tensor does not have _data attribute"
assert
quantized_tensor
.
_data
.
dtype
==
torch
.
uint8
,
"Float8Tensor _data must be uint8"
return
quantized_tensor
.
_data
elif
isinstance
(
quantized_tensor
,
Float8BlockwiseQTensor
):
assert
hasattr
(
quantized_tensor
,
"_rowwise_data"
),
"Float8BlockwiseQTensor does not have _rowwise_data attribute"
assert
(
quantized_tensor
.
_rowwise_data
.
dtype
==
torch
.
uint8
),
"Float8BlockwiseQTensor _rowwise_data must be uint8"
return
quantized_tensor
.
_rowwise_data
else
:
raise
ValueError
(
f
"Unsupported quantized tensor type:
{
type
(
quantized_tensor
)
}
"
)
class
MiniOptimizer
:
def
__init__
(
self
,
weights
,
lr
,
dp_group
):
self
.
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
master_weights
=
[]
for
weight
in
self
.
weights
:
master_weights
.
append
(
weight
.
detach
().
float
())
self
.
master_weights
=
master_weights
def
step
(
self
):
for
weight
,
master_weight
in
zip
(
self
.
weights
,
self
.
master_weights
):
main_grad
=
weight
.
main_grad
# Don't use all-reduce directly to explicitly control the reduce order.
# dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group)
buffers
=
[
torch
.
empty_like
(
main_grad
)
for
_
in
range
(
self
.
world_size
)]
dist
.
all_gather
(
buffers
,
main_grad
,
group
=
self
.
dp_group
)
for
i
in
range
(
1
,
self
.
world_size
):
buffers
[
0
]
+=
buffers
[
i
]
main_grad
.
copy_
(
buffers
[
0
])
main_grad
/=
self
.
world_size
master_weight
-=
main_grad
*
self
.
lr
weight
.
data
.
copy_
(
master_weight
)
class
MiniZero_1
:
"""A mini zero-1 optimizer implementation, just used for this test"""
def
__init__
(
self
,
weights
,
lr
,
dp_group
,
manual_post_all_gather_processing
=
False
):
self
.
rank
=
dist
.
get_rank
(
dp_group
)
self
.
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
self
.
manual_post_all_gather_processing
=
manual_post_all_gather_processing
# [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer
self
.
offsets
=
[
0
]
for
weight
in
self
.
weights
:
self
.
offsets
.
append
(
self
.
offsets
[
-
1
]
+
weight
.
numel
())
# Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may
# not be the end range of the last weight.
if
self
.
offsets
[
-
1
]
%
self
.
world_size
!=
0
:
self
.
offsets
[
-
1
]
+=
self
.
world_size
-
self
.
offsets
[
-
1
]
%
self
.
world_size
self
.
master_weights
=
[]
# The start offset of the master weight in the weight
self
.
start_offsets
=
[]
# The overlapping area of the weight and this rank's local buffer
self
.
overlapping_areas
=
[]
# The start and end of this rank's local buffer in the global buffer
rank_start
=
self
.
offsets
[
-
1
]
//
self
.
world_size
*
self
.
rank
rank_end
=
rank_start
+
self
.
offsets
[
-
1
]
//
self
.
world_size
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
if
offset
>=
rank_end
or
(
offset
+
weight
.
numel
())
<=
rank_start
:
# This weight is not in this rank's local buffer
master_weight
=
None
start_offset
=
None
overlapping_area
=
None
else
:
overlapping_start
=
max
(
rank_start
,
offset
)
overlapping_end
=
min
(
rank_end
,
offset
+
weight
.
numel
())
length
=
overlapping_end
-
overlapping_start
start_offset
=
overlapping_start
-
offset
if
isinstance
(
weight
,
QuantizedTensor
):
# If weight is a FP8 tensor, we need to use the original high precision version
# to initialize the master weight.
high_precision_init_val
=
weight
.
get_high_precision_init_val
().
view
(
-
1
)
master_weight
=
high_precision_init_val
.
to
(
weight
.
device
).
float
()[
start_offset
:
start_offset
+
length
]
else
:
master_weight
=
(
weight
.
detach
().
view
(
-
1
).
float
()[
start_offset
:
start_offset
+
length
]
)
overlapping_area
=
(
overlapping_start
,
overlapping_end
)
self
.
master_weights
.
append
(
master_weight
)
self
.
start_offsets
.
append
(
start_offset
)
self
.
overlapping_areas
.
append
(
overlapping_area
)
# Create global buffer for grads reduce-scatter
self
.
grad_buffer
=
torch
.
empty
(
[
self
.
offsets
[
-
1
]],
dtype
=
torch
.
float32
,
device
=
weights
[
0
].
device
)
self
.
grad_buffer_slice
=
self
.
grad_buffer
[
rank_start
:
rank_end
]
# Create global buffer for weights all-gather
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
weight_buffer_dtype
=
torch
.
uint8
else
:
weight_buffer_dtype
=
weights
[
0
].
dtype
self
.
weight_buffer
=
torch
.
empty
(
[
self
.
offsets
[
-
1
]],
dtype
=
weight_buffer_dtype
,
device
=
weights
[
0
].
device
)
self
.
weight_buffer_slice
=
self
.
weight_buffer
[
rank_start
:
rank_end
]
def
step
(
self
):
# -----------------------------------------------------------------------------------------
# Step 1: Copy grads to the grad buffer
# -----------------------------------------------------------------------------------------
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
start
=
offset
end
=
offset
+
weight
.
numel
()
self
.
grad_buffer
[
start
:
end
].
copy_
(
weight
.
main_grad
.
view
(
-
1
))
# -----------------------------------------------------------------------------------------
# Step 2: Grads reduce-scatter
# -----------------------------------------------------------------------------------------
# Don't use reduce_scatter directly to explicitly control the reduce order.
# dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG,
# group=self.dp_group)
buffers
=
[
torch
.
empty_like
(
self
.
grad_buffer
)
for
_
in
range
(
self
.
world_size
)]
dist
.
all_gather
(
buffers
,
self
.
grad_buffer
,
group
=
self
.
dp_group
)
for
i
in
range
(
1
,
self
.
world_size
):
buffers
[
0
]
+=
buffers
[
i
]
rank_start
=
self
.
offsets
[
-
1
]
//
self
.
world_size
*
self
.
rank
rank_end
=
rank_start
+
self
.
offsets
[
-
1
]
//
self
.
world_size
self
.
grad_buffer_slice
.
copy_
(
buffers
[
0
][
rank_start
:
rank_end
])
self
.
grad_buffer_slice
/=
self
.
world_size
# -----------------------------------------------------------------------------------------
# Step 3: Update master weights
# -----------------------------------------------------------------------------------------
for
master_weight
,
overlapping_area
in
zip
(
self
.
master_weights
,
self
.
overlapping_areas
):
if
master_weight
is
None
:
# This weight's master weight is in other rank.
continue
grad
=
self
.
grad_buffer
[
overlapping_area
[
0
]
:
overlapping_area
[
1
]]
master_weight
-=
grad
*
self
.
lr
# -----------------------------------------------------------------------------------------
# Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight
# -----------------------------------------------------------------------------------------
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
# FP8 weights case
for
i
in
range
(
1
,
len
(
self
.
weights
)):
assert
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
)
cast_master_weights_to_fp8
(
self
.
weights
,
self
.
master_weights
,
self
.
start_offsets
,
self
.
dp_group
,
manual_post_all_gather_processing
=
self
.
manual_post_all_gather_processing
,
)
else
:
# BF16 weights case
for
weight
,
master_weight
,
start_offset
in
zip
(
self
.
weights
,
self
.
master_weights
,
self
.
start_offsets
):
if
master_weight
is
None
:
continue
start
=
start_offset
end
=
start_offset
+
master_weight
.
numel
()
weight
.
data
.
view
(
-
1
)[
start
:
end
].
copy_
(
master_weight
)
# -----------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -----------------------------------------------------------------------------------------
for
i
in
range
(
len
(
self
.
weights
)):
master_weight
=
self
.
master_weights
[
i
]
if
master_weight
is
None
:
continue
start_offset
=
self
.
start_offsets
[
i
]
if
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
):
weight
=
_get_raw_data
(
self
.
weights
[
i
])
else
:
weight
=
self
.
weights
[
i
]
weight_slice
=
weight
.
view
(
-
1
)[
start_offset
:
start_offset
+
master_weight
.
numel
()]
overlapping_start
,
overlapping_end
=
self
.
overlapping_areas
[
i
]
self
.
weight_buffer
[
overlapping_start
:
overlapping_end
].
copy_
(
weight_slice
)
# -----------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -----------------------------------------------------------------------------------------
dist
.
all_gather_into_tensor
(
self
.
weight_buffer
,
self
.
weight_buffer_slice
,
group
=
self
.
dp_group
)
# -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -----------------------------------------------------------------------------------------
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
start
=
offset
end
=
offset
+
weight
.
numel
()
if
isinstance
(
weight
,
QuantizedTensor
):
weight
=
_get_raw_data
(
weight
)
weight
.
view
(
-
1
).
data
.
copy_
(
self
.
weight_buffer
[
start
:
end
])
if
self
.
manual_post_all_gather_processing
:
quantized_weights
=
[
weight
for
weight
in
self
.
weights
if
isinstance
(
weight
,
QuantizedTensor
)
]
post_all_gather_processing
(
quantized_weights
)
class
MiniFSDP
:
def
__init__
(
self
,
weights
,
lr
,
dp_group
,
manual_post_all_gather_processing
=
False
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
self
.
manual_post_all_gather_processing
=
manual_post_all_gather_processing
# Flatten the weights and pad to align with world size
if
isinstance
(
weights
[
0
],
QuantizedTensor
):
raw_data_list
=
[
_get_raw_data
(
w
).
view
(
-
1
)
for
w
in
weights
]
else
:
raw_data_list
=
[
w
.
view
(
-
1
)
for
w
in
weights
]
self
.
flatten_weight
,
original_length
=
self
.
_flatten_tensors_with_pad
(
raw_data_list
)
# Split flattened weights into shards
self
.
local_weight_shard
=
torch
.
chunk
(
self
.
flatten_weight
,
world_size
)[
rank
]
self
.
local_main_grad_shard
=
torch
.
zeros_like
(
self
.
local_weight_shard
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
shard_size
=
self
.
flatten_weight
.
size
(
0
)
//
world_size
# Map original tensors to flattened indices
tensor_indices
=
[]
cumulative_length
=
0
for
tensor
in
raw_data_list
:
length
=
tensor
.
size
(
0
)
tensor_indices
.
append
((
cumulative_length
,
cumulative_length
+
length
))
cumulative_length
+=
length
# Build shard index mappings
self
.
weight_indices
=
[]
self
.
shard_indices
=
[]
for
idx
,
(
start
,
end
)
in
enumerate
(
tensor_indices
):
shard_start
=
rank
*
shard_size
shard_end
=
shard_start
+
shard_size
adjusted_end
=
min
(
shard_end
,
original_length
)
if
start
<=
adjusted_end
and
end
>=
shard_start
:
start_idx
=
max
(
start
,
shard_start
)
end_idx
=
min
(
end
,
adjusted_end
)
self
.
weight_indices
.
append
((
start_idx
-
start
,
end_idx
-
start
))
self
.
shard_indices
.
append
((
start_idx
-
shard_start
,
end_idx
-
shard_start
))
else
:
self
.
weight_indices
.
append
((
None
,
None
))
self
.
shard_indices
.
append
((
None
,
None
))
if
isinstance
(
weights
[
idx
],
QuantizedTensor
):
replace_raw_data
(
weights
[
idx
],
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
)
else
:
weights
[
idx
].
data
=
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
# Initialize local model weights and high-precision master weights
self
.
local_weights
=
[]
self
.
master_weights
=
[]
for
i
,
weight
in
enumerate
(
self
.
weights
):
weight_start
,
weight_end
=
self
.
weight_indices
[
i
]
shard_start
,
shard_end
=
self
.
shard_indices
[
i
]
if
shard_start
is
not
None
and
shard_end
is
not
None
:
local_weight_shard
=
self
.
local_weight_shard
[
shard_start
:
shard_end
]
self
.
local_weights
.
append
(
local_weight_shard
)
if
isinstance
(
weight
,
QuantizedTensor
):
high_precision_init_val
=
weight
.
get_high_precision_init_val
().
view
(
-
1
)
master_weight_shard
=
high_precision_init_val
.
to
(
weight
.
device
).
float
()[
weight_start
:
weight_end
]
else
:
master_weight_shard
=
weight
.
detach
().
view
(
-
1
).
float
()[
weight_start
:
weight_end
]
self
.
master_weights
.
append
(
master_weight_shard
)
else
:
self
.
local_weights
.
append
(
None
)
self
.
master_weights
.
append
(
None
)
setattr
(
weight
,
"main_grad"
,
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
)
def
_flatten_tensors_with_pad
(
self
,
tensors
):
"""
Flatten the list of tensors and pad them to align with the world size.
Args:
tensors (list): List of tensors to flatten.
Returns:
tuple: Flattened tensor and its original length before padding.
"""
world_size
=
dist
.
get_world_size
(
self
.
dp_group
)
flatten_tensor
=
torch
.
cat
(
tensors
)
original_length
=
flatten_tensor
.
size
(
0
)
padding_needed
=
(
world_size
-
original_length
%
world_size
)
%
world_size
if
padding_needed
>
0
:
zeros
=
torch
.
zeros
(
padding_needed
,
dtype
=
flatten_tensor
.
dtype
,
device
=
"cuda"
)
flatten_tensor
=
torch
.
cat
([
flatten_tensor
,
zeros
])
return
flatten_tensor
,
original_length
def
zero_grad
(
self
):
for
weight
in
self
.
weights
:
weight
.
grad
=
None
weight
.
main_grad
.
zero_
()
def
step
(
self
):
"""
Perform an optimization step for the distributed sharded model.
This method includes:
1. Gradient reduce-scatter: Synchronize gradients across all processes.
2. Master weight update: Update high-precision master weights using local gradients.
3. Precision casting: Cast updated master weights to FP8 or BF16 precision.
4. Weight synchronization: All-gather updated weights across all processes.
Returns:
None
"""
# Step 1: Reduce-scatter the gradients
main_grad_buffer
,
_
=
self
.
_flatten_tensors_with_pad
(
[
weight
.
main_grad
.
view
(
-
1
)
for
weight
in
self
.
weights
]
)
dist
.
reduce_scatter_tensor
(
self
.
local_main_grad_shard
,
main_grad_buffer
,
group
=
self
.
dp_group
)
self
.
local_main_grad_shard
/=
dist
.
get_world_size
(
self
.
dp_group
)
# Step 2: Update the master weights
for
weight
,
master_weight
,
(
shard_start
,
shard_end
)
in
zip
(
self
.
weights
,
self
.
master_weights
,
self
.
shard_indices
):
if
master_weight
is
None
:
continue
# Extract the local gradient shard for this weight
grad
=
self
.
local_main_grad_shard
[
shard_start
:
shard_end
]
# Update the master weight using gradient descent
master_weight
-=
grad
*
self
.
lr
# Step 3: Cast master weights to FP8 or BF16 precision
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
local_weights
=
[]
for
local_weight
in
self
.
local_weights
:
if
local_weight
is
None
:
local_weights
.
append
(
None
)
continue
local_weights
.
append
(
local_weight
)
cast_master_weights_to_fp8
(
self
.
weights
,
self
.
master_weights
,
[
idx
[
0
]
for
idx
in
self
.
weight_indices
],
self
.
dp_group
,
local_weights
,
manual_post_all_gather_processing
=
self
.
manual_post_all_gather_processing
,
)
else
:
for
weight
,
master_weight
in
zip
(
self
.
local_weights
,
self
.
master_weights
):
if
master_weight
is
None
:
continue
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
# Copy updated master weights to local weights
NUM_PROCS
:
int
=
min
(
2
,
torch
.
cuda
.
device_count
())
weight
.
data
.
copy_
(
master_weight
)
LAUNCH_CMD
=
[
"torchrun"
,
f
"--nproc_per_node=
{
NUM_PROCS
}
"
]
# Step 4: All-gather updated weights across processes
dist
.
all_gather_into_tensor
(
self
.
flatten_weight
,
self
.
local_weight_shard
,
group
=
self
.
dp_group
)
if
self
.
manual_post_all_gather_processing
:
quantized_weights
=
[
weight
for
weight
in
self
.
weights
if
isinstance
(
weight
,
QuantizedTensor
)
]
post_all_gather_processing
(
quantized_weights
)
def
_test_mini_optimizer
(
dp_group
):
"""Make sure the implementation of MiniZero_1 and MiniFSDP is correct"""
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
torch
.
manual_seed
(
12345
)
torch
.
cuda
.
manual_seed
(
12345
)
weights
=
[
torch
.
randn
(
256
*
256
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
torch
.
randn
(
256
*
256
*
3
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
torch
.
randn
(
256
*
256
*
2
-
1
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
]
weights_1
=
weights
weights_2
=
[
weight
.
clone
()
for
weight
in
weights
]
weights_3
=
[
weight
.
clone
()
for
weight
in
weights
]
lr
=
1.0
optimizer_1
=
MiniZero_1
(
weights_1
,
lr
,
dp_group
)
optimizer_2
=
MiniOptimizer
(
weights_2
,
lr
,
dp_group
)
optimizer_3
=
MiniFSDP
(
weights_3
,
lr
,
dp_group
)
for
_
in
range
(
100
):
for
w1
,
w2
,
w3
in
zip
(
weights_1
,
weights_2
,
weights_3
):
main_grads
=
[
torch
.
randn_like
(
w1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the grads of different ranks are different.
main_grad
=
main_grads
[
rank
]
w1
.
main_grad
=
main_grad
w2
.
main_grad
=
main_grad
w3
.
main_grad
=
main_grad
optimizer_1
.
step
()
optimizer_2
.
step
()
optimizer_3
.
step
()
for
w1
,
w2
in
zip
(
weights_1
,
weights_2
):
torch
.
testing
.
assert_close
(
w1
,
w2
,
atol
=
0
,
rtol
=
0
)
for
w1
,
w3
in
zip
(
weights_1
,
weights_3
):
torch
.
testing
.
assert_close
(
w1
,
w3
,
atol
=
0
,
rtol
=
0
)
def
_test_cast_master_weights_to_fp8
(
quantization
,
dp_group
,
manual_post_all_gather_processing
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
torch
.
manual_seed
(
12345
)
torch
.
cuda
.
manual_seed
(
12345
)
mock_groups
=
[
dist
.
new_group
(
ranks
=
[
i
])
for
i
in
range
(
world_size
)]
mock_group
=
mock_groups
[
rank
]
linear_kwargs
=
{
"params_dtype"
:
torch
.
bfloat16
,
"bias"
:
False
,
"fuse_wgrad_accumulation"
:
True
}
# Create model with FP8 weights
with
te
.
quantized_model_init
(
enabled
=
quantization
is
not
None
,
recipe
=
_get_quantization_recipe
(
quantization
),
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
high_precision_init_val
=
w_fp8
.
get_high_precision_init_val
()
w
.
data
.
copy_
(
high_precision_init_val
)
# Allocate main_grads for each weight
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
w_fp8
.
main_grad
=
torch
.
zeros_like
(
w_fp8
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
w
.
main_grad
=
torch
.
zeros_like
(
w
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
optimizer_fp8
=
MiniZero_1
(
[
w
for
w
in
model_fp8
.
parameters
()],
10.0
,
dp_group
,
manual_post_all_gather_processing
)
optimizer
=
MiniZero_1
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
i
in
range
(
100
):
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
w_fp8
.
main_grad
.
zero_
()
w
.
main_grad
.
zero_
()
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
recipe
=
_get_quantization_recipe
(
quantization
),
amax_reduction_group
=
mock_group
,
):
y_fp8
=
model_fp8
(
x
)
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
recipe
=
_get_quantization_recipe
(
quantization
),
amax_reduction_group
=
mock_group
,
):
y
=
model
(
x
)
targets
=
[
torch
.
randn_like
(
y
)
for
_
in
range
(
world_size
)]
# Choose based on rank to make sure the targets of different ranks are different.
target
=
targets
[
rank
]
loss_fp8
=
nn
.
MSELoss
()(
y_fp8
,
target
)
loss
=
nn
.
MSELoss
()(
y
,
target
)
loss_fp8
.
backward
()
loss
.
backward
()
optimizer_fp8
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
def
_test_fsdp_cast_master_weights_to_fp8
(
quantization
,
dp_group
,
manual_post_all_gather_processing
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
# Configuration constants
NUM_STEPS
=
100
SEED
=
12345
torch
.
manual_seed
(
SEED
)
torch
.
cuda
.
manual_seed
(
SEED
)
mock_groups
=
[
dist
.
new_group
(
ranks
=
[
i
])
for
i
in
range
(
world_size
)]
mock_group
=
mock_groups
[
rank
]
linear_kwargs
=
{
"params_dtype"
:
torch
.
bfloat16
,
"bias"
:
False
,
"fuse_wgrad_accumulation"
:
True
,
}
# Create model with FP8 weights
with
te
.
quantized_model_init
(
enabled
=
quantization
is
not
None
,
recipe
=
_get_quantization_recipe
(
quantization
),
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
high_precision_init_val
=
w_fp8
.
get_high_precision_init_val
()
w
.
data
.
copy_
(
high_precision_init_val
)
optimizer_fp8
=
MiniFSDP
(
[
w
for
w
in
model_fp8
.
parameters
()],
10.0
,
dp_group
,
manual_post_all_gather_processing
)
optimizer
=
MiniFSDP
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
_
in
range
(
100
):
optimizer_fp8
.
zero_grad
()
optimizer
.
zero_grad
()
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
recipe
=
_get_quantization_recipe
(
quantization
),
amax_reduction_group
=
mock_group
,
):
y_fp8
=
model_fp8
(
x
)
with
te
.
autocast
(
enabled
=
quantization
is
not
None
,
recipe
=
_get_quantization_recipe
(
quantization
),
amax_reduction_group
=
mock_group
,
):
y
=
model
(
x
)
targets
=
[
torch
.
randn_like
(
y
)
for
_
in
range
(
world_size
)]
# Choose based on rank to make sure the targets of different ranks are different.
target
=
targets
[
rank
]
loss_fp8
=
nn
.
MSELoss
()(
y_fp8
,
target
)
loss
=
nn
.
MSELoss
()(
y
,
target
)
loss_fp8
.
backward
()
loss
.
backward
()
optimizer_fp8
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
def
run_parallel_tests
()
->
None
:
"""Run parallel tests"""
WORLD_RANK
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
WORLD_SIZE
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
LOCAL_RANK
=
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
LOCAL_SIZE
=
int
(
os
.
getenv
(
"LOCAL_WORLD_SIZE"
,
"1"
))
assert
WORLD_SIZE
==
LOCAL_SIZE
# this test supports only 1 node
assert
LOCAL_SIZE
<=
torch
.
cuda
.
device_count
()
dist_init_kwargs
=
{
"backend"
:
"nccl"
,
"rank"
:
WORLD_RANK
,
"world_size"
:
WORLD_SIZE
,
"timeout"
:
datetime
.
timedelta
(
seconds
=
30
),
}
dist_init_kwargs
[
"init_method"
]
=
"env://"
dist_init_kwargs
[
"device_id"
]
=
torch
.
device
(
f
"cuda:
{
LOCAL_RANK
}
"
)
assert
dist
.
is_nccl_available
()
torch
.
cuda
.
set_device
(
LOCAL_RANK
)
dist
.
init_process_group
(
**
dist_init_kwargs
)
dp_group
=
dist
.
new_group
(
backend
=
"nccl"
)
quantizations
=
[]
if
is_fp8_available
():
quantizations
.
extend
([
"fp8"
,
"fp8_cs"
])
if
is_fp8_block_scaling_available
():
quantizations
.
append
(
"fp8_block"
)
manual_post_all_gather_processings
=
[
False
,
True
]
_test_mini_optimizer
(
dp_group
)
for
quantization
in
quantizations
:
for
post_ag_processing
in
manual_post_all_gather_processings
:
_test_cast_master_weights_to_fp8
(
quantization
,
dp_group
,
post_ag_processing
)
_test_fsdp_cast_master_weights_to_fp8
(
quantization
,
dp_group
,
post_ag_processing
)
dist
.
destroy_process_group
()
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"cast_master_weights_to_fp8 test needs at least 2 GPUs."
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
def
test_cast_master_weights_to_fp8
(
world_size
:
int
)
->
None
:
"""Launch parallel job that runs parallel tests"""
python_exe
=
pathlib
.
Path
(
sys
.
executable
).
resolve
()
current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
command
=
[
python_exe
,
"-m"
,
"torch.distributed.run"
,
f
"--nproc_per_node=
{
world_size
}
"
,
current_file
,
"--parallel"
,
]
result
=
subprocess
.
run
(
command
,
check
=
True
,
)
def
_run_test
(
quantization
):
def
main
()
->
None
:
test_path
=
TEST_ROOT
/
"run_cast_master_weights_to_fp8.py"
parser
=
argparse
.
ArgumentParser
()
test_cmd
=
LAUNCH_CMD
+
[
str
(
test_path
)]
+
[
"--quantization"
,
quantization
]
parser
.
add_argument
(
"--parallel"
,
action
=
"store_true"
,
help
=
"Run parallel tests"
)
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
check
=
False
)
args
=
parser
.
parse_args
()
assert
result
.
returncode
==
0
if
args
.
parallel
:
run_parallel_tests
()
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
"fp8"
,
"fp8_cs"
,
"fp8_block"
])
if
__name__
==
"__main__"
:
def
test_cast_master_weights_to_fp8
(
quantization
):
main
()
if
quantization
in
(
"fp8"
,
"fp8_cs"
)
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"fp8_block"
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
_run_test
(
quantization
)
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
View file @
c1a1c04e
...
@@ -34,6 +34,7 @@ from transformer_engine.pytorch import (
...
@@ -34,6 +34,7 @@ from transformer_engine.pytorch import (
Float8Tensor
,
Float8Tensor
,
)
)
# Import utility functions
# Import utility functions
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
...
...
tests/pytorch/distributed/test_numerics_exact.py
View file @
c1a1c04e
...
@@ -14,7 +14,7 @@ import transformer_engine.pytorch as te
...
@@ -14,7 +14,7 @@ import transformer_engine.pytorch as te
Distributed numerics tests
Distributed numerics tests
This numerical test aims for zero tolerance test for absolute confidence in numerics.
This numerical test aims for zero tolerance test for absolute confidence in numerics.
In the case of NVFP4, with the
experimental
NVFP4 quantization, we matched bitwise
In the case of NVFP4, with the
custom
NVFP4 quantization, we matched bitwise
result with the native silicon. For distrbuted test cases, we can do the same by thing
result with the native silicon. For distrbuted test cases, we can do the same by thing
by comparing BF16 AG results with the low precision AG results at layer level.
by comparing BF16 AG results with the low precision AG results at layer level.
"""
"""
...
...
tests/pytorch/distributed/test_torch_fsdp2.py
View file @
c1a1c04e
...
@@ -12,22 +12,26 @@ import transformer_engine.pytorch as te
...
@@ -12,22 +12,26 @@ import transformer_engine.pytorch as te
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
NUM_PROCS
:
int
=
torch
.
cuda
.
device_count
()
NUM_PROCS
:
int
=
torch
.
cuda
.
device_count
()
def
_run_test
(
fp_init
,
sharding_dims
):
def
_run_test
(
fp_init
,
sharding_dims
,
recipe
,
layer_type
):
test_path
=
Path
(
__file__
).
parent
.
resolve
()
/
"run_fsdp2_model.py"
test_path
=
Path
(
__file__
).
parent
.
resolve
()
/
"run_fsdp2_model.py"
test_cmd
=
[
"torchrun"
,
f
"--nproc_per_node=
{
NUM_PROCS
}
"
,
str
(
test_path
)]
test_cmd
=
[
"torchrun"
,
f
"--nproc_per_node=
{
NUM_PROCS
}
"
,
str
(
test_path
)]
if
fp_init
:
if
fp_init
:
test_cmd
+=
[
"--fp8-init"
]
test_cmd
+=
[
"--fp8-init"
]
if
len
(
sharding_dims
)
==
1
:
if
len
(
sharding_dims
)
==
1
:
test_cmd
+=
[
"--sharding-dims"
,
str
(
sharding_dims
[
0
])]
test_cmd
+=
[
"--sharding-dims"
,
str
(
sharding_dims
[
0
])]
elif
len
(
sharding_dims
)
==
2
:
elif
len
(
sharding_dims
)
==
2
:
test_cmd
+=
[
"--sharding-dims"
,
str
(
sharding_dims
[
0
]),
str
(
sharding_dims
[
1
])]
test_cmd
+=
[
"--sharding-dims"
,
str
(
sharding_dims
[
0
]),
str
(
sharding_dims
[
1
])]
else
:
else
:
assert
False
assert
False
test_cmd
+=
[
"--recipe"
,
recipe
]
test_cmd
+=
[
"--layer-type"
,
layer_type
]
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
check
=
True
)
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
check
=
True
)
...
@@ -36,16 +40,20 @@ def _run_test(fp_init, sharding_dims):
...
@@ -36,16 +40,20 @@ def _run_test(fp_init, sharding_dims):
@
pytest
.
mark
.
skipif
(
not
te
.
torch_version
()
>=
(
2
,
4
,
0
),
reason
=
"Requires PyTorch 2.4.0+"
)
@
pytest
.
mark
.
skipif
(
not
te
.
torch_version
()
>=
(
2
,
4
,
0
),
reason
=
"Requires PyTorch 2.4.0+"
)
@
pytest
.
mark
.
parametrize
(
"sharding_dims"
,
([
NUM_PROCS
],
[
2
,
NUM_PROCS
//
2
]))
@
pytest
.
mark
.
parametrize
(
"sharding_dims"
,
([
NUM_PROCS
],
[
2
,
NUM_PROCS
//
2
]))
@
pytest
.
mark
.
parametrize
(
"fp8_init"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_init"
,
(
False
,
True
))
def
test_distributed
(
fp8_init
,
sharding_dims
):
@
pytest
.
mark
.
parametrize
(
"recipe"
,
(
"delayed_scaling"
,
"current_scaling"
,
"mx_fp8_block_scaling"
))
@
pytest
.
mark
.
parametrize
(
"layer_type"
,
(
"LayerNormLinear"
,
"TransformerLayer"
))
def
test_distributed
(
fp8_init
,
sharding_dims
,
recipe
,
layer_type
):
# Skip invalid configurations
# Skip invalid configurations
if
torch
.
cuda
.
device_count
()
<
4
:
if
torch
.
cuda
.
device_count
()
<
4
:
pytest
.
skip
(
"FSDP2 test requires at least 4 GPUs"
)
pytest
.
skip
(
"FSDP2 test requires at least 4 GPUs"
)
if
fp8_init
and
not
fp8_available
:
if
recipe
==
"mx_fp8_block_scaling"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
elif
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
_run_test
(
fp8_init
,
sharding_dims
)
_run_test
(
fp8_init
,
sharding_dims
,
recipe
,
layer_type
)
def
test_dummy
()
->
None
:
def
test_dummy
()
->
None
:
...
...
tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
View file @
c1a1c04e
...
@@ -8,8 +8,8 @@ import transformer_engine.pytorch as te
...
@@ -8,8 +8,8 @@ import transformer_engine.pytorch as te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch.
experimental
.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.pytorch.
custom_recipes
.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.pytorch.
experimental
import
utils
from
transformer_engine.pytorch.
custom_recipes
import
utils
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
...
...
tests/pytorch/nvfp4/test_nvfp4_module_exact.py
View file @
c1a1c04e
...
@@ -6,8 +6,8 @@ import pytest
...
@@ -6,8 +6,8 @@ import pytest
import
torch
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.
experimental
import
quantization_nvfp4
from
transformer_engine.pytorch.
custom_recipes
import
quantization_nvfp4
from
transformer_engine.pytorch.
experimental
import
utils
from
transformer_engine.pytorch.
custom_recipes
import
utils
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
...
...
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
View file @
c1a1c04e
...
@@ -7,10 +7,10 @@ import torch
...
@@ -7,10 +7,10 @@ import torch
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch.experimental.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.pytorch.custom_recipes.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.pytorch.custom_recipes
import
utils
from
transformer_engine.common.recipe
import
NVFP4BlockScaling
from
transformer_engine.common.recipe
import
NVFP4BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.experimental
import
utils
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
recipe_available
,
reason_for_no_recipe
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
...
...
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
View file @
c1a1c04e
...
@@ -12,10 +12,10 @@
...
@@ -12,10 +12,10 @@
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.pytorch
import
NVFP4Quantizer
from
transformer_engine.common.recipe
import
NVFP4BlockScaling
from
transformer_engine.pytorch.custom_recipes.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.pytorch.custom_recipes
import
utils
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.experimental.quantization_nvfp4
import
NVFP4QuantizerRef
from
transformer_engine.common.recipe
import
NVFP4BlockScaling
from
transformer_engine.pytorch.experimental
import
utils
import
pytest
import
pytest
import
torch
import
torch
...
...
tests/pytorch/test_cpu_offloading.py
View file @
c1a1c04e
...
@@ -2,27 +2,41 @@
...
@@ -2,27 +2,41 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
import
random
import
contextlib
import
contextlib
import
gc
import
os
from
typing
import
Iterable
,
Optional
import
pytest
import
pytest
import
os
import
torch
import
torch
from
typing
import
Optional
,
List
from
transformer_engine.pytorch.cpu_offload
import
(
get_cpu_offload_context
,
OffloadableLayerState
,
DefaultOffloadSynchronizer
,
start_offload
,
mark_not_offload
,
)
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
utils
import
ModelConfig
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
get_available_attention_backends
# Check supported quantization schemes
# Check supported quantization schemes
fp8_available
=
te
.
is_fp8_available
()
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
=
te
.
is_mxfp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
nvfp4_available
,
_
=
FP8GlobalStateManager
.
is_nvfp4_available
()
quantization_recipes
:
Optional
[
recipe
.
Recipe
]
=
[
None
]
quantization_recipes
:
List
[
Optional
[
recipe
.
Recipe
]
]
=
[
None
]
if
fp8_available
:
if
fp8_available
:
quantization_recipes
.
extend
((
recipe
.
Float8CurrentScaling
(),
recipe
.
DelayedScaling
()))
quantization_recipes
.
extend
((
recipe
.
Float8CurrentScaling
(),
recipe
.
DelayedScaling
()))
if
fp8_block_scaling_available
:
quantization_recipes
.
append
(
recipe
.
Float8BlockScaling
())
if
mxfp8_available
:
quantization_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
if
nvfp4_available
:
quantization_recipes
.
append
(
recipe
.
NVFP4BlockScaling
())
model_config
=
{
model_config
=
{
"small"
:
ModelConfig
(
8
,
512
,
8
,
64
,
num_layers
=
5
,
eps
=
0.1
),
"small"
:
ModelConfig
(
8
,
512
,
8
,
64
,
num_layers
=
5
,
eps
=
0.1
),
...
@@ -32,181 +46,709 @@ NUM_HEADS = model_config["small"].num_heads
...
@@ -32,181 +46,709 @@ NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS
=
model_config
[
"small"
].
num_layers
NUM_LAYERS
=
model_config
[
"small"
].
num_layers
EPSILON
=
model_config
[
"small"
].
eps
EPSILON
=
model_config
[
"small"
].
eps
# Flash attention saves some internal tensor for the backward pass
# Disable garbage collection to tests if there are reference cycles.
# that cannot be offloaded to CPU.
# We do not want them, because they can result in CUDA out of memory errors.
assert
os
.
getenv
(
"NVTE_FLASH_ATTN"
)
==
"0"
import
gc
gc
.
disable
()
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
class
Utils
:
#
tensor1
=
torch
.
randn
((
1024
,
1024
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# For the TransformerLayer, activation offloading with dropout is not supported,
_B
=
64
# so we set hidden_dropout to 0.0.
_S
=
256
model_types
=
{
_H
=
4
"linear"
:
lambda
:
te
.
Linear
(
SIZE
,
SIZE
,
params_dtype
=
torch
.
bfloat16
),
_D
=
256
"layernorm_mlp"
:
lambda
:
te
.
LayerNormMLP
(
SIZE
,
SIZE
,
params_dtype
=
torch
.
bfloat16
),
"layernorm_linear"
:
lambda
:
te
.
LayerNormLinear
(
SIZE
,
SIZE
,
params_dtype
=
torch
.
bfloat16
),
@
staticmethod
"multihead_attention"
:
lambda
:
te
.
MultiheadAttention
(
def
long_job
(
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
):
SIZE
,
NUM_HEADS
,
params_dtype
=
torch
.
bfloat16
NUM_ITERS
=
6000
),
if
stream
is
None
:
"transformer_layer"
:
lambda
:
te
.
TransformerLayer
(
stream
=
torch
.
cuda
.
current_stream
()
SIZE
,
SIZE
,
NUM_HEADS
,
params_dtype
=
torch
.
bfloat16
,
hidden_dropout
=
0.0
),
with
torch
.
cuda
.
stream
(
stream
):
"linear_op"
:
lambda
:
te
.
ops
.
Linear
(
SIZE
,
SIZE
,
dtype
=
torch
.
bfloat16
),
for
i
in
range
(
NUM_ITERS
):
"layernorm_mlp_ops"
:
lambda
:
te
.
ops
.
Sequential
(
Utils
.
tensor1
.
normal_
()
te
.
ops
.
LayerNorm
(
SIZE
,
dtype
=
torch
.
bfloat16
),
te
.
ops
.
Linear
(
SIZE
,
SIZE
,
dtype
=
torch
.
bfloat16
),
@
staticmethod
def
measure_time
(
func
):
import
time
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
func
()
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
return
(
end
-
start
)
*
1000
@
staticmethod
def
get_cuda_memory_mb
():
return
torch
.
cuda
.
memory_allocated
()
/
(
1024
**
2
)
@
staticmethod
def
get_max_cuda_memory_mb
():
return
torch
.
cuda
.
max_memory_allocated
()
/
(
1024
**
2
)
@
staticmethod
def
get_cpu_memory_mb
()
->
float
:
import
psutil
,
os
return
psutil
.
Process
(
os
.
getpid
()).
memory_info
().
rss
/
(
1024
**
2
)
@
staticmethod
def
get_layer_names
():
return
[
"linear"
,
"layernorm_linear"
,
"layernorm_mlp"
,
"grouped_linear"
,
"multihead_attention"
,
"transformer_layer"
,
"linear_op"
,
"layernorm_mlp_ops"
,
]
@
staticmethod
def
create_layer
(
layer_type
:
str
):
if
layer_type
==
"linear"
:
return
te
.
Linear
(
Utils
.
_D
,
Utils
.
_D
,
params_dtype
=
torch
.
bfloat16
)
elif
layer_type
==
"layernorm_linear"
:
return
te
.
LayerNormLinear
(
Utils
.
_D
,
Utils
.
_D
,
params_dtype
=
torch
.
bfloat16
)
elif
layer_type
==
"layernorm_mlp"
:
return
te
.
LayerNormMLP
(
Utils
.
_D
,
Utils
.
_D
,
params_dtype
=
torch
.
bfloat16
)
elif
layer_type
==
"multihead_attention"
:
return
te
.
MultiheadAttention
(
Utils
.
_D
,
Utils
.
_H
,
attention_dropout
=
0.0
,
params_dtype
=
torch
.
bfloat16
)
elif
layer_type
==
"grouped_linear"
:
return
te
.
GroupedLinear
(
Utils
.
_H
,
Utils
.
_D
,
Utils
.
_D
,
params_dtype
=
torch
.
bfloat16
)
elif
layer_type
==
"transformer_layer"
:
return
te
.
TransformerLayer
(
Utils
.
_D
,
Utils
.
_D
,
Utils
.
_H
,
attention_dropout
=
0.0
,
hidden_dropout
=
0.0
,
params_dtype
=
torch
.
bfloat16
,
)
elif
layer_type
==
"linear_op"
:
return
te
.
ops
.
Linear
(
Utils
.
_D
,
Utils
.
_D
,
dtype
=
torch
.
bfloat16
)
elif
layer_type
==
"layernorm_mlp_ops"
:
return
te
.
ops
.
Sequential
(
te
.
ops
.
LayerNorm
(
Utils
.
_D
,
dtype
=
torch
.
bfloat16
),
te
.
ops
.
Linear
(
Utils
.
_D
,
Utils
.
_D
,
dtype
=
torch
.
bfloat16
),
te
.
ops
.
GELU
(),
te
.
ops
.
GELU
(),
te
.
ops
.
Linear
(
SIZE
,
SIZE
,
dtype
=
torch
.
bfloat16
),
te
.
ops
.
Linear
(
Utils
.
_D
,
Utils
.
_D
,
dtype
=
torch
.
bfloat16
),
),
)
}
else
:
raise
ValueError
(
f
"Unknown layer type:
{
layer_type
}
"
)
@
staticmethod
def
create_tensor
(
recipe
:
Optional
[
recipe
.
Recipe
],
requires_grad
:
bool
=
False
)
->
torch
.
Tensor
:
shape
=
(
Utils
.
_B
,
Utils
.
_S
,
Utils
.
_D
)
tensor
=
torch
.
randn
(
shape
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
if
recipe
is
None
:
tensor
=
tensor
.
requires_grad_
()
if
requires_grad
else
tensor
return
tensor
elif
recipe
.
delayed
():
quantizer
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
scale
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
),
amax
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
),
)
return
quantizer
(
tensor
)
elif
recipe
.
float8_current_scaling
():
quantizer
=
te
.
tensor
.
float8_tensor
.
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
return
quantizer
(
tensor
)
elif
recipe
.
float8_block_scaling
():
quantizer
=
te
.
tensor
.
float8_blockwise_tensor
.
Float8BlockQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
rowwise
=
True
,
columnwise
=
True
)
return
quantizer
(
tensor
)
elif
recipe
.
mxfp8
():
quantizer
=
te
.
tensor
.
mxfp8_tensor
.
MXFP8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
return
quantizer
(
tensor
)
elif
recipe
.
nvfp4
():
quantizer
=
te
.
tensor
.
nvfp4_tensor
.
NVFP4Quantizer
()
return
quantizer
(
tensor
)
@
staticmethod
def
create_recipe_ctx
(
recipe
:
Optional
[
recipe
.
Recipe
]):
if
recipe
is
None
:
return
lambda
:
contextlib
.
nullcontext
()
else
:
return
lambda
:
te
.
fp8_autocast
(
fp8_recipe
=
recipe
)
@
staticmethod
def
get_tensor_size_mb
(
tensor
):
if
tensor
is
None
:
return
0
if
isinstance
(
tensor
,
te
.
quantized_tensor
.
QuantizedTensorStorage
):
return
sum
(
Utils
.
get_tensor_size_mb
(
t
)
for
t
in
tensor
.
get_data_tensors
())
else
:
return
tensor
.
numel
()
*
tensor
.
element_size
()
/
(
1024
**
2
)
@
staticmethod
def
memory_leak_check
():
# Should be called before each test.
# Only cublas workspaces and some global tensors are allowed to be allocated.
# All other allocations should be released.
# This is a simple check to catch memory leaks.
if
Utils
.
get_cuda_memory_mb
()
>
1000
:
memory_num
=
Utils
.
get_cuda_memory_mb
()
import
gc
gc
.
collect
()
# We want next test to be run with clean state.
gc
.
disable
()
raise
RuntimeError
(
f
"Memory leak:
{
memory_num
}
MB"
)
class
TestsOffloadableLayerState
:
@
pytest
.
mark
.
parametrize
(
"random_num_tensors"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
quantization_recipes
)
def
test_general
(
self
,
random_num_tensors
,
recipe
):
"""
Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers,
for each layer offload random number of random tensors.
Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors.
"""
Utils
.
memory_leak_check
()
NUM_ITERATIONS
=
10
stream
=
torch
.
cuda
.
Stream
()
offload_layer_state
=
OffloadableLayerState
(
offload_stream
=
stream
,
)
for
_
in
range
(
NUM_ITERATIONS
):
original_tensors
=
[]
tensors_ids
=
[]
NUM_TENSORS
=
random
.
choice
([
1
,
20
])
if
random_num_tensors
else
1
for
_
in
range
(
NUM_TENSORS
):
tensor
=
Utils
.
create_tensor
(
recipe
)
original_tensors
.
append
(
tensor
)
tensor_id
=
offload_layer_state
.
push_tensor
(
tensor
)
assert
tensor
.
device
.
type
==
"cuda"
tensors_ids
.
append
(
tensor_id
)
offload_layer_state
.
start_offload
()
offload_layer_state
.
release_activation_forward_gpu_memory
()
offload_layer_state
.
start_reload
()
for
j
in
range
(
len
(
tensors_ids
)):
tensor_gpu
=
offload_layer_state
.
pop_tensor
(
tensors_ids
[
j
])
assert
tensor_gpu
.
device
.
type
==
"cuda"
assert
tensor_gpu
.
shape
==
original_tensors
[
j
].
shape
assert
tensor_gpu
.
dtype
==
original_tensors
[
j
].
dtype
torch
.
testing
.
assert_close
(
tensor_gpu
,
original_tensors
[
j
])
offload_layer_state
.
release_all_memory
()
torch
.
cuda
.
synchronize
()
def
test_offload_base_tensor
(
self
):
Utils
.
memory_leak_check
()
stream
=
torch
.
cuda
.
Stream
()
offload_layer_state
=
OffloadableLayerState
(
offload_stream
=
stream
,
)
init_cuda_memory
=
Utils
.
get_cuda_memory_mb
()
x
=
Utils
.
create_tensor
(
None
)
x_size
=
Utils
.
get_tensor_size_mb
(
x
)
x_1
=
x
[::
2
]
x_2
=
x
[
1
::
2
]
start_offload
(
x_1
,
offload_base_tensor
=
True
)
start_offload
(
x_2
,
offload_base_tensor
=
True
)
x1_id
=
offload_layer_state
.
push_tensor
(
x_1
)
x2_id
=
offload_layer_state
.
push_tensor
(
x_2
)
del
x_1
,
x_2
offload_layer_state
.
start_offload
()
offload_layer_state
.
release_activation_forward_gpu_memory
()
assert
offload_layer_state
.
get_offloaded_total_size_mb
()
==
pytest
.
approx
(
x_size
,
0.1
)
offload_layer_state
.
start_reload
()
x_1
=
offload_layer_state
.
pop_tensor
(
x1_id
)
x_2
=
offload_layer_state
.
pop_tensor
(
x2_id
)
assert
x_1
.
device
.
type
==
"cuda"
assert
x_2
.
device
.
type
==
"cuda"
assert
torch
.
allclose
(
x_1
,
x
[::
2
])
assert
torch
.
allclose
(
x_2
,
x
[
1
::
2
])
del
x
assert
Utils
.
get_cuda_memory_mb
()
==
pytest
.
approx
(
init_cuda_memory
+
x_size
,
0.1
)
class
TestsDefaultOffloadSynchronizer
:
@
pytest
.
mark
.
parametrize
(
"random_num_tensors"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
quantization_recipes
)
def
test_general
(
self
,
random_num_tensors
,
recipe
):
"""
Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers,
for each layer offload random number of random tensors.
Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors.
"""
Utils
.
memory_leak_check
()
NUM_LAYERS
=
10
NUM_ITERATIONS
=
10
offload_synchronizer
=
DefaultOffloadSynchronizer
(
num_layers
=
NUM_LAYERS
,
num_offloaded_layers
=
NUM_LAYERS
-
1
,
)
for
_
in
range
(
NUM_ITERATIONS
):
original_tensors
=
[]
tensors_ids
=
[]
layer_ids
=
[]
for
i
in
range
(
NUM_LAYERS
):
NUM_LAYER_TENSORS
=
random
.
randint
(
1
,
10
)
if
random_num_tensors
else
1
layer_tensors
=
[]
layer_tensors_ids
=
[]
layer_id
=
offload_synchronizer
.
fwd_step
()
for
_
in
range
(
NUM_LAYER_TENSORS
):
tensor
=
Utils
.
create_tensor
(
recipe
)
layer_tensors
.
append
(
tensor
)
tensor_id
=
offload_synchronizer
.
push_tensor
(
tensor
)
assert
tensor
.
device
.
type
==
"cuda"
layer_tensors_ids
.
append
(
tensor_id
)
layer_ids
.
append
(
layer_id
)
tensors_ids
.
append
(
layer_tensors_ids
)
original_tensors
.
append
(
layer_tensors
)
for
i
in
range
(
NUM_LAYERS
-
1
,
-
1
,
-
1
):
offload_synchronizer
.
bwd_step
(
layer_ids
[
i
])
for
j
in
range
(
len
(
tensors_ids
[
i
])):
tensor_gpu
=
offload_synchronizer
.
pop_tensor
(
tensors_ids
[
i
][
j
])
assert
tensor_gpu
.
device
.
type
==
"cuda"
assert
tensor_gpu
.
shape
==
original_tensors
[
i
][
j
].
shape
assert
tensor_gpu
.
dtype
==
original_tensors
[
i
][
j
].
dtype
torch
.
testing
.
assert_close
(
tensor_gpu
,
original_tensors
[
i
][
j
])
offload_synchronizer
.
finish_part_of_bwd
()
torch
.
cuda
.
synchronize
()
@
pytest
.
mark
.
parametrize
(
"recipe"
,
quantization_recipes
)
def
test_memory
(
self
,
recipe
):
torch
.
cuda
.
synchronize
()
Utils
.
memory_leak_check
()
NUM_LAYERS
=
10
torch
.
cuda
.
reset_peak_memory_stats
()
offload_synchronizer
=
DefaultOffloadSynchronizer
(
num_layers
=
NUM_LAYERS
,
num_offloaded_layers
=
NUM_LAYERS
-
1
,
)
init_cuda_memory
=
Utils
.
get_cuda_memory_mb
()
tensor_ids
=
[]
torch
.
cuda
.
synchronize
()
for
_
in
range
(
NUM_LAYERS
):
offload_synchronizer
.
fwd_step
()
tensor
=
Utils
.
create_tensor
(
recipe
)
tensor_size
=
Utils
.
get_tensor_size_mb
(
tensor
)
tensor_id
=
offload_synchronizer
.
push_tensor
(
tensor
)
assert
tensor
.
device
.
type
==
"cuda"
tensor_ids
.
append
(
tensor_id
)
del
tensor
,
tensor_id
torch
.
cuda
.
synchronize
()
def
_make_input
()
->
torch
.
Tensor
:
if
recipe
is
None
:
"""Generate random input tensor."""
assert
Utils
.
get_max_cuda_memory_mb
()
==
pytest
.
approx
(
return
torch
.
randn
(
init_cuda_memory
+
tensor_size
,
0.1
(
128
,
SIZE
,
SIZE
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
,
requires_grad
=
True
,
)
)
assert
Utils
.
get_cuda_memory_mb
()
==
pytest
.
approx
(
init_cuda_memory
+
tensor_size
,
0.1
)
for
i
in
range
(
NUM_LAYERS
-
1
,
-
1
,
-
1
):
offload_synchronizer
.
bwd_step
(
i
)
tensor_gpu
=
offload_synchronizer
.
pop_tensor
(
tensor_ids
[
i
])
assert
tensor_gpu
.
device
.
type
==
"cuda"
del
tensor_gpu
,
tensor_ids
[
i
]
offload_synchronizer
.
finish_part_of_bwd
()
del
tensor_ids
torch
.
cuda
.
synchronize
()
def
_warmup_model
(
if
recipe
is
None
:
modules
:
Iterable
[
torch
.
nn
.
Module
],
assert
Utils
.
get_max_cuda_memory_mb
()
==
pytest
.
approx
(
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
init_cuda_memory
+
tensor_size
,
0.1
)
->
None
:
)
"""Perform forward and backward pass"""
assert
Utils
.
get_cuda_memory_mb
()
==
pytest
.
approx
(
init_cuda_memory
,
0.1
)
tensor
=
_make_input
()
for
module
in
modules
:
@
pytest
.
mark
.
parametrize
(
"recipe"
,
quantization_recipes
)
with
te
.
autocast
(
def
test_multiple_tensor_offload
(
self
,
recipe
):
enabled
=
quantization_recipe
is
not
None
,
Utils
.
memory_leak_check
()
recipe
=
quantization_recipe
,
init_cpu_memory
=
Utils
.
get_cpu_memory_mb
()
init_cuda_memory
=
Utils
.
get_cuda_memory_mb
()
offload_synchronizer
=
DefaultOffloadSynchronizer
(
num_layers
=
2
,
num_offloaded_layers
=
1
,
)
x1
=
Utils
.
create_tensor
(
recipe
)
x_size
=
Utils
.
get_tensor_size_mb
(
x1
)
offload_synchronizer
.
fwd_step
()
offload_synchronizer
.
push_tensor
(
x1
)
offload_synchronizer
.
push_tensor
(
x1
)
offload_synchronizer
.
push_tensor
(
x1
)
offload_synchronizer
.
fwd_step
()
# Only one copy of tensor on cpu is allocated.
assert
Utils
.
get_cpu_memory_mb
()
==
pytest
.
approx
(
init_cpu_memory
+
1
*
x_size
,
0.1
)
del
x1
offload_synchronizer
.
bwd_step
(
1
)
offload_synchronizer
.
bwd_step
(
0
)
offload_synchronizer
.
finish_part_of_bwd
()
assert
Utils
.
get_cuda_memory_mb
()
==
pytest
.
approx
(
init_cuda_memory
,
0.1
)
class
TestTELayers
:
@
pytest
.
mark
.
parametrize
(
"layer_type"
,
Utils
.
get_layer_names
())
@
pytest
.
mark
.
parametrize
(
"recipe"
,
quantization_recipes
)
def
test_sanity
(
self
,
layer_type
,
recipe
):
Utils
.
memory_leak_check
()
# Skip ops-based layers with Float8BlockScaling recipe
if
(
layer_type
in
[
"linear_op"
,
"layernorm_mlp_ops"
]
and
recipe
is
not
None
and
recipe
.
float8_block_scaling
()
):
):
tensor
=
module
(
tensor
)
pytest
.
skip
(
"Fusible operations do not support FP8 block scaling recipe"
)
tensor
.
sum
().
backward
()
recipe_ctx
=
Utils
.
create_recipe_ctx
(
recipe
)
init_cuda_memory
=
Utils
.
get_cuda_memory_mb
()
OFFLOAD_LAYERS
=
6
NUM_LAYERS
=
10
offload_ctx
,
sync_function
=
get_cpu_offload_context
(
enabled
=
True
,
num_layers
=
OFFLOAD_LAYERS
,
model_layers
=
NUM_LAYERS
,
)
layers
=
[
Utils
.
create_layer
(
layer_type
)
for
_
in
range
(
NUM_LAYERS
)]
inp
=
Utils
.
create_tensor
(
None
)
m_splits
=
(
{
"m_splits"
:
[
Utils
.
_B
*
Utils
.
_S
//
Utils
.
_H
]
*
Utils
.
_H
}
if
layer_type
==
"grouped_linear"
else
{}
)
out
=
inp
for
i
in
range
(
NUM_LAYERS
):
with
offload_ctx
,
recipe_ctx
():
# Ops-based layers don't support is_first_microbatch parameter
if
layer_type
in
[
"linear_op"
,
"layernorm_mlp_ops"
]:
out
=
layers
[
i
](
out
,
**
m_splits
)
else
:
out
=
layers
[
i
](
out
,
is_first_microbatch
=
False
,
**
m_splits
)
out
=
sync_function
(
out
)
out
.
sum
().
backward
()
torch
.
cuda
.
synchronize
()
del
out
,
inp
,
layers
@
pytest
.
mark
.
parametrize
(
"layer_type"
,
Utils
.
get_layer_names
())
@
pytest
.
mark
.
parametrize
(
"recipe"
,
quantization_recipes
)
def
test_memory
(
self
,
layer_type
,
recipe
):
Utils
.
memory_leak_check
()
# Skip ops-based layers with Float8BlockScaling recipe
if
(
layer_type
in
[
"linear_op"
,
"layernorm_mlp_ops"
]
and
recipe
is
not
None
and
recipe
.
float8_block_scaling
()
):
pytest
.
skip
(
"Fusible operations do not support FP8 block scaling recipe"
)
def
_estimate_cached_weight_size
(
offload_ctx
,
sync_function
=
get_cpu_offload_context
(
model_name
:
str
,
enabled
=
True
,
modules
:
Iterable
[
torch
.
nn
.
Module
],
num_layers
=
1
,
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
model_layers
=
2
,
)
->
float
:
offload_activations
=
True
,
"""Calculate the memory (in MiB) needed for weight caching."""
offload_weights
=
False
,
)
recipe_ctx
=
Utils
.
create_recipe_ctx
(
recipe
)
layer
=
Utils
.
create_layer
(
layer_type
)
inp
=
Utils
.
create_tensor
(
None
)
m_splits
=
(
{
"m_splits"
:
[
Utils
.
_B
*
Utils
.
_S
//
Utils
.
_H
]
*
Utils
.
_H
}
if
layer_type
==
"grouped_linear"
else
{}
)
# The weight params are cached directly for unquantized compute
# Ops-based layers don't support is_first_microbatch parameter
if
quantization_recipe
is
None
:
is_ops_layer
=
layer_type
in
[
"linear_op"
,
"layernorm_mlp_ops"
]
return
0
# Count number of weight param elements
with
recipe_ctx
():
param_elements
=
0
if
is_ops_layer
:
for
module
in
modules
:
out
=
layer
(
inp
,
**
m_splits
)
for
param
in
module
.
parameters
()
:
else
:
if
param
.
dim
()
==
2
:
out
=
layer
(
inp
,
is_first_microbatch
=
True
,
**
m_splits
)
param_elements
+=
param
.
numel
()
out
.
sum
().
backward
()
# FP8 tensor-scaling caches one byte per element
del
inp
if
quantization_recipe
.
delayed
()
or
quantization_recipe
.
float8_current_scaling
():
init_cuda_memory
=
Utils
.
get_cuda_memory_mb
()
if
not
is_non_tn_fp8_gemm_supported
()
and
model_name
not
in
(
"linear_op"
,
# run layer without offload
"layernorm_mlp_ops"
,
inp
=
Utils
.
create_tensor
(
None
)
with
recipe_ctx
():
if
is_ops_layer
:
out
=
layer
(
inp
,
**
m_splits
)
else
:
out
=
layer
(
inp
,
is_first_microbatch
=
False
,
**
m_splits
)
with
recipe_ctx
():
out
=
out
+
1
del
inp
cuda_memory_no_offload
=
Utils
.
get_cuda_memory_mb
()
out
.
sum
().
backward
()
# run layer with offload
inp
=
Utils
.
create_tensor
(
None
)
with
offload_ctx
,
recipe_ctx
():
if
is_ops_layer
:
out
=
layer
(
inp
,
**
m_splits
)
else
:
out
=
layer
(
inp
,
is_first_microbatch
=
False
,
**
m_splits
)
out
=
sync_function
(
out
)
with
offload_ctx
,
recipe_ctx
():
out
=
out
+
1
out
=
sync_function
(
out
)
del
inp
assert
Utils
.
get_cuda_memory_mb
()
==
pytest
.
approx
(
init_cuda_memory
,
0.1
)
offloaded_memory_cpu
=
offload_ctx
.
offload_synchronizer
.
get_offloaded_total_size_mb
()
# This assertion verifies that the memory used by tensors on the CPU matches the memory saved from a layer.
# It helps catch cases where an offloaded tensor still has a live pointer, which would
# cause an unnecessary copy to the CPU and prevent GPU memory from being released.
assert
Utils
.
get_cuda_memory_mb
()
+
offloaded_memory_cpu
==
pytest
.
approx
(
cuda_memory_no_offload
,
0.1
)
out
.
sum
().
backward
()
@
pytest
.
mark
.
parametrize
(
"layer_type"
,
Utils
.
get_layer_names
())
@
pytest
.
mark
.
parametrize
(
"recipe"
,
quantization_recipes
)
def
test_manual_synchronization
(
self
,
recipe
,
layer_type
):
Utils
.
memory_leak_check
()
# Skip ops-based layers with Float8BlockScaling recipe
if
(
layer_type
in
[
"linear_op"
,
"layernorm_mlp_ops"
]
and
recipe
is
not
None
and
recipe
.
float8_block_scaling
()
):
):
# Modules do not deallocate FP8 transpose for weights
pytest
.
skip
(
"Fusible operations do not support FP8 block scaling recipe"
)
return
2
*
param_elements
/
1024
**
2
return
param_elements
/
1024
**
2
# MXFP8 caches one data byte per element and one scale byte per 32
offload_ctx
,
sync_function
,
manual_controller
=
get_cpu_offload_context
(
# elements
enabled
=
True
,
if
quantization_recipe
.
mxfp8
():
model_layers
=
6
,
if
model_name
not
in
(
"linear_op"
,
"layernorm_mlp_ops"
):
offload_activations
=
True
,
# Modules do not deallocate column-wise MXFP8 data for weights
manual_synchronization
=
True
,
return
2
*
param_elements
*
(
1
+
1
/
32
)
/
1024
**
2
)
return
param_elements
*
(
1
+
1
/
32
)
/
1024
**
2
layer_1
=
Utils
.
create_layer
(
layer_type
)
layer_2
=
Utils
.
create_layer
(
layer_type
)
inp1
=
Utils
.
create_tensor
(
None
)
inp2
=
Utils
.
create_tensor
(
None
)
r
aise
NotImplementedError
(
f
"Unrecognized recipe (
{
quantization_
recipe
}
)"
)
r
ecipe_ctx
=
Utils
.
create_recipe_ctx
(
recipe
)
m_splits
=
(
{
"m_splits"
:
[
Utils
.
_B
*
Utils
.
_S
//
Utils
.
_H
]
*
Utils
.
_H
}
if
layer_type
==
"grouped_linear"
else
{}
)
def
_measure_cached_memory
(
init_cuda_memory
=
Utils
.
get_cuda_memory_mb
()
modules
:
Iterable
[
torch
.
nn
.
Module
],
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
# 1 fwd
cpu_offload
:
bool
,
with
offload_ctx
,
recipe_ctx
():
)
->
float
:
out_1
=
layer_1
(
inp1
,
**
m_splits
)
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
out_1
=
sync_function
(
out_1
)
with
offload_ctx
,
recipe_ctx
():
out_2
=
layer_2
(
inp2
,
**
m_splits
)
out_2
=
sync_function
(
out_2
)
mark_not_offload
(
out_1
,
out_2
)
del
inp1
,
inp2
memory_before_offload
=
Utils
.
get_cuda_memory_mb
()
manual_controller
.
start_offload_layer
(
0
)
manual_controller
.
release_activation_forward_gpu_memory
(
0
)
manual_controller
.
start_offload_layer
(
1
)
manual_controller
.
release_activation_forward_gpu_memory
(
1
)
memory_after_offload
=
Utils
.
get_cuda_memory_mb
()
assert
memory_after_offload
+
EPSILON
<
memory_before_offload
manual_controller
.
start_reload_layer
(
0
)
manual_controller
.
start_reload_layer
(
1
)
memory_after_reload
=
Utils
.
get_cuda_memory_mb
()
assert
memory_after_reload
==
pytest
.
approx
(
memory_before_offload
,
0.1
)
out_1
.
sum
().
backward
()
out_2
.
sum
().
backward
()
@
pytest
.
mark
.
parametrize
(
"recipe"
,
quantization_recipes
)
@
pytest
.
mark
.
parametrize
(
"layer_type"
,
Utils
.
get_layer_names
())
@
pytest
.
mark
.
parametrize
(
"use_cuda_graphs"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"retain_pinned_cpu_buffers"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FlashAttention"
,
"FusedAttention"
,
"UnfusedAttention"
])
def
test_numerics
(
self
,
recipe
,
layer_type
,
use_cuda_graphs
,
backend
,
retain_pinned_cpu_buffers
,
):
# Skip ops-based layers with Float8BlockScaling recipe
if
(
layer_type
in
[
"linear_op"
,
"layernorm_mlp_ops"
]
and
recipe
is
not
None
and
recipe
.
float8_block_scaling
()
):
pytest
.
skip
(
"Fusible operations do not support FP8 block scaling recipe"
)
Memory measurement excludes the input and output tensors.
recipe_ctx
=
Utils
.
create_recipe_ctx
(
recipe
)
"""
if
use_cuda_graphs
and
not
retain_pinned_cpu_buffers
:
pytest
.
skip
(
"Cuda graphs are not yet supported with cpu offloading when"
" retain_pinned_cpu_buffers is False."
)
# Reset memory
if
backend
==
"FusedAttention"
and
use_cuda_graphs
:
gc
.
collect
()
pytest
.
skip
(
torch
.
cuda
.
empty_cache
()
"Fused attention + cuda graphs is temporarily broken, not because of cpu offloading"
)
# Context and sync function for CPU offloading
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
if
cpu_offload
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
offload_context
,
sync_function
=
te
.
get_cpu_offload_context
(
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
backend
==
"FlashAttention"
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
elif
backend
==
"FusedAttention"
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
elif
backend
==
"UnfusedAttention"
:
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
offload_ctx
,
sync_function
=
get_cpu_offload_context
(
enabled
=
True
,
enabled
=
True
,
num_layers
=
len
(
modules
)
,
num_layers
=
1
,
model_layers
=
len
(
modules
)
+
1
,
model_layers
=
2
,
offload_activations
=
True
,
offload_activations
=
True
,
offload_weights
=
False
,
offload_weights
=
False
,
retain_pinned_cpu_buffers
=
retain_pinned_cpu_buffers
,
)
class
Callable
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
offload_ctx
=
None
,
sync_function
=
None
):
super
().
__init__
()
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
Utils
.
create_layer
(
layer_type
)
for
_
in
range
(
2
)]
)
self
.
offload_ctx
=
offload_ctx
self
.
sync_function
=
sync_function
def
forward
(
self
,
x
):
m_splits
=
(
{
"m_splits"
:
[
Utils
.
_B
*
Utils
.
_S
//
Utils
.
_H
]
*
Utils
.
_H
}
if
layer_type
==
"grouped_linear"
else
{}
)
)
is_ops_layer
=
layer_type
in
[
"linear_op"
,
"layernorm_mlp_ops"
]
for
layer
in
self
.
layers
:
with
self
.
offload_ctx
,
recipe_ctx
():
if
is_ops_layer
:
x
=
layer
(
x
,
**
m_splits
)
else
:
else
:
offload_context
=
contextlib
.
nullcontext
()
x
=
layer
(
x
,
is_first_microbatch
=
False
,
**
m_splits
)
sync_function
=
lambda
x
:
x
if
self
.
sync_function
is
not
None
:
x
=
self
.
sync_function
(
x
)
# Forward pass, with dummy step to trigger offload for last module
return
x
inp
=
_make_input
()
tensor
=
inp
memory_before_forward
=
torch
.
cuda
.
memory_allocated
()
/
(
1024
**
2
)
for
module
in
modules
:
with
te
.
autocast
(
enabled
=
quantization_recipe
is
not
None
,
recipe
=
quantization_recipe
),
offload_context
:
tensor
=
module
(
tensor
)
tensor
=
sync_function
(
tensor
)
with
offload_context
:
tensor
=
tensor
.
clone
()
tensor
=
sync_function
(
tensor
)
memory_after_forward
=
(
torch
.
cuda
.
memory_allocated
()
-
tensor
.
nbytes
)
/
(
1024
**
2
)
# Backward pass
tensor
.
sum
().
backward
()
torch
.
cuda
.
synchronize
()
# Memory usage in MiB
callable_offload
=
Callable
(
offload_ctx
=
offload_ctx
,
sync_function
=
sync_function
)
return
memory_after_forward
-
memory_before_forward
callable_no_offload
=
Callable
(
offload_ctx
=
contextlib
.
nullcontext
(),
sync_function
=
None
)
# copy parameters
for
param_offload
,
param_no_offload
in
zip
(
callable_offload
.
parameters
(),
callable_no_offload
.
parameters
()
):
param_offload
.
data
.
copy_
(
param_no_offload
.
data
)
@
pytest
.
mark
.
parametrize
(
"quantization_recipe"
,
quantization_recipes
)
x
=
Utils
.
create_tensor
(
None
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
model_types
.
keys
())
def
test_cpu_offload
(
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
model_name
:
str
)
->
None
:
"""Check that CPU offloading runs and has expected memory usage."""
# Construct model
if
use_cuda_graphs
:
modules_list
=
[
model_types
[
model_name
]()
for
_
in
range
(
NUM_LAYERS
)]
callable_offload
=
te
.
make_graphed_callables
(
if
model_name
in
[
"multihead_attention"
,
"transformer_layer"
]:
callable_offload
,
available_backends
,
*
_
=
get_available_attention_backends
(
(
x
,),
model_config
[
"small"
],
enabled
=
recipe
is
not
None
,
qkv_dtype
=
torch
.
bfloat16
,
recipe
=
(
Utils
.
create_recipe_ctx
(
recipe
)
if
recipe
is
not
None
else
None
),
qkv_layout
=
"sbhd_sbhd_sbhd"
,
)
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
pytest
.
skip
(
"Fused attention backend not available."
)
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
# Warmup
# warm up (for example to compute sf for delayed scaling)
_warmup_model
(
modules_list
,
quantization_recipe
)
for
_
in
range
(
4
):
out
=
callable_offload
(
x
)
out
.
sum
().
backward
()
out
=
callable_no_offload
(
x
)
out
.
sum
().
backward
()
callable_offload
.
zero_grad
(
set_to_none
=
True
)
out_offload
=
callable_offload
(
x
)
out_offload
.
sum
().
backward
()
# save out and gradients
offload_outs
=
[
out_offload
]
for
param
in
callable_offload
.
parameters
():
offload_outs
.
append
(
param
.
detach
().
clone
())
# Measure cached memory after forward pass
torch
.
cuda
.
reset_peak_memory_stats
()
memory_without_offload
=
_measure_cached_memory
(
modules_list
,
quantization_recipe
,
False
)
out_no_offload
=
callable_no_offload
(
x
)
memory_with_offload
=
_measure_cached_memory
(
modules_list
,
quantization_recipe
,
True
)
out_no_offload
.
sum
().
backward
(
)
# Check for expected memory usage
# collect gradients
assert
memory_with_offload
<
memory_without_offload
no_offload_outs
=
[
out_no_offload
]
memory_from_cached_weights
=
_estimate_cached_weight_size
(
for
param
in
callable_no_offload
.
parameters
():
model_name
,
no_offload_outs
.
append
(
param
.
detach
().
clone
())
modules_list
,
quantization_recipe
,
# check if tensors are the same
for
i
in
range
(
len
(
offload_outs
)):
assert
torch
.
allclose
(
offload_outs
[
i
],
no_offload_outs
[
i
]),
f
"Error in tensor
{
i
}
."
torch
.
cuda
.
synchronize
()
def
test_example_from_doc
(
self
):
offload_stream
=
torch
.
cuda
.
Stream
()
num_layers
=
10
layers
=
[
Utils
.
create_layer
(
"transformer_layer"
)
for
_
in
range
(
num_layers
)]
inp
=
[
Utils
.
create_tensor
(
None
)
for
_
in
range
(
num_layers
)]
out
=
[
None
]
*
num_layers
cpu_offload_context
,
sync_function
,
manual_controller
=
get_cpu_offload_context
(
enabled
=
True
,
model_layers
=
num_layers
,
manual_synchronization
=
True
,
offload_stream
=
offload_stream
,
)
)
assert
abs
(
memory_with_offload
-
memory_from_cached_weights
)
<
EPSILON
for
i
in
range
(
num_layers
):
with
cpu_offload_context
:
out
[
i
]
=
layers
[
i
].
forward
(
inp
[
i
])
out
[
i
]
=
sync_function
(
out
[
i
])
manual_controller
.
start_offload_layer
(
i
)
offload_stream
.
synchronize
()
for
i
in
range
(
num_layers
):
manual_controller
.
release_activation_forward_gpu_memory
(
i
)
for
i
in
range
(
num_layers
-
1
,
-
1
,
-
1
):
# these calls are intended to be done in the backward pass
manual_controller
.
start_reload_layer
(
i
)
offload_stream
.
synchronize
()
for
i
in
range
(
num_layers
):
out
[
i
].
sum
().
backward
()
tests/pytorch/test_cpu_offloading_v1.py
0 → 100644
View file @
c1a1c04e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
contextlib
import
gc
import
os
from
typing
import
Iterable
,
Optional
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
from
utils
import
ModelConfig
,
get_available_attention_backends
# Check supported quantization schemes
fp8_available
=
te
.
is_fp8_available
()
mxfp8_available
=
te
.
is_mxfp8_available
()
quantization_recipes
:
Optional
[
recipe
.
Recipe
]
=
[
None
]
if
fp8_available
:
quantization_recipes
.
extend
((
recipe
.
Float8CurrentScaling
(),
recipe
.
DelayedScaling
()))
model_config
=
{
"small"
:
ModelConfig
(
8
,
512
,
8
,
64
,
num_layers
=
5
,
eps
=
0.1
),
}
SIZE
=
model_config
[
"small"
].
hidden_size
NUM_HEADS
=
model_config
[
"small"
].
num_heads
NUM_LAYERS
=
model_config
[
"small"
].
num_layers
EPSILON
=
model_config
[
"small"
].
eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert
os
.
getenv
(
"NVTE_FLASH_ATTN"
)
==
"0"
# CPU offload v1 code path is enabled
assert
os
.
environ
.
get
(
"NVTE_CPU_OFFLOAD_V1"
,
"0"
)
==
"1"
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types
=
{
"linear"
:
lambda
:
te
.
Linear
(
SIZE
,
SIZE
,
params_dtype
=
torch
.
bfloat16
),
"layernorm_mlp"
:
lambda
:
te
.
LayerNormMLP
(
SIZE
,
SIZE
,
params_dtype
=
torch
.
bfloat16
),
"layernorm_linear"
:
lambda
:
te
.
LayerNormLinear
(
SIZE
,
SIZE
,
params_dtype
=
torch
.
bfloat16
),
"multihead_attention"
:
lambda
:
te
.
MultiheadAttention
(
SIZE
,
NUM_HEADS
,
params_dtype
=
torch
.
bfloat16
),
"transformer_layer"
:
lambda
:
te
.
TransformerLayer
(
SIZE
,
SIZE
,
NUM_HEADS
,
params_dtype
=
torch
.
bfloat16
,
hidden_dropout
=
0.0
),
"linear_op"
:
lambda
:
te
.
ops
.
Linear
(
SIZE
,
SIZE
,
dtype
=
torch
.
bfloat16
),
"layernorm_mlp_ops"
:
lambda
:
te
.
ops
.
Sequential
(
te
.
ops
.
LayerNorm
(
SIZE
,
dtype
=
torch
.
bfloat16
),
te
.
ops
.
Linear
(
SIZE
,
SIZE
,
dtype
=
torch
.
bfloat16
),
te
.
ops
.
GELU
(),
te
.
ops
.
Linear
(
SIZE
,
SIZE
,
dtype
=
torch
.
bfloat16
),
),
}
def
_make_input
()
->
torch
.
Tensor
:
"""Generate random input tensor."""
return
torch
.
randn
(
(
128
,
SIZE
,
SIZE
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
,
requires_grad
=
True
,
)
def
_warmup_model
(
modules
:
Iterable
[
torch
.
nn
.
Module
],
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
)
->
None
:
"""Perform forward and backward pass"""
tensor
=
_make_input
()
for
module
in
modules
:
with
te
.
autocast
(
enabled
=
quantization_recipe
is
not
None
,
recipe
=
quantization_recipe
,
):
tensor
=
module
(
tensor
)
tensor
.
sum
().
backward
()
def
_estimate_cached_weight_size
(
model_name
:
str
,
modules
:
Iterable
[
torch
.
nn
.
Module
],
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
)
->
float
:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if
quantization_recipe
is
None
:
return
0
# Count number of weight param elements
param_elements
=
0
for
module
in
modules
:
for
param
in
module
.
parameters
():
if
param
.
dim
()
==
2
:
param_elements
+=
param
.
numel
()
# FP8 tensor-scaling caches one byte per element
if
quantization_recipe
.
delayed
()
or
quantization_recipe
.
float8_current_scaling
():
if
not
is_non_tn_fp8_gemm_supported
()
and
model_name
not
in
(
"linear_op"
,
"layernorm_mlp_ops"
,
):
# Modules do not deallocate FP8 transpose for weights
return
2
*
param_elements
/
1024
**
2
return
param_elements
/
1024
**
2
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if
quantization_recipe
.
mxfp8
():
if
model_name
not
in
(
"linear_op"
,
"layernorm_mlp_ops"
):
# Modules do not deallocate column-wise MXFP8 data for weights
return
2
*
param_elements
*
(
1
+
1
/
32
)
/
1024
**
2
return
param_elements
*
(
1
+
1
/
32
)
/
1024
**
2
raise
NotImplementedError
(
f
"Unrecognized recipe (
{
quantization_recipe
}
)"
)
def
_measure_cached_memory
(
modules
:
Iterable
[
torch
.
nn
.
Module
],
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
cpu_offload
:
bool
,
)
->
float
:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
Memory measurement excludes the input and output tensors.
"""
# Reset memory
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
# Context and sync function for CPU offloading
if
cpu_offload
:
offload_context
,
sync_function
=
te
.
get_cpu_offload_context
(
enabled
=
True
,
num_layers
=
len
(
modules
),
model_layers
=
len
(
modules
)
+
1
,
offload_activations
=
True
,
offload_weights
=
False
,
)
else
:
offload_context
=
contextlib
.
nullcontext
()
sync_function
=
lambda
x
:
x
# Forward pass, with dummy step to trigger offload for last module
inp
=
_make_input
()
tensor
=
inp
memory_before_forward
=
torch
.
cuda
.
memory_allocated
()
/
(
1024
**
2
)
for
module
in
modules
:
with
te
.
autocast
(
enabled
=
quantization_recipe
is
not
None
,
recipe
=
quantization_recipe
),
offload_context
:
tensor
=
module
(
tensor
)
tensor
=
sync_function
(
tensor
)
with
offload_context
:
tensor
=
tensor
.
clone
()
tensor
=
sync_function
(
tensor
)
memory_after_forward
=
(
torch
.
cuda
.
memory_allocated
()
-
tensor
.
nbytes
)
/
(
1024
**
2
)
# Backward pass
tensor
.
sum
().
backward
()
torch
.
cuda
.
synchronize
()
# Memory usage in MiB
return
memory_after_forward
-
memory_before_forward
@
pytest
.
mark
.
parametrize
(
"quantization_recipe"
,
quantization_recipes
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
model_types
.
keys
())
def
test_cpu_offload
(
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
model_name
:
str
)
->
None
:
"""Check that CPU offloading runs and has expected memory usage."""
# Construct model
modules_list
=
[
model_types
[
model_name
]()
for
_
in
range
(
NUM_LAYERS
)]
if
model_name
in
[
"multihead_attention"
,
"transformer_layer"
]:
available_backends
,
*
_
=
get_available_attention_backends
(
model_config
[
"small"
],
qkv_dtype
=
torch
.
bfloat16
,
qkv_layout
=
"sbhd_sbhd_sbhd"
,
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
pytest
.
skip
(
"Fused attention backend not available."
)
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
# Warmup
_warmup_model
(
modules_list
,
quantization_recipe
)
# Measure cached memory after forward pass
memory_without_offload
=
_measure_cached_memory
(
modules_list
,
quantization_recipe
,
False
)
memory_with_offload
=
_measure_cached_memory
(
modules_list
,
quantization_recipe
,
True
)
# Check for expected memory usage
assert
memory_with_offload
<
memory_without_offload
memory_from_cached_weights
=
_estimate_cached_weight_size
(
model_name
,
modules_list
,
quantization_recipe
,
)
assert
abs
(
memory_with_offload
-
memory_from_cached_weights
)
<
EPSILON
tests/pytorch/test_cuda_graphs.py
View file @
c1a1c04e
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
from
typing
import
Iterable
,
List
,
Union
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Tuple
,
Union
import
pytest
import
pytest
import
torch
import
torch
...
@@ -173,6 +173,20 @@ def get_outputs(
...
@@ -173,6 +173,20 @@ def get_outputs(
return
values
return
values
def
reset_graphs
(
graphed_callables
:
Union
[
Callable
,
Tuple
[
Callable
,
...],
Dict
[
Tuple
[
int
,
int
],
Callable
]],
)
->
None
:
"""Reset CUDA graphs."""
if
isinstance
(
graphed_callables
,
tuple
)
or
isinstance
(
graphed_callables
,
list
):
for
callable
in
graphed_callables
:
callable
.
reset
()
elif
isinstance
(
graphed_callables
,
dict
):
for
callable
in
graphed_callables
.
values
():
callable
.
reset
()
else
:
graphed_callables
.
reset
()
class
_Sequential
(
torch
.
nn
.
Sequential
):
class
_Sequential
(
torch
.
nn
.
Sequential
):
"""Sequential model that forwards keyword arguments to modules"""
"""Sequential model that forwards keyword arguments to modules"""
...
@@ -335,7 +349,12 @@ def _test_cuda_graphs(
...
@@ -335,7 +349,12 @@ def _test_cuda_graphs(
output
.
backward
(
grad_output
)
output
.
backward
(
grad_output
)
optimizer
.
step
()
optimizer
.
step
()
return
get_outputs
(
model
,
output
)
outputs
=
get_outputs
(
model
,
output
)
if
graph_mode
==
"full"
:
reset_graphs
(
model
)
elif
graph_mode
==
"individual"
:
reset_graphs
(
modules
)
return
outputs
@
pytest
.
mark
.
parametrize
(
"module"
,
_test_cuda_graphs_modules
)
@
pytest
.
mark
.
parametrize
(
"module"
,
_test_cuda_graphs_modules
)
...
@@ -487,7 +506,10 @@ def _test_cuda_graphs_with_dot_product_attention(
...
@@ -487,7 +506,10 @@ def _test_cuda_graphs_with_dot_product_attention(
output
=
model
(
*
inputs
)
output
=
model
(
*
inputs
)
output
.
backward
(
grad_output
)
output
.
backward
(
grad_output
)
return
get_outputs
(
model
,
output
)
outputs
=
get_outputs
(
model
,
output
)
if
with_graph
:
reset_graphs
(
model
)
return
outputs
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
...
@@ -572,7 +594,10 @@ def _test_cuda_graphs_with_kwargs(
...
@@ -572,7 +594,10 @@ def _test_cuda_graphs_with_kwargs(
output
.
backward
(
grad_output
)
output
.
backward
(
grad_output
)
optimizer
.
step
()
optimizer
.
step
()
return
get_outputs
(
model
,
output
)
outputs
=
get_outputs
(
model
,
output
)
if
with_graph
:
reset_graphs
(
model
)
return
outputs
def
test_make_graphed_callables_with_kwargs
(
def
test_make_graphed_callables_with_kwargs
(
...
@@ -687,7 +712,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
...
@@ -687,7 +712,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
optimizer
.
step
()
optimizer
.
step
()
outputs
=
[
y
for
_
,
y
in
sorted
(
outputs
.
items
())]
outputs
=
[
y
for
_
,
y
in
sorted
(
outputs
.
items
())]
return
get_outputs
(
model
,
outputs
)
outputs
=
get_outputs
(
model
,
outputs
)
if
with_graph
:
reset_graphs
(
layer_forwards
)
return
outputs
def
test_make_graphed_callables_with_interleaved_pipeline_parallelism
(
def
test_make_graphed_callables_with_interleaved_pipeline_parallelism
(
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment