Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9df0c4a3
Commit
9df0c4a3
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main'
parents
0d874a4e
f122b07d
Changes
221
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1890 additions
and
66 deletions
+1890
-66
tests/jax/test_permutation.py
tests/jax/test_permutation.py
+2
-2
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+55
-4
tests/pytorch/attention/test_attention_with_cp.py
tests/pytorch/attention/test_attention_with_cp.py
+19
-5
tests/pytorch/debug/test_log.py
tests/pytorch/debug/test_log.py
+120
-0
tests/pytorch/debug/test_sanity.py
tests/pytorch/debug/test_sanity.py
+19
-6
tests/pytorch/test_checkpoint.py
tests/pytorch/test_checkpoint.py
+1
-1
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+819
-21
tests/pytorch/test_grouped_tensor.py
tests/pytorch/test_grouped_tensor.py
+385
-0
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+2
-0
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+17
-5
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+147
-4
tests/pytorch/utils.py
tests/pytorch/utils.py
+53
-8
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+11
-6
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+6
-4
transformer_engine/common/activation/gelu.cu
transformer_engine/common/activation/gelu.cu
+73
-0
transformer_engine/common/activation/glu.cu
transformer_engine/common/activation/glu.cu
+24
-0
transformer_engine/common/activation/relu.cu
transformer_engine/common/activation/relu.cu
+73
-0
transformer_engine/common/activation/swiglu.cu
transformer_engine/common/activation/swiglu.cu
+36
-0
transformer_engine/common/cast/cast.cu
transformer_engine/common/cast/cast.cu
+22
-0
transformer_engine/common/cast/core/common.cuh
transformer_engine/common/cast/core/common.cuh
+6
-0
No files found.
tests/jax/test_permutation.py
View file @
9df0c4a3
...
@@ -23,7 +23,7 @@ ALL_DISPATCH_COMBINE_CASES = [
...
@@ -23,7 +23,7 @@ ALL_DISPATCH_COMBINE_CASES = [
(
128
,
5
,
128
,
3
),
(
128
,
5
,
128
,
3
),
(
1024
,
8
,
128
,
8
),
(
1024
,
8
,
128
,
8
),
(
4096
,
32
,
1280
,
2
),
(
4096
,
32
,
1280
,
2
),
(
4096
,
25
6
,
4096
,
6
),
(
4096
,
6
4
,
4096
,
6
),
]
]
DISPATCH_COMBINE_CASES
=
{
DISPATCH_COMBINE_CASES
=
{
"L0"
:
ALL_DISPATCH_COMBINE_CASES
[
0
:
2
],
"L0"
:
ALL_DISPATCH_COMBINE_CASES
[
0
:
2
],
...
@@ -44,7 +44,7 @@ ALL_DISPATCH_COMBINE_PADDING_CASES = [
...
@@ -44,7 +44,7 @@ ALL_DISPATCH_COMBINE_PADDING_CASES = [
(
128
,
5
,
128
,
3
,
8
),
(
128
,
5
,
128
,
3
,
8
),
(
1024
,
8
,
128
,
8
,
16
),
(
1024
,
8
,
128
,
8
,
16
),
(
4096
,
32
,
1280
,
2
,
128
),
(
4096
,
32
,
1280
,
2
,
128
),
(
4096
,
25
6
,
4096
,
6
,
16
),
(
4096
,
6
4
,
4096
,
6
,
16
),
]
]
DISPATCH_COMBINE_PADDING_CASES
=
{
DISPATCH_COMBINE_PADDING_CASES
=
{
"L0"
:
ALL_DISPATCH_COMBINE_PADDING_CASES
[
0
:
2
],
"L0"
:
ALL_DISPATCH_COMBINE_PADDING_CASES
[
0
:
2
],
...
...
tests/pytorch/attention/test_attention.py
View file @
9df0c4a3
...
@@ -74,6 +74,14 @@ if not IS_HIP_EXTENSION:
...
@@ -74,6 +74,14 @@ if not IS_HIP_EXTENSION:
f
" sm
{
device_compute_capability
[
0
]
*
10
+
device_compute_capability
[
1
]
}
"
f
" sm
{
device_compute_capability
[
0
]
*
10
+
device_compute_capability
[
1
]
}
"
)
)
# Get determinism
_deterministic
=
(
not
bool
(
int
(
os
.
getenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
,
"1"
)))
or
torch
.
are_deterministic_algorithms_enabled
()
)
# Reset RNG seed and states
# Reset RNG seed and states
seed
=
1234
seed
=
1234
reset_rng_states
()
reset_rng_states
()
...
@@ -147,6 +155,7 @@ def test_dot_product_attention(
...
@@ -147,6 +155,7 @@ def test_dot_product_attention(
if
config
.
window_size
==
(
-
1
,
-
1
)
and
swa
:
if
config
.
window_size
==
(
-
1
,
-
1
)
and
swa
:
config
.
window_size
=
[
2
,
2
]
config
.
window_size
=
[
2
,
2
]
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
qkv_format
=
qkv_layout
.
replace
(
"3"
,
""
).
replace
(
"2"
,
""
).
split
(
"_"
)[
0
]
qkv_format
=
qkv_layout
.
replace
(
"3"
,
""
).
replace
(
"2"
,
""
).
split
(
"_"
)[
0
]
if
qkv_format
==
"thd"
and
"padding"
not
in
config
.
attn_mask_type
:
if
qkv_format
==
"thd"
and
"padding"
not
in
config
.
attn_mask_type
:
...
@@ -162,8 +171,10 @@ def test_dot_product_attention(
...
@@ -162,8 +171,10 @@ def test_dot_product_attention(
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
pad_between_seqs
=
pad_between_seqs
,
pad_between_seqs
=
pad_between_seqs
,
is_training
=
is_training
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
:
if
not
fused_attn_supported
:
is_training
=
False
is_training
=
False
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
...
@@ -172,6 +183,7 @@ def test_dot_product_attention(
...
@@ -172,6 +183,7 @@ def test_dot_product_attention(
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
pad_between_seqs
=
pad_between_seqs
,
pad_between_seqs
=
pad_between_seqs
,
is_training
=
is_training
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
...
@@ -421,6 +433,15 @@ def test_dpa_softmax(dtype, model_configs, model):
...
@@ -421,6 +433,15 @@ def test_dpa_softmax(dtype, model_configs, model):
)
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
18
,
0
),
reason
=
"cuDNN 9.18.0+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_softmax
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_softmax
.
keys
())
def
test_dpa_softmax_thd
(
dtype
,
model_configs
,
model
):
"""Test DotProductAttention module with different softmax types"""
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
"thd_thd_thd"
,
False
,
False
)
model_configs_mla
=
{
model_configs_mla
=
{
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
...
@@ -685,9 +706,10 @@ model_configs_swa = {
...
@@ -685,9 +706,10 @@ model_configs_swa = {
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_swa
])
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_swa
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_swa
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_swa
.
keys
())
def
test_dpa_sliding_window
(
dtype
,
model_configs
,
model
):
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
[
"thd_thd_thd"
,
"sbhd_sbhd_sbhd"
])
def
test_dpa_sliding_window
(
dtype
,
model_configs
,
model
,
qkv_layout
):
"""Test DotProductAttention module with sliding window attention"""
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
False
,
True
,
None
,
True
,
False
)
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
False
,
True
,
qkv_layout
,
True
,
False
)
model_configs_alibi_slopes
=
{
model_configs_alibi_slopes
=
{
...
@@ -889,11 +911,14 @@ def _run_dot_product_attention(
...
@@ -889,11 +911,14 @@ def _run_dot_product_attention(
reset_rng_states
()
reset_rng_states
()
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
backend
==
"FlashAttention"
:
if
backend
==
"FlashAttention"
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
if
backend
==
"FusedAttention"
:
if
backend
==
"FusedAttention"
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"
]
=
"1"
if
workspace_opt
else
"0"
os
.
environ
[
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"
]
=
"1"
if
workspace_opt
else
"0"
if
backend
==
"UnfusedDotProductAttention"
:
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
# Create seqlens
# Create seqlens
...
@@ -1295,6 +1320,7 @@ def test_transformer_layer(
...
@@ -1295,6 +1320,7 @@ def test_transformer_layer(
qkv_format
.
replace
(
"hd"
,
"h3d"
)
if
fused_qkv_params
else
qkv_format
.
replace
(
"hd"
,
"3hd"
)
qkv_format
.
replace
(
"hd"
,
"h3d"
)
if
fused_qkv_params
else
qkv_format
.
replace
(
"hd"
,
"3hd"
)
),
),
is_training
=
is_training
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
:
if
not
fused_attn_supported
:
...
@@ -1308,6 +1334,7 @@ def test_transformer_layer(
...
@@ -1308,6 +1334,7 @@ def test_transformer_layer(
else
qkv_format
.
replace
(
"hd"
,
"3hd"
)
else
qkv_format
.
replace
(
"hd"
,
"3hd"
)
),
),
is_training
=
is_training
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
...
@@ -1435,10 +1462,13 @@ def _run_transformer_layer(
...
@@ -1435,10 +1462,13 @@ def _run_transformer_layer(
reset_rng_states
()
reset_rng_states
()
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
backend
==
"FlashAttention"
:
if
backend
==
"FlashAttention"
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
if
backend
==
"FusedAttention"
:
if
backend
==
"FusedAttention"
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
if
backend
==
"UnfusedDotProductAttention"
:
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
# Create input tensor
# Create input tensor
...
@@ -1632,6 +1662,7 @@ def test_dpa_fp8_extra_state(model, dtype):
...
@@ -1632,6 +1662,7 @@ def test_dpa_fp8_extra_state(model, dtype):
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_layout
=
"sb3hd"
,
qkv_layout
=
"sb3hd"
,
is_training
=
is_training
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
and
not
flash_attn_supported
:
if
not
fused_attn_supported
and
not
flash_attn_supported
:
...
@@ -1822,6 +1853,7 @@ def test_mha_fp8_vs_f16(
...
@@ -1822,6 +1853,7 @@ def test_mha_fp8_vs_f16(
fp8
=
True
,
fp8
=
True
,
fp8_meta
=
fp8_meta
,
fp8_meta
=
fp8_meta
,
is_training
=
is_training
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
)
flash_attn_supported
,
fused_attn_supported_fp8
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported_fp8
,
unfused_attn_supported
=
available_backends
if
flash_attn_supported
+
fused_attn_supported_fp8
<
1
:
if
flash_attn_supported
+
fused_attn_supported_fp8
<
1
:
...
@@ -1833,6 +1865,7 @@ def test_mha_fp8_vs_f16(
...
@@ -1833,6 +1865,7 @@ def test_mha_fp8_vs_f16(
qkv_dtype
=
dtype
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_format
.
replace
(
"hd"
,
"h3d"
),
qkv_layout
=
qkv_format
.
replace
(
"hd"
,
"h3d"
),
is_training
=
is_training
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
)
_
,
fused_attn_supported_f16
,
_
=
available_backends
_
,
fused_attn_supported_f16
,
_
=
available_backends
if
not
fused_attn_supported_f16
:
if
not
fused_attn_supported_f16
:
...
@@ -1841,6 +1874,7 @@ def test_mha_fp8_vs_f16(
...
@@ -1841,6 +1874,7 @@ def test_mha_fp8_vs_f16(
if
flash_attn_supported
:
if
flash_attn_supported
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = True"
)
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = True"
)
flash_attn_fwd_fp8
,
param_names
,
flash_attn_bwd_fp8
=
_run_mha_fp8_vs_f16
(
flash_attn_fwd_fp8
,
param_names
,
flash_attn_bwd_fp8
=
_run_mha_fp8_vs_f16
(
...
@@ -1850,6 +1884,7 @@ def test_mha_fp8_vs_f16(
...
@@ -1850,6 +1884,7 @@ def test_mha_fp8_vs_f16(
if
fused_attn_supported_fp8
:
if
fused_attn_supported_fp8
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = True"
)
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = True"
)
fused_attn_fwd_fp8
,
param_names
,
fused_attn_bwd_fp8
=
_run_mha_fp8_vs_f16
(
fused_attn_fwd_fp8
,
param_names
,
fused_attn_bwd_fp8
=
_run_mha_fp8_vs_f16
(
...
@@ -1859,6 +1894,7 @@ def test_mha_fp8_vs_f16(
...
@@ -1859,6 +1894,7 @@ def test_mha_fp8_vs_f16(
if
fused_attn_supported_f16
:
if
fused_attn_supported_f16
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = False"
)
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = False"
)
fused_attn_fwd_f16
,
param_names
,
fused_attn_bwd_f16
=
_run_mha_fp8_vs_f16
(
fused_attn_fwd_f16
,
param_names
,
fused_attn_bwd_f16
=
_run_mha_fp8_vs_f16
(
...
@@ -2071,6 +2107,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
...
@@ -2071,6 +2107,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
fp8
=
True
,
fp8
=
True
,
fp8_meta
=
fp8_meta
,
fp8_meta
=
fp8_meta
,
is_training
=
is_training
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
flash_attn_supported
+
fused_attn_supported
<
1
:
if
flash_attn_supported
+
fused_attn_supported
<
1
:
...
@@ -2081,6 +2118,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
...
@@ -2081,6 +2118,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
qkv_dtype
=
dtype
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
)
_
,
fused_attn_supported
,
_
=
available_backends
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
if
not
fused_attn_supported
:
...
@@ -2091,6 +2129,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
...
@@ -2091,6 +2129,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if
flash_attn_supported
:
if
flash_attn_supported
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)"
)
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)"
)
flash_attn_fwd_fp8
,
flash_attn_bwd_fp8
=
_run_dpa_fp8_vs_f16
(
flash_attn_fwd_fp8
,
flash_attn_bwd_fp8
=
_run_dpa_fp8_vs_f16
(
...
@@ -2100,6 +2139,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
...
@@ -2100,6 +2139,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if
unfused_attn_supported
:
if
unfused_attn_supported
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)"
)
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)"
)
unfused_attn_fwd_fp8
,
unfused_attn_bwd_fp8
=
_run_dpa_fp8_vs_f16
(
unfused_attn_fwd_fp8
,
unfused_attn_bwd_fp8
=
_run_dpa_fp8_vs_f16
(
...
@@ -2108,6 +2148,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
...
@@ -2108,6 +2148,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)"
)
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)"
)
fused_attn_fwd_fp8
,
fused_attn_bwd_fp8
=
_run_dpa_fp8_vs_f16
(
fused_attn_fwd_fp8
,
fused_attn_bwd_fp8
=
_run_dpa_fp8_vs_f16
(
...
@@ -2116,6 +2157,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
...
@@ -2116,6 +2157,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
config
.
dropout_p
==
0.0
:
if
config
.
dropout_p
==
0.0
:
# test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
# test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)"
)
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)"
)
...
@@ -2370,13 +2412,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
...
@@ -2370,13 +2412,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_layout
=
"t3hd"
if
cudnn_frontend_version
==
0
else
"bs3hd"
,
qkv_layout
=
"t3hd"
if
cudnn_frontend_version
==
0
else
"bs3hd"
,
is_training
=
is_training
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
(
fused_attn_backends
and
unfused_attn_supported
):
if
not
(
fused_attn_backends
and
unfused_attn_supported
):
pytest
.
skip
(
"Not enough backends to run this test with."
)
pytest
.
skip
(
"Not enough backends to run this test with."
)
fused_attn_fwd_fp8
,
fused_attn_bwd_fp8
=
_run_custom_mha_fp8
(
dtype
,
config
,
"FusedAttention"
)
fused_attn_fwd_fp8
,
fused_attn_bwd_fp8
=
_run_custom_mha_fp8
(
dtype
,
config
,
"FusedAttention"
)
unfused_attn_fwd_f16
,
unfused_attn_bwd_f16
=
_run_ref_mha_f16
(
dtype
,
config
,
"UnfusedAttention"
)
unfused_attn_fwd_f16
,
unfused_attn_bwd_f16
=
_run_ref_mha_f16
(
dtype
,
config
,
"UnfusedDotProductAttention"
)
atol
=
5e-1
atol
=
5e-1
rtol
=
5e-1
rtol
=
5e-1
...
@@ -2409,10 +2454,13 @@ def _run_custom_mha_fp8(dtype, config, backend):
...
@@ -2409,10 +2454,13 @@ def _run_custom_mha_fp8(dtype, config, backend):
reset_rng_states
()
reset_rng_states
()
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
backend
==
"FlashAttention"
:
if
backend
==
"FlashAttention"
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
if
backend
==
"FusedAttention"
:
if
backend
==
"FusedAttention"
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
if
backend
==
"UnfusedDotProductAttention"
:
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
inp
=
0.0001
*
torch
.
randint
(
inp
=
0.0001
*
torch
.
randint
(
...
@@ -2463,10 +2511,13 @@ def _run_ref_mha_f16(dtype, config, backend):
...
@@ -2463,10 +2511,13 @@ def _run_ref_mha_f16(dtype, config, backend):
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
backend
==
"FlashAttention"
:
if
backend
==
"FlashAttention"
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
if
backend
==
"FusedAttention"
:
if
backend
==
"FusedAttention"
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
if
backend
==
"UnfusedDotProductAttention"
:
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
inp
=
torch
.
load
(
"qkv.pt"
).
to
(
device
=
"cuda"
)
inp
=
torch
.
load
(
"qkv.pt"
).
to
(
device
=
"cuda"
)
...
@@ -2754,7 +2805,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
...
@@ -2754,7 +2805,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
with
self
.
prepare_forward
(
inp
,
num_gemms
=
3
)
as
inp
:
with
self
.
prepare_forward
_ctx
(
inp
,
num_gemms
=
3
)
as
inp
:
out
=
_custom_mha_fp8
.
apply
(
out
=
_custom_mha_fp8
.
apply
(
inp
,
inp
,
self
.
qkv_weight
,
self
.
qkv_weight
,
...
...
tests/pytorch/attention/test_attention_with_cp.py
View file @
9df0c4a3
...
@@ -148,7 +148,7 @@ model_configs_fused_attn = {
...
@@ -148,7 +148,7 @@ model_configs_fused_attn = {
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
),
# MHA
"cp_1_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
"cp_1_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
"cp_1_4"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
)),
# MHA
"cp_1_4"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
512
)),
# MHA
"cp_2_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
),
# GQA
"cp_2_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
),
# GQA
"cp_2_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
),
# GQA
"cp_2_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
),
# GQA
"cp_2_2"
:
ModelConfig
(
"cp_2_2"
:
ModelConfig
(
...
@@ -164,7 +164,7 @@ model_configs_fused_attn = {
...
@@ -164,7 +164,7 @@ model_configs_fused_attn = {
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_bias_type
=
"post_scale_bias"
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_bias_type
=
"post_scale_bias"
),
# GQA
),
# GQA
"cp_2_4"
:
ModelConfig
(
"cp_2_4"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
)
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
512
)
),
# GQA
),
# GQA
"cp_3_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# MLA
"cp_3_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# MLA
"cp_3_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
head_dim_v
=
64
),
# MLA
"cp_3_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
head_dim_v
=
64
),
# MLA
...
@@ -188,7 +188,16 @@ dtypes = ["bf16", "fp16", "fp8"]
...
@@ -188,7 +188,16 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
if
test_essential
:
if
test_essential
:
configs
=
[
"cp_1_0"
,
"cp_1_1"
,
"cp_1_4"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
configs
=
[
"cp_1_0"
,
"cp_1_1"
,
"cp_1_4"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_2_4"
,
"cp_3_2"
,
"cp_4_2"
,
]
model_configs_fused_attn
=
{
k
:
model_configs_fused_attn
[
k
]
for
k
in
configs
}
model_configs_fused_attn
=
{
k
:
model_configs_fused_attn
[
k
]
for
k
in
configs
}
dtypes
=
[
"bf16"
,
"fp8"
]
dtypes
=
[
"bf16"
,
"fp8"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
...
@@ -284,9 +293,14 @@ def test_cp_with_fused_attention(
...
@@ -284,9 +293,14 @@ def test_cp_with_fused_attention(
pytest
.
skip
(
pytest
.
skip
(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
)
)
if
config
.
softmax_type
!=
"vanilla"
and
qkv_format
==
"thd"
:
if
(
get_cudnn_version
()
<
(
9
,
18
,
0
)
and
config
.
softmax_type
!=
"vanilla"
and
qkv_format
==
"thd"
):
pytest
.
skip
(
pytest
.
skip
(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!"
"Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for"
" non-vanilla softmax types!"
)
)
dtypes
=
{
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
,
"fp8"
:
torch
.
bfloat16
}
dtypes
=
{
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
,
"fp8"
:
torch
.
bfloat16
}
...
...
tests/pytorch/debug/test_log.py
View file @
9df0c4a3
...
@@ -15,6 +15,7 @@ from transformer_engine.pytorch import (
...
@@ -15,6 +15,7 @@ from transformer_engine.pytorch import (
is_fp8_available
,
is_fp8_available
,
is_mxfp8_available
,
is_mxfp8_available
,
is_fp8_block_scaling_available
,
is_fp8_block_scaling_available
,
is_nvfp4_available
,
)
)
from
transformer_engine.pytorch.quantization
import
RecipeState
from
transformer_engine.pytorch.quantization
import
RecipeState
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
...
@@ -29,6 +30,7 @@ mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
...
@@ -29,6 +30,7 @@ mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
is_fp8_block_scaling_available
(
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
is_fp8_block_scaling_available
(
return_reason
=
True
return_reason
=
True
)
)
nvfp4_available
,
reason_for_no_nvfp4
=
is_nvfp4_available
(
return_reason
=
True
)
LOG_QUANTIZED_CONFIG_BASE
=
"""
LOG_QUANTIZED_CONFIG_BASE
=
"""
log:
log:
...
@@ -363,6 +365,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
...
@@ -363,6 +365,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState
.
_reset
()
TEDebugState
.
_reset
()
# NVFP4 tests
LOG_NVFP4_CONFIG_BASE
=
"""
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogNvfp4TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
def
test_nvfp4_numeric
(
feature_dirs
):
"""Test that NVFP4 underflows% and MSE stats are computed correctly with known values."""
if
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
log_nvfp4_config
=
LOG_NVFP4_CONFIG_BASE
.
format
(
stats
=
"underflows%, mse"
)
with
debug_session
(
log_nvfp4_config
,
feature_dirs
)
as
log_dir
:
from
transformer_engine.pytorch.tensor.nvfp4_tensor
import
NVFP4Quantizer
from
transformer_engine.pytorch.quantization
import
RecipeState
recipe_state
=
RecipeState
.
create
(
recipe
.
NVFP4BlockScaling
(),
mode
=
"forward"
,
num_quantizers
=
3
,
)
# Create test tensor with known distribution
torch
.
manual_seed
(
42
)
tensor
=
torch
.
randn
(
128
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
# Add some small values that should underflow to zero in FP4
tensor
[
0
,
:
16
]
=
0.0001
quantizer
=
recipe_state
.
make_quantizers
()[
0
]
quantized_tensor
=
quantizer
(
tensor
)
debug_api
.
transformer_engine
.
inspect_tensor
(
layer_name
=
"test_layer"
,
tensor_name
=
"activation"
,
iteration
=
0
,
tp_group
=
None
,
tensor
=
tensor
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
quantized_tensor
,
columnwise_quantized_tensor
=
quantized_tensor
,
)
debug_api
.
step
()
dequantized_tensor
=
quantized_tensor
.
dequantize
()
output
=
read_log
(
log_dir
)
# Validate both stats are present
assert
"nvfp4_underflows%"
in
output
,
"underflows% stat missing"
assert
"nvfp4_mse"
in
output
,
"mse stat missing"
# Extract values and validate numerics
underflows_value
=
None
mse_value
=
None
for
line
in
output
.
splitlines
():
if
"nvfp4_underflows%"
in
line
and
"value="
in
line
:
underflows_value
=
float
(
line
.
split
(
"value="
)[
1
].
split
()[
0
])
if
"nvfp4_mse"
in
line
and
"value="
in
line
:
mse_value
=
float
(
line
.
split
(
"value="
)[
1
].
split
()[
0
])
# Compute expected underflows: non-zero elements that became zero after quantization
orig_nonzero_mask
=
tensor
!=
0
dequant_zero_mask
=
dequantized_tensor
==
0
expected_underflows
=
(
(
orig_nonzero_mask
&
dequant_zero_mask
).
sum
().
float
()
/
tensor
.
numel
()
*
100
)
# Allow some tolerance
assert
underflows_value
==
pytest
.
approx
(
expected_underflows
.
cpu
().
item
(),
abs
=
1e-4
)
# Compute expected MSE
expected_mse
=
torch
.
nn
.
functional
.
mse_loss
(
dequantized_tensor
.
float
(),
tensor
.
float
(),
reduction
=
"mean"
)
assert
mse_value
==
pytest
.
approx
(
expected_mse
.
cpu
().
item
(),
abs
=
1e-4
)
def
test_fp8_stats_allows_nvfp4_with_recipe_prefix
(
feature_dirs
):
"""Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis."""
if
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
# Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately)
log_fp8_config
=
LOG_QUANTIZED_CONFIG_BASE
.
format
(
stats
=
"mxfp8_mse"
)
with
debug_session
(
log_fp8_config
,
feature_dirs
)
as
log_dir
:
model
=
te
.
Linear
(
128
,
128
,
params_dtype
=
torch
.
bfloat16
)
inp
=
torch
.
randn
(
128
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
# Should work - recipe-prefixed stats compute MXFP8 separately for comparison
for
_
in
range
(
2
):
with
te
.
autocast
(
recipe
=
recipe
.
NVFP4BlockScaling
()):
output
=
model
(
inp
)
loss
=
output
.
sum
()
loss
.
backward
()
debug_api
.
step
()
output
=
read_log
(
log_dir
)
# Should have logged MXFP8 MSE stat (what-if scenario)
assert
"mxfp8_mse"
in
output
def
test_log_grouped_gemm
(
feature_dirs
):
def
test_log_grouped_gemm
(
feature_dirs
):
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
...
...
tests/pytorch/debug/test_sanity.py
View file @
9df0c4a3
...
@@ -30,10 +30,17 @@ configs = {
...
@@ -30,10 +30,17 @@ configs = {
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
start_step : 0
end_step: 1
end_step: 1
"""
,
"log_fp8"
:
"""log_fp8:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogFp8TensorStats:
LogFp8TensorStats:
enabled: True
enabled: True
tensors: [activation, gradient, weight]
tensors: [activation, gradient, weight]
stats: [underflows
, overflows
]
stats: [underflows
%
]
start_step : 0
start_step : 0
end_step: 1
end_step: 1
"""
,
"""
,
...
@@ -46,22 +53,26 @@ fake_quant_config:
...
@@ -46,22 +53,26 @@ fake_quant_config:
FakeQuant:
FakeQuant:
enabled: True
enabled: True
gemms: [fprop, dgrad, wgrad]
gemms: [fprop, dgrad, wgrad]
tensors: [activation, weight, gradient]
quant_format: FP8E5M2
quant_format: FP8E5M2
"""
,
"""
,
}
}
# Configs that require FP8 to be enabled
fp8_required_configs
=
{
"log_fp8"
}
def
_get_model
(
model_key
):
def
_get_model
(
model_key
):
if
model_key
==
"linear"
:
if
model_key
==
"linear"
:
return
te
.
Linear
(
D
,
D
)
return
te
.
Linear
(
D
,
D
,
name
=
"layer"
)
if
model_key
==
"layernorm_linear"
:
if
model_key
==
"layernorm_linear"
:
return
te
.
LayerNormLinear
(
D
,
D
)
return
te
.
LayerNormLinear
(
D
,
D
,
name
=
"layer"
)
if
model_key
==
"layernorm_mlp"
:
if
model_key
==
"layernorm_mlp"
:
return
te
.
LayerNormMLP
(
D
,
D
,
D
)
return
te
.
LayerNormMLP
(
D
,
D
,
D
,
name
=
"layer"
)
if
model_key
==
"mha_attention"
:
if
model_key
==
"mha_attention"
:
return
te
.
MultiheadAttention
(
D
,
H
)
return
te
.
MultiheadAttention
(
D
,
H
,
name
=
"layer"
)
if
model_key
==
"transformer_layer"
:
if
model_key
==
"transformer_layer"
:
return
te
.
TransformerLayer
(
D
,
D
,
H
)
return
te
.
TransformerLayer
(
D
,
D
,
H
,
name
=
"layer"
)
def
_run_forward_backward
(
model
,
fp8
):
def
_run_forward_backward
(
model
,
fp8
):
...
@@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
...
@@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
def
test_sanity_debug
(
model_key
,
fp8
,
config_key
,
feature_dirs
):
def
test_sanity_debug
(
model_key
,
fp8
,
config_key
,
feature_dirs
):
if
fp8
and
not
fp8_available
:
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
not
fp8
and
config_key
in
fp8_required_configs
:
pytest
.
skip
(
f
"Config '
{
config_key
}
' requires FP8"
)
_run_test
(
model_key
,
fp8
,
configs
[
config_key
],
feature_dirs
)
_run_test
(
model_key
,
fp8
,
configs
[
config_key
],
feature_dirs
)
tests/pytorch/test_checkpoint.py
View file @
9df0c4a3
...
@@ -101,7 +101,7 @@ class TestLoadCheckpoint:
...
@@ -101,7 +101,7 @@ class TestLoadCheckpoint:
# Path to save checkpoint
# Path to save checkpoint
if
checkpoint_dir
is
None
:
if
checkpoint_dir
is
None
:
checkpoint_dir
=
TestLoadCheckpoint
.
_checkpoint_dir
()
checkpoint_dir
=
TestLoadCheckpoint
.
_checkpoint_dir
()
checkpoint_dir
.
mkdir
(
exist_ok
=
True
)
checkpoint_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
checkpoint_file
=
checkpoint_dir
/
f
"
{
name
}
.pt"
checkpoint_file
=
checkpoint_dir
/
f
"
{
name
}
.pt"
# Create module and save checkpoint
# Create module and save checkpoint
...
...
tests/pytorch/test_fusible_ops.py
View file @
9df0c4a3
...
@@ -5,8 +5,10 @@
...
@@ -5,8 +5,10 @@
from
__future__
import
annotations
from
__future__
import
annotations
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
import
functools
import
io
import
io
import
math
import
math
import
random
from
typing
import
Optional
from
typing
import
Optional
import
pytest
import
pytest
...
@@ -37,7 +39,14 @@ from transformer_engine.pytorch import (
...
@@ -37,7 +39,14 @@ from transformer_engine.pytorch import (
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
# Import utility functions
# Import utility functions
from
utils
import
dtype_tols
,
make_recipe
,
quantization_tols
,
reset_rng_states
from
utils
import
(
assert_close
,
assert_close_grads
,
dtype_tols
,
make_recipe
,
quantization_tols
,
reset_rng_states
,
)
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
import
os
import
os
...
@@ -116,6 +125,9 @@ def maybe_skip_quantization(
...
@@ -116,6 +125,9 @@ def maybe_skip_quantization(
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
make_reference_and_test_tensors
(
def
make_reference_and_test_tensors
(
shape
:
int
|
Iterable
[
int
],
shape
:
int
|
Iterable
[
int
],
*
,
min
:
float
=
0.0
,
max
:
float
=
1.0
,
quantization
:
Optional
[
str
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
ref_dtype
:
torch
.
dtype
=
torch
.
float64
,
ref_dtype
:
torch
.
dtype
=
torch
.
float64
,
ref_device
:
torch
.
device
=
"cpu"
,
ref_device
:
torch
.
device
=
"cpu"
,
...
@@ -136,7 +148,8 @@ def make_reference_and_test_tensors(
...
@@ -136,7 +148,8 @@ def make_reference_and_test_tensors(
"""
"""
# Random reference tensor
# Random reference tensor
ref
=
torch
.
rand
(
shape
,
dtype
=
ref_dtype
,
device
=
ref_device
)
ref
=
torch
.
empty
(
shape
,
dtype
=
ref_dtype
,
device
=
ref_device
)
ref
.
uniform_
(
min
,
max
)
# Construct test tensor from reference tensor
# Construct test tensor from reference tensor
test
=
ref
.
to
(
device
=
test_device
,
dtype
=
test_dtype
)
test
=
ref
.
to
(
device
=
test_device
,
dtype
=
test_dtype
)
...
@@ -1569,7 +1582,19 @@ class TestBasicOps:
...
@@ -1569,7 +1582,19 @@ class TestBasicOps:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"activation"
,
"activation"
,
(
"gelu"
,
"geglu"
,
"qgelu"
,
"qgeglu"
,
"relu"
,
"reglu"
,
"srelu"
,
"sreglu"
,
"silu"
,
"swiglu"
),
(
"gelu"
,
"geglu"
,
"qgelu"
,
"qgeglu"
,
"relu"
,
"reglu"
,
"glu"
,
"srelu"
,
"sreglu"
,
"silu"
,
"swiglu"
,
),
)
)
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
37
,),
(
2
,
13
),
(
32
,
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
37
,),
(
2
,
13
),
(
32
,
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
...
@@ -1589,7 +1614,7 @@ class TestBasicOps:
...
@@ -1589,7 +1614,7 @@ class TestBasicOps:
# Tensor dimensions
# Tensor dimensions
in_shape
=
list
(
out_shape
)
in_shape
=
list
(
out_shape
)
if
activation
in
(
"geglu"
,
"qgeglu"
,
"reglu"
,
"sreglu"
,
"swiglu"
):
if
activation
in
(
"geglu"
,
"glu"
,
"qgeglu"
,
"reglu"
,
"sreglu"
,
"swiglu"
):
in_shape
[
-
1
]
*=
2
in_shape
[
-
1
]
*=
2
# Skip invalid configurations
# Skip invalid configurations
...
@@ -1629,6 +1654,13 @@ class TestBasicOps:
...
@@ -1629,6 +1654,13 @@ class TestBasicOps:
elif
activation
==
"reglu"
:
elif
activation
==
"reglu"
:
x1
,
x2
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
x1
,
x2
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
y_ref
=
torch
.
nn
.
functional
.
relu
(
x1
)
*
x2
y_ref
=
torch
.
nn
.
functional
.
relu
(
x1
)
*
x2
elif
activation
==
"sigmoid"
:
y_ref
=
torch
.
nn
.
functional
.
sigmoid
(
x_ref
)
elif
activation
==
"glu"
:
x
=
x_ref
.
reshape
(
*
in_shape
[:
-
1
],
2
,
in_shape
[
-
1
]
//
2
)
x
=
x
.
flip
(
-
2
)
# PyTorch GLU swaps gate and linear unit
x
=
x
.
reshape
(
in_shape
)
y_ref
=
torch
.
nn
.
functional
.
glu
(
x
)
elif
activation
==
"srelu"
:
elif
activation
==
"srelu"
:
y_ref
=
torch
.
nn
.
functional
.
relu
(
x_ref
)
**
2
y_ref
=
torch
.
nn
.
functional
.
relu
(
x_ref
)
**
2
elif
activation
==
"sreglu"
:
elif
activation
==
"sreglu"
:
...
@@ -1648,6 +1680,7 @@ class TestBasicOps:
...
@@ -1648,6 +1680,7 @@ class TestBasicOps:
make_op
=
dict
(
make_op
=
dict
(
gelu
=
te_ops
.
GELU
,
gelu
=
te_ops
.
GELU
,
geglu
=
te_ops
.
GEGLU
,
geglu
=
te_ops
.
GEGLU
,
glu
=
te_ops
.
GLU
,
qgelu
=
te_ops
.
QGELU
,
qgelu
=
te_ops
.
QGELU
,
qgeglu
=
te_ops
.
QGEGLU
,
qgeglu
=
te_ops
.
QGEGLU
,
relu
=
te_ops
.
ReLU
,
relu
=
te_ops
.
ReLU
,
...
@@ -1692,6 +1725,7 @@ class TestBasicOps:
...
@@ -1692,6 +1725,7 @@ class TestBasicOps:
quantization
:
Optional
[
str
],
quantization
:
Optional
[
str
],
quantize_forward
:
bool
,
quantize_forward
:
bool
,
quantize_backward
:
bool
,
quantize_backward
:
bool
,
glu_interleave_size
:
Optional
[
int
]
=
None
,
):
):
# Tensor dimensions
# Tensor dimensions
...
@@ -1718,7 +1752,17 @@ class TestBasicOps:
...
@@ -1718,7 +1752,17 @@ class TestBasicOps:
)
)
# Plain PyTorch implementation
# Plain PyTorch implementation
x1
,
x2
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
x
=
x_ref
if
glu_interleave_size
is
not
None
:
x
=
x
.
reshape
(
*
in_shape
[:
-
1
],
in_shape
[
-
1
]
//
(
2
*
glu_interleave_size
),
2
,
glu_interleave_size
,
)
x
=
x
.
transpose
(
-
3
,
-
2
)
x
=
x
.
reshape
(
in_shape
)
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
y_ref
=
torch
.
nn
.
functional
.
silu
(
x1
)
*
x2
y_ref
=
torch
.
nn
.
functional
.
silu
(
x1
)
*
x2
y_ref
.
backward
(
dy_ref
)
y_ref
.
backward
(
dy_ref
)
...
@@ -1726,7 +1770,7 @@ class TestBasicOps:
...
@@ -1726,7 +1770,7 @@ class TestBasicOps:
recipe
=
make_recipe
(
quantization
)
recipe
=
make_recipe
(
quantization
)
forward
=
te_ops
.
Sequential
(
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
False
,
backward
=
quantize_backward
),
te_ops
.
Quantize
(
forward
=
False
,
backward
=
quantize_backward
),
te_ops
.
SwiGLU
(),
te_ops
.
SwiGLU
(
glu_interleave_size
=
glu_interleave_size
),
te_ops
.
Quantize
(
forward
=
quantize_forward
,
backward
=
False
),
te_ops
.
Quantize
(
forward
=
quantize_forward
,
backward
=
False
),
)
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
...
@@ -1739,10 +1783,19 @@ class TestBasicOps:
...
@@ -1739,10 +1783,19 @@ class TestBasicOps:
tols
=
quantization_tols
(
quantization
)
tols
=
quantization_tols
(
quantization
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
assert_close
(
y_test
,
y_ref
,
**
tols
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
assert_close_grads
(
x_test
,
x_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
def
test_interleaved_swiglu
(
self
):
"""SwiGLU with block interleaved input format"""
self
.
test_swiglu
(
out_shape
=
(
32
,
192
),
dtype
=
torch
.
float32
,
quantization
=
None
,
quantize_forward
=
False
,
quantize_backward
=
False
,
glu_interleave_size
=
32
,
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
...
@@ -1752,6 +1805,7 @@ class TestBasicOps:
...
@@ -1752,6 +1805,7 @@ class TestBasicOps:
self
,
self
,
*
,
*
,
out_shape
:
Iterable
[
int
]
=
(
32
,
32
),
out_shape
:
Iterable
[
int
]
=
(
32
,
32
),
glu_interleave_size
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantization
:
Optional
[
str
],
...
@@ -1760,7 +1814,7 @@ class TestBasicOps:
...
@@ -1760,7 +1814,7 @@ class TestBasicOps:
limit
:
float
=
0.75
,
limit
:
float
=
0.75
,
alpha
:
float
=
1.702
,
alpha
:
float
=
1.702
,
):
):
# Test
SwiGLU variant used in GPT
OSS
.
"""
SwiGLU variant used in GPT
-
OSS
"""
# Tensor dimensions
# Tensor dimensions
in_shape
=
list
(
out_shape
)
in_shape
=
list
(
out_shape
)
in_shape
[
-
1
]
*=
2
in_shape
[
-
1
]
*=
2
...
@@ -1785,7 +1839,17 @@ class TestBasicOps:
...
@@ -1785,7 +1839,17 @@ class TestBasicOps:
)
)
# Plain PyTorch implementation
# Plain PyTorch implementation
x_glu
,
x_linear
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
x
=
x_ref
if
glu_interleave_size
is
not
None
:
x
=
x
.
reshape
(
*
in_shape
[:
-
1
],
in_shape
[
-
1
]
//
(
2
*
glu_interleave_size
),
2
,
glu_interleave_size
,
)
x
=
x
.
transpose
(
-
3
,
-
2
)
x
=
x
.
reshape
(
in_shape
)
x_glu
,
x_linear
=
x
.
chunk
(
2
,
dim
=-
1
)
x_glu
=
x_glu
.
clamp
(
min
=
None
,
max
=
limit
)
x_glu
=
x_glu
.
clamp
(
min
=
None
,
max
=
limit
)
x_linear
=
x_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
x_linear
=
x_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
out_glu
=
x_glu
*
torch
.
sigmoid
(
alpha
*
x_glu
)
out_glu
=
x_glu
*
torch
.
sigmoid
(
alpha
*
x_glu
)
...
@@ -1797,7 +1861,11 @@ class TestBasicOps:
...
@@ -1797,7 +1861,11 @@ class TestBasicOps:
forward
=
te_ops
.
Sequential
(
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
False
,
backward
=
quantize_backward
),
te_ops
.
Quantize
(
forward
=
False
,
backward
=
quantize_backward
),
te_ops
.
ClampedSwiGLU
(
limit
=
limit
,
alpha
=
alpha
),
te_ops
.
ClampedSwiGLU
(
limit
=
limit
,
alpha
=
alpha
,
glu_interleave_size
=
glu_interleave_size
,
),
te_ops
.
Quantize
(
forward
=
quantize_forward
,
backward
=
False
),
te_ops
.
Quantize
(
forward
=
quantize_forward
,
backward
=
False
),
)
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
...
@@ -1813,10 +1881,19 @@ class TestBasicOps:
...
@@ -1813,10 +1881,19 @@ class TestBasicOps:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
# Check results
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
assert_close
(
y_test
,
y_ref
,
**
tols
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
assert_close_grads
(
x_test
,
x_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
def
test_interleaved_clamped_swiglu
(
self
):
"""GPT-OSS SwiGLU with block interleaved input format"""
self
.
test_clamped_swiglu
(
out_shape
=
(
32
,
192
),
dtype
=
torch
.
float32
,
quantization
=
None
,
quantize_forward
=
False
,
quantize_backward
=
False
,
glu_interleave_size
=
32
,
)
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
1
,
0
,
-
2.5
,
3.5
))
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
1
,
0
,
-
2.5
,
3.5
))
@
pytest
.
mark
.
parametrize
(
"shape"
,
((),
(
1
,
13
),
(
4
,
4
,
2
)))
@
pytest
.
mark
.
parametrize
(
"shape"
,
((),
(
1
,
13
),
(
4
,
4
,
2
)))
...
@@ -1936,6 +2013,231 @@ class TestBasicOps:
...
@@ -1936,6 +2013,231 @@ class TestBasicOps:
abs
(
z_score
)
<
2.5758
abs
(
z_score
)
<
2.5758
),
f
"Number of zeros is outside 99% confidence interval (
{
prob
=
}
,
{
prob_observed
=
}
)"
),
f
"Number of zeros is outside 99% confidence interval (
{
prob
=
}
,
{
prob_observed
=
}
)"
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantized_compute"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"input_requires_grad"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"weight_requires_grad"
,
(
False
,
True
))
def
test_grouped_linear
(
self
,
*
,
group_size
:
int
=
4
,
bias
:
bool
,
weight_shape
:
tuple
[
int
,
int
]
=
(
128
,
128
),
split_alignment
:
int
=
128
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantized_compute
:
bool
,
quantized_weight
:
bool
,
input_requires_grad
:
bool
,
weight_requires_grad
:
bool
,
)
->
None
:
"""Grouped GEMM"""
# Split sizes
split_sizes
=
[
split_alignment
*
i
for
i
in
range
(
group_size
)]
random
.
shuffle
(
split_sizes
)
split_sizes
=
torch
.
tensor
(
split_sizes
,
dtype
=
torch
.
int
,
device
=
device
)
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
in_shape
=
(
split_sizes
.
sum
().
item
(),
in_features
)
out_shape
=
(
in_shape
[
0
],
out_features
)
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantization
is
None
and
(
quantized_compute
or
quantized_weight
):
pytest
.
skip
(
"Quantization scheme is not specified"
)
if
quantization
is
not
None
and
not
(
quantized_compute
or
quantized_weight
):
pytest
.
skip
(
"Quantization scheme is not used"
)
if
quantization
is
not
None
and
dtype
not
in
(
torch
.
bfloat16
,
torch
.
float16
):
pytest
.
skip
(
"Quantized group GEMM is only supported with BF16/FP16"
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
input_requires_grad
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
ws_ref
,
ws_test
=
[],
[]
bs_ref
,
bs_test
=
[],
[]
for
_
in
range
(
group_size
):
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
weight_requires_grad
,
)
b_ref
,
b_test
=
None
,
None
if
bias
:
b_ref
,
b_test
=
make_reference_and_test_tensors
(
out_features
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
weight_requires_grad
,
)
ws_ref
.
append
(
w_ref
)
ws_test
.
append
(
w_test
)
bs_ref
.
append
(
b_ref
)
bs_test
.
append
(
b_test
)
# Plain PyTorch implementation
xs_ref
=
torch
.
split
(
x_ref
,
split_sizes
.
tolist
())
ys_ref
=
[]
for
x
,
w
,
b
in
zip
(
xs_ref
,
ws_ref
,
bs_ref
):
ys_ref
.
append
(
torch
.
nn
.
functional
.
linear
(
x
,
w
,
bias
=
b
))
y_ref
=
torch
.
cat
(
ys_ref
)
if
input_requires_grad
or
weight_requires_grad
:
y_ref
.
backward
(
dy_ref
)
# Construct fusible operation
recipe
=
make_recipe
(
quantization
)
with
te
.
quantized_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
op
=
te_ops
.
GroupedLinear
(
group_size
,
in_features
,
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
with
torch
.
no_grad
():
for
group_idx
in
range
(
group_size
):
getattr
(
op
,
f
"weight
{
group_idx
}
"
).
copy_
(
ws_test
[
group_idx
])
if
bias
:
getattr
(
op
,
f
"bias
{
group_idx
}
"
).
copy_
(
bs_test
[
group_idx
])
del
ws_test
,
bs_test
for
param
in
op
.
parameters
():
param
.
requires_grad_
(
requires_grad
=
weight_requires_grad
)
# Forward and backward pass with op
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
op
(
x_test
,
split_sizes
)
if
input_requires_grad
or
weight_requires_grad
:
y_test
.
backward
(
dy_test
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
if
input_requires_grad
:
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
else
:
assert
x_test
.
grad
is
None
for
group_idx
in
range
(
group_size
):
w_test
=
getattr
(
op
,
f
"weight
{
group_idx
}
"
)
if
weight_requires_grad
:
dw_test
=
w_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
dw_test
,
ws_ref
[
group_idx
].
grad
,
**
tols
)
else
:
assert
w_test
.
grad
is
None
if
bias
:
b_test
=
getattr
(
op
,
f
"bias
{
group_idx
}
"
)
if
weight_requires_grad
:
db_test
=
b_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
db_test
,
bs_ref
[
group_idx
].
grad
,
**
tols
)
else
:
assert
b_test
.
grad
is
None
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
71
,
192
),
(
5
,
7
,
128
)))
@
pytest
.
mark
.
parametrize
(
"input_requires_grad"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"scales_requires_grad"
,
(
False
,
True
))
def
test_scaled_swiglu
(
self
,
*
,
in_shape
:
Iterable
[
int
],
glu_interleave_size
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
"cuda"
,
input_requires_grad
:
bool
,
scales_requires_grad
:
bool
,
)
->
None
:
"""SwiGLU with post-scale"""
# Tensor dims
out_shape
=
list
(
in_shape
)
out_shape
[
-
1
]
//=
2
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
input_requires_grad
,
)
scales_ref
,
scales_test
=
make_reference_and_test_tensors
(
in_shape
[:
-
1
],
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
scales_requires_grad
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
x
=
x_ref
if
glu_interleave_size
is
not
None
:
x
=
x
.
reshape
(
-
1
,
in_shape
[
-
1
]
//
(
2
*
glu_interleave_size
),
2
,
glu_interleave_size
,
)
x
=
x
.
transpose
(
1
,
2
)
x
=
x
.
reshape
(
in_shape
)
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
y
=
torch
.
nn
.
functional
.
silu
(
x1
)
*
x2
y_ref
=
scales_ref
.
unsqueeze
(
-
1
)
*
y
if
input_requires_grad
or
scales_requires_grad
:
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
op
=
te_ops
.
ScaledSwiGLU
(
glu_interleave_size
=
glu_interleave_size
)
y_test
=
op
(
x_test
,
scales_test
)
if
input_requires_grad
or
scales_requires_grad
:
y_test
.
backward
(
dy_test
)
# Check results
tols
=
dtype_tols
(
dtype
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
assert_close
(
y_test
,
y_ref
,
**
tols
)
assert_close_grads
(
x_test
,
x_ref
,
**
tols
)
assert_close_grads
(
scales_test
,
scales_ref
,
**
tols
)
def
test_interleaved_scaled_swiglu
(
self
):
"""SwiGLU with post-scale and block interleaved input format"""
self
.
test_scaled_swiglu
(
in_shape
=
(
32
,
192
),
glu_interleave_size
=
32
,
input_requires_grad
=
True
,
scales_requires_grad
=
True
,
)
class
TestFusedOps
:
class
TestFusedOps
:
"""Tests for fused operations"""
"""Tests for fused operations"""
...
@@ -2342,13 +2644,13 @@ class TestFusedOps:
...
@@ -2342,13 +2644,13 @@ class TestFusedOps:
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
if
with_quantization
:
if
with_quantization
:
assert
len
(
backward_ops
)
==
2
assert
len
(
backward_ops
)
==
2
assert
isinstance
(
backward_ops
[
0
][
0
],
BackwardActivationBias
)
assert
isinstance
(
backward_ops
[
0
][
0
],
te_ops
.
Quantize
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Quantize
)
assert
isinstance
(
backward_ops
[
1
][
0
],
BackwardActivationBias
)
else
:
else
:
assert
len
(
backward_ops
)
==
3
assert
len
(
backward_ops
)
==
3
assert
isinstance
(
backward_ops
[
0
][
0
],
act_typ
e
)
assert
isinstance
(
backward_ops
[
0
][
0
],
te_ops
.
Quantiz
e
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Bias
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Bias
)
assert
isinstance
(
backward_ops
[
2
][
0
],
te_ops
.
Quantiz
e
)
assert
isinstance
(
backward_ops
[
2
][
0
],
act_typ
e
)
# Expected numerical error
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
tols
=
dtype_tols
(
dtype
)
...
@@ -2944,3 +3246,499 @@ class TestSequentialModules:
...
@@ -2944,3 +3246,499 @@ class TestSequentialModules:
if
bias
:
if
bias
:
torch
.
testing
.
assert_close
(
to_cpu
(
ffn1
.
bias
.
grad
),
b1_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
to_cpu
(
ffn1
.
bias
.
grad
),
b1_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
to_cpu
(
ffn2
.
bias
.
grad
),
b2_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
to_cpu
(
ffn2
.
bias
.
grad
),
b2_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"glu_interleave_size"
,
(
None
,
32
))
def
test_grouped_mlp
(
self
,
*
,
group_size
:
int
=
4
,
bias
:
bool
,
hidden_size
:
int
=
256
,
dtype
:
torch
.
dtype
,
quantization
:
Optional
[
str
],
device
:
torch
.
device
=
"cuda"
,
split_alignment
:
int
=
256
,
glu_interleave_size
:
Optional
[
int
],
)
->
None
:
"""GroupedLinear + ScaledSwiGLU + GroupedLinear"""
# Split sizes
split_sizes
=
[
split_alignment
*
i
for
i
in
range
(
group_size
)]
random
.
shuffle
(
split_sizes
)
split_sizes
=
torch
.
tensor
(
split_sizes
,
dtype
=
torch
.
int
,
device
=
device
)
# Make input shape
in_shape
=
(
split_sizes
.
sum
().
item
(),
hidden_size
)
out_shape
=
in_shape
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
if
with_quantization
and
dtype
not
in
(
torch
.
bfloat16
,
torch
.
float16
):
pytest
.
skip
(
"Quantized group GEMM is only supported with BF16/FP16"
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
min
=-
0.25
,
max
=
0.25
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
min
=-
0.25
,
max
=
0.25
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
probs_ref
,
probs_test
=
make_reference_and_test_tensors
(
(
in_shape
[
0
],),
test_dtype
=
dtype
,
test_device
=
device
,
)
fc1_ws_ref
,
fc1_ws_test
=
[],
[]
fc1_bs_ref
,
fc1_bs_test
=
[],
[]
fc2_ws_ref
,
fc2_ws_test
=
[],
[]
fc2_bs_ref
,
fc2_bs_test
=
[],
[]
for
_
in
range
(
group_size
):
fc1_w_ref
,
fc1_w_test
=
make_reference_and_test_tensors
(
(
2
*
hidden_size
,
hidden_size
),
min
=-
0.25
,
max
=
0.25
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
fc2_w_ref
,
fc2_w_test
=
make_reference_and_test_tensors
(
(
hidden_size
,
hidden_size
),
min
=-
0.25
,
max
=
0.25
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
fc1_b_ref
,
fc1_b_test
=
None
,
None
fc2_b_ref
,
fc2_b_test
=
None
,
None
if
bias
:
fc1_b_ref
,
fc1_b_test
=
make_reference_and_test_tensors
(
(
2
*
hidden_size
,),
min
=-
0.5
,
max
=
0.5
,
test_dtype
=
dtype
,
test_device
=
device
,
)
fc2_b_ref
,
fc2_b_test
=
make_reference_and_test_tensors
(
(
hidden_size
,),
min
=-
0.5
,
max
=
0.5
,
test_dtype
=
dtype
,
test_device
=
device
,
)
fc1_ws_ref
.
append
(
fc1_w_ref
)
fc1_bs_ref
.
append
(
fc1_b_ref
)
fc1_ws_test
.
append
(
fc1_w_test
)
fc1_bs_test
.
append
(
fc1_b_test
)
fc2_ws_ref
.
append
(
fc2_w_ref
)
fc2_bs_ref
.
append
(
fc2_b_ref
)
fc2_ws_test
.
append
(
fc2_w_test
)
fc2_bs_test
.
append
(
fc2_b_test
)
# Reference implementation
xs
=
torch
.
split
(
x_ref
,
split_sizes
.
tolist
())
probs
=
torch
.
split
(
probs_ref
,
split_sizes
.
tolist
())
ys
=
[]
for
group_idx
in
range
(
group_size
):
x
=
xs
[
group_idx
]
x
=
torch
.
nn
.
functional
.
linear
(
x
,
fc1_ws_ref
[
group_idx
],
bias
=
fc1_bs_ref
[
group_idx
])
if
glu_interleave_size
is
not
None
:
x
=
x
.
reshape
(
-
1
,
2
*
hidden_size
//
(
2
*
glu_interleave_size
),
2
,
glu_interleave_size
,
)
x
=
x
.
transpose
(
1
,
2
)
x
=
x
.
reshape
(
-
1
,
2
*
hidden_size
)
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
x
=
torch
.
nn
.
functional
.
silu
(
x1
)
*
x2
x
=
x
*
probs
[
group_idx
].
unsqueeze
(
-
1
)
x
=
torch
.
nn
.
functional
.
linear
(
x
,
fc2_ws_ref
[
group_idx
],
bias
=
fc2_bs_ref
[
group_idx
])
ys
.
append
(
x
)
y_ref
=
torch
.
cat
(
ys
)
y_ref
.
backward
(
dy_ref
)
# Construct operations
recipe
=
make_recipe
(
quantization
)
with
te
.
quantized_model_init
(
enabled
=
with_quantization
,
recipe
=
recipe
):
fc1
=
te_ops
.
GroupedLinear
(
group_size
,
hidden_size
,
2
*
hidden_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
fc2
=
te_ops
.
GroupedLinear
(
group_size
,
hidden_size
,
hidden_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
module
=
te_ops
.
Sequential
(
fc1
,
te_ops
.
ScaledSwiGLU
(
glu_interleave_size
=
glu_interleave_size
),
fc2
,
)
# Copy weights
with
torch
.
no_grad
():
for
group_idx
in
range
(
group_size
):
getattr
(
fc1
,
f
"weight
{
group_idx
}
"
).
copy_
(
fc1_ws_test
[
group_idx
])
getattr
(
fc2
,
f
"weight
{
group_idx
}
"
).
copy_
(
fc2_ws_test
[
group_idx
])
if
bias
:
getattr
(
fc1
,
f
"bias
{
group_idx
}
"
).
copy_
(
fc1_bs_test
[
group_idx
])
getattr
(
fc2
,
f
"bias
{
group_idx
}
"
).
copy_
(
fc2_bs_test
[
group_idx
])
del
fc1_ws_test
,
fc1_bs_test
,
fc2_ws_test
,
fc2_bs_test
# Fuse ops and perform forward and backward pass
with
te
.
autocast
(
enabled
=
with_quantization
,
recipe
=
recipe
):
y_test
=
module
(
x_test
,
split_sizes
,
probs_test
,
split_sizes
)
y_test
.
backward
(
dy_test
)
# Loose tols for sanity checking
tols
=
{
"rtol"
:
0.125
,
"atol"
:
0.25
}
if
quantization
==
"nvfp4"
:
tols
=
{
"rtol"
:
0.25
,
"atol"
:
0.5
}
# Check values
assert_close
(
y_test
,
y_ref
,
**
tols
)
assert_close_grads
(
x_test
,
x_ref
,
**
tols
)
assert_close_grads
(
probs_test
,
probs_ref
,
**
tols
)
for
group_idx
in
range
(
group_size
):
assert_close_grads
(
getattr
(
fc2
,
f
"weight
{
group_idx
}
"
),
fc2_ws_ref
[
group_idx
],
**
tols
)
assert_close_grads
(
getattr
(
fc2
,
f
"bias
{
group_idx
}
"
),
fc2_bs_ref
[
group_idx
],
**
tols
)
assert_close_grads
(
getattr
(
fc1
,
f
"weight
{
group_idx
}
"
),
fc1_ws_ref
[
group_idx
],
**
tols
)
assert_close_grads
(
getattr
(
fc1
,
f
"bias
{
group_idx
}
"
),
fc1_bs_ref
[
group_idx
],
**
tols
)
class
TestCustomOps
:
"""Test with ops that are defined externally"""
def
test_custom_basic_op
(
self
,
*
,
shape
:
Iterable
[
int
]
=
(
7
,
5
),
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
"cuda"
,
)
->
None
:
"""Custom basic op"""
class
CustomScaleOp
(
te
.
ops
.
BasicOperation
):
"""Custom op that applies a learnable scale"""
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
scale
:
torch
.
nn
.
Parameter
scale
=
torch
.
ones
((),
dtype
=
dtype
,
device
=
device
)
scale
=
torch
.
nn
.
Parameter
(
scale
)
self
.
register_parameter
(
"scale"
,
scale
)
def
op_forward
(
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_output_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
ctx
.
save_for_backward
(
self
.
scale
,
input_
)
return
self
.
scale
*
input_
def
op_backward
(
self
,
ctx
:
OperationContext
,
grad_output
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
(
scale
,
input_
,
)
=
ctx
.
saved_tensors
grad_scale
=
torch
.
inner
(
input_
.
reshape
(
-
1
),
grad_output
.
reshape
(
-
1
))
grad_scale
=
grad_scale
.
reshape
(())
grad_input
=
scale
*
grad_output
return
grad_input
,
(
grad_scale
,)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(),
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
w_ref
*
x_ref
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
op
=
CustomScaleOp
()
forward
=
te
.
ops
.
Sequential
(
te
.
ops
.
Identity
(),
op
,
te
.
ops
.
Identity
())
with
torch
.
no_grad
():
op
.
scale
.
copy_
(
w_test
)
del
w_test
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
# Check results
tols
=
dtype_tols
(
dtype
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
op
.
scale
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
def
test_custom_forward_fused_op
(
self
,
*
,
shape
:
Iterable
[
int
]
=
(
7
,
11
),
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
"cuda"
,
):
"""Custom fused op in forward pass"""
class
CustomForwardLinearSiLU
(
te
.
ops
.
FusedOperation
):
"""Custom fused op for GEMM + SiLU"""
_enabled
=
True
def
__init__
(
self
,
*
,
linear
,
silu
)
->
None
:
super
().
__init__
((
linear
,
silu
))
def
fuser_forward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
input_
:
torch
.
Tensor
,
**
unused
,
)
->
torch
.
Tensor
:
weight
=
self
.
basic_ops
[
0
].
weight
dtype
=
weight
.
dtype
device
=
weight
.
device
# Perform compute on CPU, because why not?
x
=
input_
.
cpu
()
w
=
weight
.
cpu
()
y
=
torch
.
matmul
(
x
,
w
.
T
)
z
=
torch
.
nn
.
functional
.
silu
(
y
)
out
=
z
.
to
(
device
=
device
)
# Save state for linear backward
linear_op_ctx
=
basic_op_ctxs
[
0
]
linear_op_ctx
.
save_for_backward
(
input_
,
weight
)
linear_op_ctx
.
with_quantized_compute
=
False
linear_op_ctx
.
input_quantizer
=
None
linear_op_ctx
.
weight_quantizer
=
None
linear_op_ctx
.
grad_output_quantizer
=
None
linear_op_ctx
.
grad_input_quantizer
=
None
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_requires_grad
=
True
linear_op_ctx
.
weight_requires_grad
=
True
# Save state for SiLU backward
silu_op_ctx
=
basic_op_ctxs
[
1
]
silu_op_ctx
.
save_for_backward
(
y
.
to
(
device
=
device
))
silu_op_ctx
.
dtype
=
dtype
silu_op_ctx
.
prev_op_grad_output_quantizer
=
None
return
out
,
[(),
()]
@
staticmethod
def
fuse_ops
(
ops
:
list
[
FusibleOperation
],
**
unused
,
)
->
list
[
FusibleOperation
]:
"""Apply fusion the first time this function is called"""
if
CustomForwardLinearSiLU
.
_enabled
:
CustomForwardLinearSiLU
.
_enabled
=
False
op
=
CustomForwardLinearSiLU
(
linear
=
ops
[
0
],
silu
=
ops
[
1
])
return
[
op
]
+
ops
[
2
:]
return
ops
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
shape
[
-
1
],
shape
[
-
1
]),
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
linear
(
x_ref
,
w_ref
)
y_ref
=
torch
.
nn
.
functional
.
silu
(
y_ref
)
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
te
.
ops
.
register_forward_fusion
(
CustomForwardLinearSiLU
.
fuse_ops
)
model
=
te
.
ops
.
Sequential
(
te
.
ops
.
Linear
(
shape
[
-
1
],
shape
[
-
1
],
bias
=
False
),
te
.
ops
.
SiLU
(),
)
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
del
w_test
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
# Check that forward operations have been fused
forward_ops
=
model
.
_module_groups
[
0
].
_forward_ops
assert
len
(
forward_ops
)
==
1
assert
isinstance
(
forward_ops
[
0
][
0
],
CustomForwardLinearSiLU
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
model
[
0
].
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
def
test_custom_backward_fused_op
(
self
,
*
,
shape
:
Iterable
[
int
]
=
(
13
,
5
),
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
"cuda"
,
):
"""Custom fused op in backward pass"""
class
CustomBackwardLinearScale
(
te
.
ops
.
FusedOperation
):
"""Custom fused op for backward linear + scale"""
_enabled
:
bool
=
True
def
__init__
(
self
,
*
,
scale
,
linear
)
->
None
:
super
().
__init__
((
scale
,
linear
))
def
fuser_backward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
grad_output
:
torch
.
Tensor
,
**
unused
,
)
->
torch
.
Tensor
:
# Load state from linear forward
linear_op_ctx
=
basic_op_ctxs
[
1
]
x
,
w
=
linear_op_ctx
.
saved_tensors
dtype
=
linear_op_ctx
.
dtype
device
=
w
.
device
# Perform compute in FP64 and apply scale before dgrad
# GEMM instead of after
scale
=
self
.
basic_ops
[
0
].
scale
dy
=
grad_output
.
double
()
x
=
x
.
double
()
w
=
w
.
double
()
dx
=
torch
.
matmul
(
dy
,
scale
*
w
)
dw
=
torch
.
matmul
(
dy
.
T
,
x
)
dx
=
dx
.
to
(
dtype
=
dtype
)
dw
=
dw
.
to
(
dtype
=
dtype
)
return
dx
,
[(),
(
dw
,)],
[(),
()]
@
staticmethod
def
fuse_ops
(
ops
:
list
[
FusibleOperation
],
**
unused
,
)
->
list
[
FusibleOperation
]:
"""Apply fusion the first time this function is called"""
if
CustomBackwardLinearScale
.
_enabled
:
CustomBackwardLinearScale
.
_enabled
=
False
op
=
CustomBackwardLinearScale
(
scale
=
ops
[
0
],
linear
=
ops
[
1
])
return
[
op
]
+
ops
[
2
:]
return
ops
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
shape
[
-
1
],
shape
[
-
1
]),
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
scale
=
1.234
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
linear
(
scale
*
x_ref
,
w_ref
)
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
te
.
ops
.
register_backward_fusion
(
CustomBackwardLinearScale
.
fuse_ops
,
prepend
=
True
)
model
=
te
.
ops
.
Sequential
(
te
.
ops
.
ConstantScale
(
scale
),
te
.
ops
.
Linear
(
shape
[
-
1
],
shape
[
-
1
],
bias
=
False
),
)
with
torch
.
no_grad
():
model
[
1
].
weight
.
copy_
(
w_test
)
del
w_test
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
# Check that forward operations have been fused
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
assert
len
(
backward_ops
)
==
1
assert
isinstance
(
backward_ops
[
0
][
0
],
CustomBackwardLinearScale
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
model
[
1
].
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
tests/pytorch/test_grouped_tensor.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for GroupedTensor class"""
from
typing
import
List
,
Tuple
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.tensor.storage.grouped_tensor
import
GroupedTensor
from
transformer_engine.pytorch
import
(
Quantizer
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
Float8BlockQuantizer
,
MXFP8Quantizer
,
NVFP4Quantizer
,
)
from
transformer_engine.pytorch.constants
import
TE_DType_To_Torch
import
transformer_engine_torch
as
tex
# Check available recipes
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
te
.
is_fp8_block_scaling_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
nvfp4_available
,
reason_for_no_nvfp4
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
_quantization_params
=
[
pytest
.
param
(
"fp8_delayed_scaling"
,
marks
=
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
),
),
pytest
.
param
(
"fp8_current_scaling"
,
marks
=
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
),
),
pytest
.
param
(
"fp8_blockwise"
,
marks
=
pytest
.
mark
.
skipif
(
not
fp8_block_scaling_available
,
reason
=
reason_for_no_fp8_block_scaling
),
),
pytest
.
param
(
"mxfp8"
,
marks
=
pytest
.
mark
.
skipif
(
not
mxfp8_available
,
reason
=
reason_for_no_mxfp8
),
),
pytest
.
param
(
"nvfp4"
,
marks
=
pytest
.
mark
.
skipif
(
not
nvfp4_available
,
reason
=
reason_for_no_nvfp4
),
),
]
def
make_quantizer
(
quantization
:
str
,
num_tensors
:
int
,
shape
:
List
[
Tuple
[
int
,
int
]])
->
Quantizer
:
"""Create quantizers for given quantization scheme"""
if
quantization
==
"fp8_delayed_scaling"
:
quantizer
=
Float8Quantizer
(
scale
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
),
amax
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
elif
quantization
==
"fp8_current_scaling"
:
quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
,
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
elif
quantization
==
"fp8_blockwise"
:
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
rowwise
=
True
,
columnwise
=
False
,
force_pow_2_scales
=
True
,
amax_epsilon
=
0.0
,
block_scaling_dim
=
1
,
)
elif
quantization
==
"mxfp8"
:
quantizer
=
MXFP8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
elif
quantization
==
"nvfp4"
:
quantizer
=
NVFP4Quantizer
(
with_rht
=
False
,
with_post_rht_amax
=
False
,
with_2d_quantization
=
False
,
stochastic_rounding
=
False
,
with_random_sign_mask
=
False
,
)
else
:
raise
ValueError
(
f
"Unknown quantization scheme:
{
quantization
}
"
)
quantizer
.
internal
=
False
return
quantizer
def
_get_rowwise_data_tensor
(
qtensor
,
quantization
:
str
)
->
torch
.
Tensor
:
if
quantization
in
(
"fp8_delayed_scaling"
,
"fp8_current_scaling"
):
return
qtensor
.
_data
if
quantization
in
(
"fp8_blockwise"
,
"mxfp8"
,
"nvfp4"
):
return
qtensor
.
_rowwise_data
raise
ValueError
(
f
"Unknown quantization scheme:
{
quantization
}
"
)
def
_rowwise_offset_bytes
(
numel
:
int
,
quantization
:
str
)
->
int
:
if
quantization
==
"nvfp4"
:
return
numel
//
2
return
numel
class
TestGroupedTensor
:
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
def
test_basic_construction_all_same_shape
(
self
)
->
None
:
"""Test GroupedTensor construction with all tensors having same shape"""
num_tensors
=
4
shape
=
[(
256
,
512
)
for
_
in
range
(
num_tensors
)]
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
assert
grouped_tensor
.
num_tensors
==
num_tensors
assert
grouped_tensor
.
all_same_shape
()
assert
grouped_tensor
.
all_same_first_dim
()
assert
grouped_tensor
.
all_same_last_dim
()
assert
grouped_tensor
.
logical_shape
==
(
num_tensors
*
256
,
512
)
assert
grouped_tensor
.
get_common_first_dim
()
==
256
assert
grouped_tensor
.
get_common_last_dim
()
==
512
assert
grouped_tensor
.
has_data
()
def
test_basic_construction_varying_first_dim
(
self
)
->
None
:
"""Test GroupedTensor construction with varying first dimension"""
num_tensors
=
3
shape
=
[(
128
,
512
),
(
256
,
512
),
(
384
,
512
)]
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
assert
grouped_tensor
.
num_tensors
==
num_tensors
assert
not
grouped_tensor
.
all_same_shape
()
assert
not
grouped_tensor
.
all_same_first_dim
()
assert
grouped_tensor
.
all_same_last_dim
()
assert
grouped_tensor
.
get_common_last_dim
()
==
shape
[
0
][
1
]
assert
grouped_tensor
.
logical_shape
==
(
sum
(
v
for
v
,
_
in
shape
),
shape
[
0
][
1
],
)
# sum of first dims
def
test_split_into_quantized_tensors_no_quantization
(
self
)
->
None
:
"""Test split_into_quantized_tensors for unquantized tensors"""
num_tensors
=
3
shape
=
[(
256
,
512
)
for
_
in
range
(
num_tensors
)]
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
# Get the original data pointer
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
# Split into tensors
tensors
=
grouped_tensor
.
split_into_quantized_tensors
()
assert
len
(
tensors
)
==
num_tensors
# Verify each tensor has correct shape and shares storage
for
i
,
tensor
in
enumerate
(
tensors
):
assert
tensor
.
shape
==
shape
[
i
]
assert
isinstance
(
tensor
,
torch
.
Tensor
)
assert
not
hasattr
(
tensor
,
"_data"
)
# Not a quantized tensor
# Verify data pointer is within the original grouped tensor storage
# The tensor should be a view of the original data
assert
tensor
.
data_ptr
()
>=
original_data_ptr
# Calculate expected offset
expected_offset
=
i
*
(
shape
[
i
][
0
]
*
shape
[
i
][
1
])
*
tensor
.
element_size
()
assert
tensor
.
data_ptr
()
==
original_data_ptr
+
expected_offset
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_params
)
def
test_split_into_quantized_tensors_quantized
(
self
,
quantization
:
str
)
->
None
:
"""Test split_into_quantized_tensors for quantized tensors"""
num_tensors
=
3
shape
=
[(
512
,
512
)
for
_
in
range
(
num_tensors
)]
quantizers
=
make_quantizer
(
quantization
,
num_tensors
,
shape
)
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
quantizers
,
device
=
"cuda"
,
)
# Get the original data pointer
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
# Split into tensors
tensors
=
grouped_tensor
.
split_into_quantized_tensors
()
assert
len
(
tensors
)
==
num_tensors
# Verify each tensor shares storage with the grouped tensor
for
i
,
tensor
in
enumerate
(
tensors
):
rowwise_data
=
_get_rowwise_data_tensor
(
tensor
,
quantization
)
assert
rowwise_data
is
not
None
assert
rowwise_data
.
data_ptr
()
>=
original_data_ptr
numel
=
shape
[
i
][
0
]
*
shape
[
i
][
1
]
expected_offset
=
_rowwise_offset_bytes
(
i
*
numel
,
quantization
)
assert
rowwise_data
.
data_ptr
()
==
original_data_ptr
+
expected_offset
def
test_split_varying_shapes
(
self
)
->
None
:
"""Test split_into_quantized_tensors with varying shapes"""
num_tensors
=
3
shape
=
[(
128
,
512
),
(
256
,
512
),
(
384
,
512
)]
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
tensors
=
grouped_tensor
.
split_into_quantized_tensors
()
assert
len
(
tensors
)
==
num_tensors
# Verify shapes and storage
cumulative_offset
=
0
for
i
,
tensor
in
enumerate
(
tensors
):
assert
tensor
.
shape
==
shape
[
i
]
expected_offset
=
cumulative_offset
*
tensor
.
element_size
()
assert
tensor
.
data_ptr
()
==
original_data_ptr
+
expected_offset
cumulative_offset
+=
shape
[
i
][
0
]
*
shape
[
i
][
1
]
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_params
)
def
test_quantize_inplace
(
self
,
quantization
:
str
)
->
None
:
"""Test that quantize is done in-place for all recipes"""
num_tensors
=
3
shape
=
[(
512
,
512
)
for
_
in
range
(
num_tensors
)]
quantizers
=
make_quantizer
(
quantization
,
num_tensors
,
shape
)
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
quantizers
,
device
=
"cuda"
,
)
# Get original data pointers before quantization
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
original_scale_inv_ptr
=
grouped_tensor
.
scale_inv
.
data_ptr
()
original_scale_ptr
=
(
grouped_tensor
.
scale
.
data_ptr
()
if
grouped_tensor
.
scale
is
not
None
else
None
)
# Create input tensors
input_tensors
=
[
torch
.
randn
(
s
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
s
in
shape
]
# Quantize in place
quantized_tensors
=
grouped_tensor
.
quantize
(
input_tensors
)
# Verify data pointers haven't changed (in-place operation)
assert
grouped_tensor
.
data
.
data_ptr
()
==
original_data_ptr
assert
grouped_tensor
.
scale_inv
.
data_ptr
()
==
original_scale_inv_ptr
if
original_scale_ptr
is
not
None
:
assert
grouped_tensor
.
scale
.
data_ptr
()
==
original_scale_ptr
# Verify returned tensors point to the same storage
for
i
,
qtensor
in
enumerate
(
quantized_tensors
):
rowwise_data
=
_get_rowwise_data_tensor
(
qtensor
,
quantization
)
numel
=
shape
[
i
][
0
]
*
shape
[
i
][
1
]
expected_offset
=
_rowwise_offset_bytes
(
i
*
numel
,
quantization
)
assert
rowwise_data
.
data_ptr
()
==
original_data_ptr
+
expected_offset
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_params
)
def
test_quantize_varying_shapes
(
self
,
quantization
:
str
)
->
None
:
"""Test quantize with varying shapes"""
num_tensors
=
3
shape
=
[(
256
,
512
),
(
512
,
512
),
(
768
,
512
)]
quantizers
=
make_quantizer
(
quantization
,
num_tensors
,
shape
)
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
quantizers
,
device
=
"cuda"
,
)
# Get original data pointers
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
# Create input tensors with varying shapes
input_tensors
=
[
torch
.
randn
(
s
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
s
in
shape
]
# Quantize in place
quantized_tensors
=
grouped_tensor
.
quantize
(
input_tensors
)
# Verify data pointer hasn't changed
assert
grouped_tensor
.
data
.
data_ptr
()
==
original_data_ptr
# Verify each tensor points to correct location
cumulative_numel
=
0
for
qtensor
,
tensor_shape
in
zip
(
quantized_tensors
,
shape
):
rowwise_data
=
_get_rowwise_data_tensor
(
qtensor
,
quantization
)
expected_offset
=
_rowwise_offset_bytes
(
cumulative_numel
,
quantization
)
assert
rowwise_data
.
data_ptr
()
==
original_data_ptr
+
expected_offset
cumulative_numel
+=
tensor_shape
[
0
]
*
tensor_shape
[
1
]
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_params
)
def
test_static_quantize_method
(
self
,
quantization
:
str
)
->
None
:
"""Test the static quantize method"""
num_tensors
=
3
shape
=
[(
512
,
512
)
for
_
in
range
(
num_tensors
)]
quantizers
=
make_quantizer
(
quantization
,
num_tensors
,
shape
)
# Create input tensors
input_tensors
=
[
torch
.
randn
(
s
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
s
in
shape
]
# Use static quantize method
grouped_tensor
=
GroupedTensor
.
create_and_quantize
(
tensors
=
input_tensors
,
quantizer
=
quantizers
,
device
=
"cuda"
,
)
# Verify the grouped tensor was created correctly
assert
grouped_tensor
.
num_tensors
==
num_tensors
assert
grouped_tensor
.
has_data
()
# Verify quantized_tensors were created and point to same storage
assert
grouped_tensor
.
quantized_tensors
is
not
None
assert
len
(
grouped_tensor
.
quantized_tensors
)
==
num_tensors
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
for
i
,
qtensor
in
enumerate
(
grouped_tensor
.
quantized_tensors
):
rowwise_data
=
_get_rowwise_data_tensor
(
qtensor
,
quantization
)
numel
=
shape
[
i
][
0
]
*
shape
[
i
][
1
]
expected_offset
=
_rowwise_offset_bytes
(
i
*
numel
,
quantization
)
assert
rowwise_data
.
data_ptr
()
==
original_data_ptr
+
expected_offset
def
test_clear
(
self
)
->
None
:
"""Test clear method"""
num_tensors
=
3
shape
=
[(
256
,
512
)
for
_
in
range
(
num_tensors
)]
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
assert
grouped_tensor
.
has_data
()
assert
grouped_tensor
.
num_tensors
==
num_tensors
grouped_tensor
.
clear
()
assert
not
grouped_tensor
.
has_data
()
assert
grouped_tensor
.
num_tensors
==
0
assert
grouped_tensor
.
data
is
None
assert
grouped_tensor
.
logical_shape
==
(
0
,
0
)
tests/pytorch/test_numerics.py
View file @
9df0c4a3
...
@@ -94,6 +94,7 @@ all_boolean = [True, False]
...
@@ -94,6 +94,7 @@ all_boolean = [True, False]
all_activations
=
[
all_activations
=
[
"gelu"
,
"gelu"
,
"geglu"
,
"geglu"
,
"glu"
,
"qgelu"
,
"qgelu"
,
"qgeglu"
,
"qgeglu"
,
"relu"
,
"relu"
,
...
@@ -484,6 +485,7 @@ class TorchGroupedLinearWithPadding(nn.Module):
...
@@ -484,6 +485,7 @@ class TorchGroupedLinearWithPadding(nn.Module):
_supported_act
=
{
_supported_act
=
{
"gelu"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"geglu"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"geglu"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"glu"
:
nn
.
Sigmoid
(),
"qgelu"
:
TorchQuickGELU
(),
"qgelu"
:
TorchQuickGELU
(),
"qgeglu"
:
TorchQuickGELU
(),
"qgeglu"
:
TorchQuickGELU
(),
"relu"
:
nn
.
ReLU
(),
"relu"
:
nn
.
ReLU
(),
...
...
tests/pytorch/test_onnx_export.py
View file @
9df0c4a3
...
@@ -745,6 +745,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation):
...
@@ -745,6 +745,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation):
_test_export_layernorm_mlp
(
activation
=
activation
)
_test_export_layernorm_mlp
(
activation
=
activation
)
# Quantization recipes with fp8_dpa=True for attention emulation export test
dpa_quantization_recipes
=
[
None
]
# None = no quantization
if
fp8_available
:
dpa_quantization_recipes
.
append
(
recipe
.
DelayedScaling
(
fp8_dpa
=
True
))
dpa_quantization_recipes
.
append
(
recipe
.
Float8CurrentScaling
(
fp8_dpa
=
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
dpa_quantization_recipes
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"precision, use_mask, attn_mask_type"
,
"precision, use_mask, attn_mask_type"
,
[
[
...
@@ -762,6 +770,7 @@ def test_export_core_attention(
...
@@ -762,6 +770,7 @@ def test_export_core_attention(
precision
:
torch
.
dtype
,
precision
:
torch
.
dtype
,
use_mask
:
bool
,
use_mask
:
bool
,
attn_mask_type
:
str
,
attn_mask_type
:
str
,
fp8_recipe
:
recipe
.
Recipe
,
):
):
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
pytest
.
skip
(
"ONNX is not currently required in hip"
)
...
@@ -783,22 +792,25 @@ def test_export_core_attention(
...
@@ -783,22 +792,25 @@ def test_export_core_attention(
mask_str
=
get_attn_mask_str
(
use_mask
,
attn_mask_type
)
mask_str
=
get_attn_mask_str
(
use_mask
,
attn_mask_type
)
high_prec_str
=
dtype2str
(
precision
)
high_prec_str
=
dtype2str
(
precision
)
fname
=
f
"te.core_attention
{
mask_str
}{
high_prec_str
}
.onnx"
fp8_str
=
"_fp8_dpa"
if
fp8_recipe
is
not
None
else
""
fname
=
f
"te.core_attention
{
fp8_str
}{
mask_str
}{
high_prec_str
}
.onnx"
is_fp8
=
fp8_recipe
is
not
None
model
=
te
.
attention
.
DotProductAttention
(
model
=
te
.
attention
.
DotProductAttention
(
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
kv_channels
=
kv_channels
,
kv_channels
=
kv_channels
,
attention_dropout
=
0.5
,
qkv_format
=
qkv_format
,
qkv_format
=
qkv_format
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
).
to
(
device
=
"cuda"
)
).
to
(
device
=
"cuda"
)
do_export
(
model
,
inp
,
fname
,
input_names
=
input_names
,
fp8_recipe
=
Non
e
)
do_export
(
model
,
inp
,
fname
,
input_names
=
input_names
,
fp8_recipe
=
fp8_recip
e
)
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
False
,
fp8_recipe
=
Non
e
)
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
is_fp8
,
fp8_recipe
=
fp8_recip
e
)
serialize_inputs_outputs
(
fname
,
inp
,
te_outputs
,
input_names
=
input_names
)
serialize_inputs_outputs
(
fname
,
inp
,
te_outputs
,
input_names
=
input_names
)
if
precision
in
(
torch
.
bfloat16
,):
if
precision
in
(
torch
.
bfloat16
,):
return
return
atol
=
5e-1
if
is_fp8
else
1e-2
validate_result
(
validate_result
(
fname
,
inp
,
model
,
is_fp8
=
True
,
atol
=
1e-2
,
input_names
=
input_names
,
te_outputs
=
te_outputs
fname
,
inp
,
model
,
is_fp8
=
True
,
atol
=
atol
,
input_names
=
input_names
,
te_outputs
=
te_outputs
)
)
...
...
tests/pytorch/test_sanity.py
View file @
9df0c4a3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
from
typing
import
Optional
from
typing
import
Optional
,
List
import
torch
import
torch
import
pytest
import
pytest
...
@@ -114,6 +114,7 @@ batch_sizes_with_zero = [0, 1, 2]
...
@@ -114,6 +114,7 @@ batch_sizes_with_zero = [0, 1, 2]
all_activations
=
[
all_activations
=
[
"gelu"
,
"gelu"
,
"geglu"
,
"geglu"
,
"glu"
,
"qgelu"
,
"qgelu"
,
"qgeglu"
,
"qgeglu"
,
"relu"
,
"relu"
,
...
@@ -138,6 +139,117 @@ def reset_global_fp8_state():
...
@@ -138,6 +139,117 @@ def reset_global_fp8_state():
FP8GlobalStateManager
.
reset
()
FP8GlobalStateManager
.
reset
()
def
check_grouped_tensor_pointers_helper
(
tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"tensor"
):
"""
Verify that tensors are stored in contiguous memory.
Args:
tensors: List or iterable of tensors to check
num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4)
tensor_name: Name to use in error messages
"""
tensor_list
=
list
(
tensors
)
if
len
(
tensor_list
)
<
2
:
return
# Nothing to check
for
i
in
range
(
1
,
len
(
tensor_list
)):
prev_tensor
=
tensor_list
[
i
-
1
]
curr_tensor
=
tensor_list
[
i
]
# Calculate expected offset based on previous tensor size
prev_numel
=
prev_tensor
.
numel
()
expected_offset
=
(
prev_numel
//
num_elems_in_byte
)
*
prev_tensor
.
element_size
()
# Verify current tensor's data pointer is correctly offset
expected_ptr
=
prev_tensor
.
data_ptr
()
+
expected_offset
actual_ptr
=
curr_tensor
.
data_ptr
()
assert
(
actual_ptr
==
expected_ptr
),
f
"
{
tensor_name
}
{
i
}
data pointer mismatch: expected
{
expected_ptr
}
, got
{
actual_ptr
}
"
def
check_grouped_tensor_pointers
(
weights
:
List
[
torch
.
Tensor
],
fp8_recipe
:
Optional
[
recipe
.
Recipe
]
=
None
):
"""
Verify that the pointers of the weights are in contiguous memory for GroupedTensor.
TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach.
"""
num_elems_in_a_data_byte
=
1
if
fp8_recipe
is
None
else
2
if
fp8_recipe
.
nvfp4
()
else
1
# Check data.
if
hasattr
(
weights
[
0
],
"_data"
)
and
weights
[
0
].
_data
is
not
None
:
data_tensors
=
[
w
.
_data
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
data_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"data"
)
# Check transpose.
if
hasattr
(
weights
[
0
],
"_transpose"
)
and
weights
[
0
].
_transpose
is
not
None
:
transpose_tensors
=
[
w
.
_transpose
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
transpose_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"transpose"
)
# Check scale_inv.
if
hasattr
(
weights
[
0
],
"_scale_inv"
)
and
weights
[
0
].
_scale_inv
is
not
None
:
scale_inv_tensors
=
[
w
.
_scale_inv
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
scale_inv_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"scale_inv"
)
# Check rowwise scale_inv.
if
hasattr
(
weights
[
0
],
"_rowwise_scale_inv"
)
and
weights
[
0
].
_rowwise_scale_inv
is
not
None
:
scale_inv_tensors
=
[
w
.
_rowwise_scale_inv
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
scale_inv_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"rowwise_scale_inv"
)
# Check columnwise scale_inv.
if
(
hasattr
(
weights
[
0
],
"_columnwise_scale_inv"
)
and
weights
[
0
].
_columnwise_scale_inv
is
not
None
):
columnwise_scale_inv_tensors
=
[
w
.
_columnwise_scale_inv
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
columnwise_scale_inv_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"columnwise scale_inv"
,
)
# Check rowwise amax.
if
hasattr
(
weights
[
0
],
"_rowwise_amax"
)
and
weights
[
0
].
_rowwise_amax
is
not
None
:
rowwise_amax_tensors
=
[
w
.
_rowwise_amax
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
rowwise_amax_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"rowwise amax"
)
# Check columnwise amax.
if
hasattr
(
weights
[
0
],
"_columnwise_amax"
)
and
weights
[
0
].
_columnwise_amax
is
not
None
:
columnwise_amax_tensors
=
[
w
.
_columnwise_amax
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
columnwise_amax_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"columnwise amax"
)
# Check rowwise data.
if
hasattr
(
weights
[
0
],
"_rowwise_data"
)
and
weights
[
0
].
_rowwise_data
is
not
None
:
rowwise_data_tensors
=
[
w
.
_rowwise_data
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
rowwise_data_tensors
,
num_elems_in_byte
=
num_elems_in_a_data_byte
,
tensor_name
=
"rowwise data"
,
)
# Check columnwise data.
if
hasattr
(
weights
[
0
],
"_columnwise_data"
)
and
weights
[
0
].
_columnwise_data
is
not
None
:
columnwise_data_tensors
=
[
w
.
_columnwise_data
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
columnwise_data_tensors
,
num_elems_in_byte
=
num_elems_in_a_data_byte
,
tensor_name
=
"columnwise data"
,
)
def
_test_sanity_e2e_amp
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
def
_test_sanity_e2e_amp
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
max_seqlen_q
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_seqlen_q
,
config
.
batch_size
,
config
.
hidden_size
),
...
@@ -486,10 +598,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
...
@@ -486,10 +598,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"single_param"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"empty_split"
,
[
"first"
,
"last"
,
"middle"
])
@
pytest
.
mark
.
parametrize
(
"empty_split"
,
[
"first"
,
"last"
,
"middle"
])
@
pytest
.
mark
.
parametrize
(
"num_gemms"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"num_gemms"
,
[
4
])
def
test_sanity_grouped_linear
(
def
test_sanity_grouped_linear
(
dtype
,
bs
,
model
,
fp8_recipe
,
fp8_model_params
,
use_bias
,
num_gemms
,
empty_split
dtype
,
bs
,
model
,
fp8_recipe
,
fp8_model_params
,
use_bias
,
single_param
,
num_gemms
,
empty_split
,
):
):
if
NVTE_TEST_NVINSPECT_ENABLED
and
fp8_model_params
:
if
NVTE_TEST_NVINSPECT_ENABLED
and
fp8_model_params
:
pytest
.
skip
(
"FP8 model parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 model parameters are not supported in debug mode."
)
...
@@ -499,6 +620,9 @@ def test_sanity_grouped_linear(
...
@@ -499,6 +620,9 @@ def test_sanity_grouped_linear(
bs
=
bs
*
16
bs
=
bs
*
16
num_tokens
=
bs
*
config
.
max_seqlen_q
*
(
num_gemms
-
1
)
num_tokens
=
bs
*
config
.
max_seqlen_q
*
(
num_gemms
-
1
)
if
single_param
:
os
.
environ
[
"NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"
]
=
"1"
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
...
@@ -508,9 +632,19 @@ def test_sanity_grouped_linear(
...
@@ -508,9 +632,19 @@ def test_sanity_grouped_linear(
use_fp8
=
fp8_recipe
is
not
None
use_fp8
=
fp8_recipe
is
not
None
with
quantized_model_init
(
enabled
=
use_fp8
and
fp8_model_params
,
recipe
=
fp8_recipe
):
with
quantized_model_init
(
enabled
=
use_fp8
and
fp8_model_params
,
recipe
=
fp8_recipe
):
te_grouped_linear
=
GroupedLinear
(
te_grouped_linear
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
ffn_hidden_size
,
bias
=
use_bias
,
params_dtype
=
dtype
num_gemms
,
config
.
hidden_size
,
ffn_hidden_size
,
bias
=
use_bias
,
params_dtype
=
dtype
,
).
cuda
()
).
cuda
()
# Verify that weights are stored in contiguous GroupedTensor storage.
weights
=
[
getattr
(
te_grouped_linear
,
f
"weight
{
i
}
"
)
for
i
in
range
(
num_gemms
)]
if
fp8_recipe
is
None
or
not
(
fp8_recipe
.
delayed
()
or
fp8_recipe
.
float8_current_scaling
()):
if
single_param
:
check_grouped_tensor_pointers
(
weights
,
fp8_recipe
)
inp_hidden_states
=
torch
.
randn
(
inp_hidden_states
=
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
requires_grad
=
True
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
requires_grad
=
True
).
cuda
()
).
cuda
()
...
@@ -528,6 +662,9 @@ def test_sanity_grouped_linear(
...
@@ -528,6 +662,9 @@ def test_sanity_grouped_linear(
loss
.
backward
()
loss
.
backward
()
assert
out
.
shape
==
(
num_tokens
,
ffn_hidden_size
)
assert
out
.
shape
==
(
num_tokens
,
ffn_hidden_size
)
if
single_param
:
del
os
.
environ
[
"NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"
]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
...
@@ -1005,7 +1142,13 @@ def test_replace_raw_data_for_float8tensor():
...
@@ -1005,7 +1142,13 @@ def test_replace_raw_data_for_float8tensor():
random_bf16_data
=
torch
.
randn
(
fp8_tensor
.
shape
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
random_bf16_data
=
torch
.
randn
(
fp8_tensor
.
shape
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
fp8_quantizer
.
update_quantized
(
random_bf16_data
,
fp8_tensor
)
fp8_quantizer
.
update_quantized
(
random_bf16_data
,
fp8_tensor
)
attrs_to_check
=
[
"_quantizer"
,
"_fp8_dtype"
,
"_scale_inv"
,
"_transpose"
,
"_transpose_invalid"
]
attrs_to_check
=
[
"_quantizer"
,
"_fp8_dtype"
,
"_scale_inv"
,
"_transpose"
,
"_transpose_invalid"
,
]
attrs
=
{}
attrs
=
{}
for
attr
in
attrs_to_check
:
for
attr
in
attrs_to_check
:
attrs
[
attr
]
=
getattr
(
fp8_tensor
,
attr
)
attrs
[
attr
]
=
getattr
(
fp8_tensor
,
attr
)
...
...
tests/pytorch/utils.py
View file @
9df0c4a3
...
@@ -15,7 +15,7 @@ import torch
...
@@ -15,7 +15,7 @@ import torch
import
transformer_engine
import
transformer_engine
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
InferenceParams
from
transformer_engine.pytorch
import
InferenceParams
,
QuantizedTensor
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
get_attention_backend
,
get_attention_backend
,
...
@@ -353,11 +353,56 @@ def get_available_attention_backends(
...
@@ -353,11 +353,56 @@ def get_available_attention_backends(
backends
=
{
0
:
"F16_max512_seqlen"
,
1
:
"F16_arbitrary_seqlen"
,
2
:
"FP8"
}
backends
=
{
0
:
"F16_max512_seqlen"
,
1
:
"F16_arbitrary_seqlen"
,
2
:
"FP8"
}
if
AttentionLogging
.
_is_logging_setup
is
False
:
if
AttentionLogging
.
_is_logging_setup
is
False
:
AttentionLogging
.
setup_logging
()
AttentionLogging
.
setup_logging
()
with
logging_context
(
highest_level
=
AttentionLogging
.
_log_level
):
for
i
in
range
(
3
):
for
i
in
range
(
3
):
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
str
(
i
)
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
str
(
i
)
_attention_backends
[
"backend_selection_requires_update"
]
=
True
_attention_backends
[
"backend_selection_requires_update"
]
=
True
available_backends
,
flash_attention_backend
,
fused_attention_backend
=
test
()
available_backends
,
flash_attention_backend
,
fused_attention_backend
=
test
()
if
fused_attention_backend
==
FusedAttnBackend
[
backends
[
i
]]:
if
fused_attention_backend
==
FusedAttnBackend
[
backends
[
i
]]:
fused_attn_backends
.
append
(
fused_attention_backend
)
fused_attn_backends
.
append
(
fused_attention_backend
)
return
available_backends
,
flash_attention_backend
,
fused_attn_backends
return
available_backends
,
flash_attention_backend
,
fused_attn_backends
@
torch
.
no_grad
def
assert_close
(
actual
:
Optional
[
torch
.
Tensor
],
expected
:
Optional
[
torch
.
Tensor
],
*
,
check_device
:
bool
=
False
,
check_dtype
:
bool
=
False
,
check_layout
:
bool
=
False
,
**
kwargs
,
)
->
None
:
"""Assert that two tensors are close.
This function is a wrapper around torch.testing.assert_close. It
changes the defaults for device and dtype checks (useful when the
reference implementation is computed in high precision on CPU) and
it can handle quantized tensors.
"""
if
isinstance
(
actual
,
QuantizedTensor
):
actual
=
actual
.
dequantize
()
if
isinstance
(
expected
,
QuantizedTensor
):
expected
=
expected
.
dequantize
()
torch
.
testing
.
assert_close
(
actual
,
expected
,
check_device
=
check_device
,
check_dtype
=
check_dtype
,
check_layout
=
check_layout
,
**
kwargs
,
)
def
assert_close_grads
(
actual
:
Optional
[
torch
.
Tensor
],
expected
:
Optional
[
torch
.
Tensor
],
**
kwargs
,
)
->
None
:
"""Assert that two tensors have close gradients."""
if
actual
is
None
and
expected
is
None
:
return
assert
actual
is
not
None
assert
expected
is
not
None
assert_close
(
actual
.
grad
,
expected
.
grad
,
**
kwargs
)
transformer_engine/common/CMakeLists.txt
View file @
9df0c4a3
...
@@ -202,6 +202,7 @@ if(USE_CUDA)
...
@@ -202,6 +202,7 @@ if(USE_CUDA)
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn_fp8.cu
fused_attn/utils.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_grouped_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
...
@@ -225,15 +226,18 @@ if(USE_CUDA)
...
@@ -225,15 +226,18 @@ if(USE_CUDA)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
list
(
APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu
activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/relu.cu
activation/swiglu.cu
activation/swiglu.cu
cast/cast.cu
cast/cast.cu
gemm/cutlass_grouped_gemm.cu
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/graph_safe_group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_square_blockwise.cu
...
@@ -357,6 +361,7 @@ else()
...
@@ -357,6 +361,7 @@ else()
fused_attn/kv_cache.cu
fused_attn/kv_cache.cu
fused_attn/utils.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_grouped_gemm.cu
gemm/hipblas_gemm.cu
gemm/hipblas_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
...
@@ -381,6 +386,7 @@ else()
...
@@ -381,6 +386,7 @@ else()
list
(
APPEND transformer_engine_cuda_arch_specific_sources
list
(
APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu
activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/relu.cu
activation/swiglu.cu
activation/swiglu.cu
cast/cast.cu
cast/cast.cu
...
@@ -476,20 +482,18 @@ endif()
...
@@ -476,20 +482,18 @@ endif()
option
(
NVTE_WITH_CUBLASMP
"Use cuBLASMp for tensor parallel GEMMs"
OFF
)
option
(
NVTE_WITH_CUBLASMP
"Use cuBLASMp for tensor parallel GEMMs"
OFF
)
if
(
NVTE_WITH_CUBLASMP
)
if
(
NVTE_WITH_CUBLASMP
)
target_compile_definitions
(
transformer_engine PRIVATE NVTE_WITH_CUBLASMP
)
target_compile_definitions
(
transformer_engine PRIVATE NVTE_WITH_CUBLASMP
)
target_include_directories
(
transformer_engine PRIVATE
${
CUBLASMP_DIR
}
/include
${
NVSHMEM_DIR
}
/include
)
target_include_directories
(
transformer_engine PRIVATE
${
CUBLASMP_DIR
}
/include
)
find_library
(
CUBLASMP_LIB
find_library
(
CUBLASMP_LIB
NAMES cublasmp libcublasmp
NAMES cublasmp libcublasmp
PATHS
${
CUBLASMP_DIR
}
PATHS
${
CUBLASMP_DIR
}
PATH_SUFFIXES lib
PATH_SUFFIXES lib
REQUIRED
)
REQUIRED
)
find_library
(
NVSHMEM_HOST_LIB
find_library
(
NCCL_LIB
NAMES nvshmem_host libnvshmem_host.so.3
NAMES nccl libnccl
PATHS
${
NVSHMEM_DIR
}
PATH_SUFFIXES lib
PATH_SUFFIXES lib
REQUIRED
)
REQUIRED
)
target_link_libraries
(
transformer_engine PUBLIC
${
CUBLASMP_LIB
}
${
NVSHMEM_HOST
_LIB
}
)
target_link_libraries
(
transformer_engine PUBLIC
${
NCCL_LIB
}
${
CUBLASMP
_LIB
}
)
message
(
STATUS
"Using cuBLASMp at:
${
CUBLASMP_DIR
}
"
)
message
(
STATUS
"Using cuBLASMp at:
${
CUBLASMP_DIR
}
"
)
message
(
STATUS
"Using nvshmem at:
${
NVSHMEM_DIR
}
"
)
endif
()
endif
()
if
(
USE_CUDA
)
if
(
USE_CUDA
)
...
@@ -561,6 +565,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
...
@@ -561,6 +565,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
if
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
)
if
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
)
list
(
APPEND nvte_sources_with_fast_math activation/gelu.cu
list
(
APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/relu.cu
activation/swiglu.cu
)
activation/swiglu.cu
)
endif
()
endif
()
...
...
transformer_engine/common/__init__.py
View file @
9df0c4a3
...
@@ -246,11 +246,13 @@ def _nvidia_cudart_include_dir() -> str:
...
@@ -246,11 +246,13 @@ def _nvidia_cudart_include_dir() -> str:
return
""
return
""
# Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia"
# Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia"
# above doesn't through. However, they don't set "__file__" attribute.
# above doesn't throw. However, they don't set "__file__" attribute.
if
nvidia
.
__file__
is
None
:
if
nvidia
.
__file__
is
not
None
:
return
""
nvidia_root
=
Path
(
nvidia
.
__file__
).
parent
else
:
nvidia_root
=
Path
(
nvidia
.
__path__
[
0
])
# namespace package
include_dir
=
Path
(
nvidia
.
__file__
).
paren
t
/
"cuda_runtime"
include_dir
=
nvidia
_roo
t
/
"cuda_runtime"
return
str
(
include_dir
)
if
include_dir
.
exists
()
else
""
return
str
(
include_dir
)
if
include_dir
.
exists
()
else
""
...
...
transformer_engine/common/activation/gelu.cu
View file @
9df0c4a3
...
@@ -13,6 +13,14 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
...
@@ -13,6 +13,14 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn
<
fp32
,
Empty
,
gelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
act_fn
<
fp32
,
Empty
,
gelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
}
void
nvte_group_gelu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_gelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
true
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
gelu
<
fp32
,
fp32
>>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_dgelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
void
nvte_dgelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dgelu
);
NVTE_API_CALL
(
nvte_dgelu
);
...
@@ -20,6 +28,20 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
...
@@ -20,6 +28,20 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn
<
fp32
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
dact_fn
<
fp32
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
}
void
nvte_group_dgelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_dgelu
);
using
namespace
transformer_engine
;
NVTETensor
dbias
=
nullptr
;
NVTETensor
workspace
=
nullptr
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
void
nvte_quantize_dbias_dgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
...
@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
...
@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_group_quantize_dbias_dgelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
activation_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias_dgelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_geglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
void
nvte_geglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_geglu
);
NVTE_API_CALL
(
nvte_geglu
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
@@ -54,6 +90,15 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
...
@@ -54,6 +90,15 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn
<
fp32
,
Empty
,
qgelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
act_fn
<
fp32
,
Empty
,
qgelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
}
void
nvte_group_qgelu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_qgelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
true
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
qgelu
<
fp32
,
fp32
>>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_dqgelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
void
nvte_dqgelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dqgelu
);
NVTE_API_CALL
(
nvte_dqgelu
);
...
@@ -61,6 +106,20 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
...
@@ -61,6 +106,20 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn
<
fp32
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
dact_fn
<
fp32
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
}
void
nvte_group_dqgelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_dqgelu
);
using
namespace
transformer_engine
;
NVTETensor
dbias
=
nullptr
;
NVTETensor
workspace
=
nullptr
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dqgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
void
nvte_quantize_dbias_dqgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
...
@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
...
@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_group_quantize_dbias_dqgelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
activation_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias_dqgelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_qgeglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
void
nvte_qgeglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_qgeglu
);
NVTE_API_CALL
(
nvte_qgeglu
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
...
transformer_engine/common/activation/glu.cu
0 → 100644
View file @
9df0c4a3
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../util/math.h"
#include "./activation_template.h"
void
nvte_glu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_glu
);
using
namespace
transformer_engine
;
Empty
e
=
{};
gated_act_fn
<
fp32
,
Empty
,
sigmoid
<
fp32
,
fp32
>>
(
input
,
output
,
e
,
stream
);
}
void
nvte_dglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dglu
);
using
namespace
transformer_engine
;
Empty
e
=
{};
dgated_act_fn
<
fp32
,
Empty
,
sigmoid
<
fp32
,
fp32
>
,
dsigmoid
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
e
,
stream
);
}
transformer_engine/common/activation/relu.cu
View file @
9df0c4a3
...
@@ -13,6 +13,14 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
...
@@ -13,6 +13,14 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn
<
fp32
,
Empty
,
relu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
act_fn
<
fp32
,
Empty
,
relu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
}
void
nvte_group_relu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_relu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
true
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
relu
<
fp32
,
fp32
>>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_drelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
void
nvte_drelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_drelu
);
NVTE_API_CALL
(
nvte_drelu
);
...
@@ -20,6 +28,20 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
...
@@ -20,6 +28,20 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn
<
fp32
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
dact_fn
<
fp32
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
}
void
nvte_group_drelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_drelu
);
using
namespace
transformer_engine
;
NVTETensor
dbias
=
nullptr
;
NVTETensor
workspace
=
nullptr
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_drelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
void
nvte_quantize_dbias_drelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
...
@@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
...
@@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_group_quantize_dbias_drelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
activation_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias_drelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_reglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
void
nvte_reglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_reglu
);
NVTE_API_CALL
(
nvte_reglu
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
@@ -54,6 +90,15 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
...
@@ -54,6 +90,15 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn
<
fp32
,
Empty
,
srelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
act_fn
<
fp32
,
Empty
,
srelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
}
void
nvte_group_srelu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_srelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
true
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
srelu
<
fp32
,
fp32
>>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_dsrelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
void
nvte_dsrelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dsrelu
);
NVTE_API_CALL
(
nvte_dsrelu
);
...
@@ -61,6 +106,20 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
...
@@ -61,6 +106,20 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn
<
fp32
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
dact_fn
<
fp32
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
}
void
nvte_group_dsrelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_dsrelu
);
using
namespace
transformer_engine
;
NVTETensor
dbias
=
nullptr
;
NVTETensor
workspace
=
nullptr
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dsrelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
void
nvte_quantize_dbias_dsrelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
...
@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
...
@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_group_quantize_dbias_dsrelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
activation_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias_dsrelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_sreglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
void
nvte_sreglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_sreglu
);
NVTE_API_CALL
(
nvte_sreglu
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
...
transformer_engine/common/activation/swiglu.cu
View file @
9df0c4a3
...
@@ -13,6 +13,14 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
...
@@ -13,6 +13,14 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn
<
fp32
,
Empty
,
silu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
act_fn
<
fp32
,
Empty
,
silu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
}
void
nvte_group_silu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_silu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
true
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
silu
<
fp32
,
fp32
>>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_dsilu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
void
nvte_dsilu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dsilu
);
NVTE_API_CALL
(
nvte_dsilu
);
...
@@ -20,6 +28,20 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
...
@@ -20,6 +28,20 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn
<
fp32
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
dact_fn
<
fp32
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
}
void
nvte_group_dsilu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_dsilu
);
using
namespace
transformer_engine
;
NVTETensor
dbias
=
nullptr
;
NVTETensor
workspace
=
nullptr
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dsilu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
void
nvte_quantize_dbias_dsilu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
...
@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
...
@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_group_quantize_dbias_dsilu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
activation_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias_dsilu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_swiglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
void
nvte_swiglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_swiglu
);
NVTE_API_CALL
(
nvte_swiglu
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
...
transformer_engine/common/cast/cast.cu
View file @
9df0c4a3
...
@@ -28,6 +28,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
...
@@ -28,6 +28,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
dispatch
::
quantize_fwd_helper
<
IS_ACT
,
Empty
,
nullptr
>
(
input
,
output
,
nullptr
,
stream
);
dispatch
::
quantize_fwd_helper
<
IS_ACT
,
Empty
,
nullptr
>
(
input
,
output
,
nullptr
,
stream
);
}
}
void
nvte_group_quantize
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
false
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
nullptr
>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_quantize_noop
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
noop
,
void
nvte_quantize_noop
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
noop
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_noop
);
NVTE_API_CALL
(
nvte_quantize_noop
);
...
@@ -62,6 +71,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
...
@@ -62,6 +71,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_group_quantize_dbias
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
false
;
constexpr
const
NVTEGroupedTensor
activation_input
=
nullptr
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
nullptr
>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_dequantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
void
nvte_dequantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dequantize
);
NVTE_API_CALL
(
nvte_dequantize
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
...
transformer_engine/common/cast/core/common.cuh
View file @
9df0c4a3
...
@@ -37,6 +37,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) {
...
@@ -37,6 +37,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) {
return
cols
%
alignment_requirement
==
0
;
return
cols
%
alignment_requirement
==
0
;
}
}
__device__
__forceinline__
unsigned
char
*
align_smem_ptr_per_TMA_requirements
(
unsigned
char
*
p
)
{
size_t
addr
=
reinterpret_cast
<
size_t
>
(
p
);
addr
=
(
addr
+
TMA_SHMEM_ALIGNMENT
-
1
)
&
~
(
TMA_SHMEM_ALIGNMENT
-
1
);
return
reinterpret_cast
<
unsigned
char
*>
(
addr
);
}
namespace
kernel
{
namespace
kernel
{
constexpr
size_t
THREADS_PER_BLOCK
=
256
;
constexpr
size_t
THREADS_PER_BLOCK
=
256
;
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment