Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9df0c4a3
Commit
9df0c4a3
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main'
parents
0d874a4e
f122b07d
Changes
221
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1890 additions
and
66 deletions
+1890
-66
tests/jax/test_permutation.py
tests/jax/test_permutation.py
+2
-2
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+55
-4
tests/pytorch/attention/test_attention_with_cp.py
tests/pytorch/attention/test_attention_with_cp.py
+19
-5
tests/pytorch/debug/test_log.py
tests/pytorch/debug/test_log.py
+120
-0
tests/pytorch/debug/test_sanity.py
tests/pytorch/debug/test_sanity.py
+19
-6
tests/pytorch/test_checkpoint.py
tests/pytorch/test_checkpoint.py
+1
-1
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+819
-21
tests/pytorch/test_grouped_tensor.py
tests/pytorch/test_grouped_tensor.py
+385
-0
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+2
-0
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+17
-5
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+147
-4
tests/pytorch/utils.py
tests/pytorch/utils.py
+53
-8
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+11
-6
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+6
-4
transformer_engine/common/activation/gelu.cu
transformer_engine/common/activation/gelu.cu
+73
-0
transformer_engine/common/activation/glu.cu
transformer_engine/common/activation/glu.cu
+24
-0
transformer_engine/common/activation/relu.cu
transformer_engine/common/activation/relu.cu
+73
-0
transformer_engine/common/activation/swiglu.cu
transformer_engine/common/activation/swiglu.cu
+36
-0
transformer_engine/common/cast/cast.cu
transformer_engine/common/cast/cast.cu
+22
-0
transformer_engine/common/cast/core/common.cuh
transformer_engine/common/cast/core/common.cuh
+6
-0
No files found.
tests/jax/test_permutation.py
View file @
9df0c4a3
...
...
@@ -23,7 +23,7 @@ ALL_DISPATCH_COMBINE_CASES = [
(
128
,
5
,
128
,
3
),
(
1024
,
8
,
128
,
8
),
(
4096
,
32
,
1280
,
2
),
(
4096
,
25
6
,
4096
,
6
),
(
4096
,
6
4
,
4096
,
6
),
]
DISPATCH_COMBINE_CASES
=
{
"L0"
:
ALL_DISPATCH_COMBINE_CASES
[
0
:
2
],
...
...
@@ -44,7 +44,7 @@ ALL_DISPATCH_COMBINE_PADDING_CASES = [
(
128
,
5
,
128
,
3
,
8
),
(
1024
,
8
,
128
,
8
,
16
),
(
4096
,
32
,
1280
,
2
,
128
),
(
4096
,
25
6
,
4096
,
6
,
16
),
(
4096
,
6
4
,
4096
,
6
,
16
),
]
DISPATCH_COMBINE_PADDING_CASES
=
{
"L0"
:
ALL_DISPATCH_COMBINE_PADDING_CASES
[
0
:
2
],
...
...
tests/pytorch/attention/test_attention.py
View file @
9df0c4a3
...
...
@@ -74,6 +74,14 @@ if not IS_HIP_EXTENSION:
f
" sm
{
device_compute_capability
[
0
]
*
10
+
device_compute_capability
[
1
]
}
"
)
# Get determinism
_deterministic
=
(
not
bool
(
int
(
os
.
getenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
,
"1"
)))
or
torch
.
are_deterministic_algorithms_enabled
()
)
# Reset RNG seed and states
seed
=
1234
reset_rng_states
()
...
...
@@ -147,6 +155,7 @@ def test_dot_product_attention(
if
config
.
window_size
==
(
-
1
,
-
1
)
and
swa
:
config
.
window_size
=
[
2
,
2
]
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
qkv_format
=
qkv_layout
.
replace
(
"3"
,
""
).
replace
(
"2"
,
""
).
split
(
"_"
)[
0
]
if
qkv_format
==
"thd"
and
"padding"
not
in
config
.
attn_mask_type
:
...
...
@@ -162,8 +171,10 @@ def test_dot_product_attention(
qkv_layout
=
qkv_layout
,
pad_between_seqs
=
pad_between_seqs
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
:
is_training
=
False
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
...
...
@@ -172,6 +183,7 @@ def test_dot_product_attention(
qkv_layout
=
qkv_layout
,
pad_between_seqs
=
pad_between_seqs
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
...
...
@@ -421,6 +433,15 @@ def test_dpa_softmax(dtype, model_configs, model):
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
18
,
0
),
reason
=
"cuDNN 9.18.0+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_softmax
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_softmax
.
keys
())
def
test_dpa_softmax_thd
(
dtype
,
model_configs
,
model
):
"""Test DotProductAttention module with different softmax types"""
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
"thd_thd_thd"
,
False
,
False
)
model_configs_mla
=
{
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
...
...
@@ -685,9 +706,10 @@ model_configs_swa = {
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_swa
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_swa
.
keys
())
def
test_dpa_sliding_window
(
dtype
,
model_configs
,
model
):
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
[
"thd_thd_thd"
,
"sbhd_sbhd_sbhd"
])
def
test_dpa_sliding_window
(
dtype
,
model_configs
,
model
,
qkv_layout
):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
False
,
True
,
None
,
True
,
False
)
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
False
,
True
,
qkv_layout
,
True
,
False
)
model_configs_alibi_slopes
=
{
...
...
@@ -889,11 +911,14 @@ def _run_dot_product_attention(
reset_rng_states
()
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
backend
==
"FlashAttention"
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
if
backend
==
"FusedAttention"
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"
]
=
"1"
if
workspace_opt
else
"0"
if
backend
==
"UnfusedDotProductAttention"
:
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
# Create seqlens
...
...
@@ -1295,6 +1320,7 @@ def test_transformer_layer(
qkv_format
.
replace
(
"hd"
,
"h3d"
)
if
fused_qkv_params
else
qkv_format
.
replace
(
"hd"
,
"3hd"
)
),
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
:
...
...
@@ -1308,6 +1334,7 @@ def test_transformer_layer(
else
qkv_format
.
replace
(
"hd"
,
"3hd"
)
),
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
...
...
@@ -1435,10 +1462,13 @@ def _run_transformer_layer(
reset_rng_states
()
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
backend
==
"FlashAttention"
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
if
backend
==
"FusedAttention"
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
if
backend
==
"UnfusedDotProductAttention"
:
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
# Create input tensor
...
...
@@ -1632,6 +1662,7 @@ def test_dpa_fp8_extra_state(model, dtype):
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_layout
=
"sb3hd"
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
fused_attn_supported
and
not
flash_attn_supported
:
...
...
@@ -1822,6 +1853,7 @@ def test_mha_fp8_vs_f16(
fp8
=
True
,
fp8_meta
=
fp8_meta
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
flash_attn_supported
,
fused_attn_supported_fp8
,
unfused_attn_supported
=
available_backends
if
flash_attn_supported
+
fused_attn_supported_fp8
<
1
:
...
...
@@ -1833,6 +1865,7 @@ def test_mha_fp8_vs_f16(
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_format
.
replace
(
"hd"
,
"h3d"
),
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
_
,
fused_attn_supported_f16
,
_
=
available_backends
if
not
fused_attn_supported_f16
:
...
...
@@ -1841,6 +1874,7 @@ def test_mha_fp8_vs_f16(
if
flash_attn_supported
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = True"
)
flash_attn_fwd_fp8
,
param_names
,
flash_attn_bwd_fp8
=
_run_mha_fp8_vs_f16
(
...
...
@@ -1850,6 +1884,7 @@ def test_mha_fp8_vs_f16(
if
fused_attn_supported_fp8
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = True"
)
fused_attn_fwd_fp8
,
param_names
,
fused_attn_bwd_fp8
=
_run_mha_fp8_vs_f16
(
...
...
@@ -1859,6 +1894,7 @@ def test_mha_fp8_vs_f16(
if
fused_attn_supported_f16
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = False"
)
fused_attn_fwd_f16
,
param_names
,
fused_attn_bwd_f16
=
_run_mha_fp8_vs_f16
(
...
...
@@ -2071,6 +2107,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
fp8
=
True
,
fp8_meta
=
fp8_meta
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
flash_attn_supported
+
fused_attn_supported
<
1
:
...
...
@@ -2081,6 +2118,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
...
...
@@ -2091,6 +2129,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if
flash_attn_supported
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)"
)
flash_attn_fwd_fp8
,
flash_attn_bwd_fp8
=
_run_dpa_fp8_vs_f16
(
...
...
@@ -2100,6 +2139,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if
unfused_attn_supported
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)"
)
unfused_attn_fwd_fp8
,
unfused_attn_bwd_fp8
=
_run_dpa_fp8_vs_f16
(
...
...
@@ -2108,6 +2148,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)"
)
fused_attn_fwd_fp8
,
fused_attn_bwd_fp8
=
_run_dpa_fp8_vs_f16
(
...
...
@@ -2116,6 +2157,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
config
.
dropout_p
==
0.0
:
# test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
logging
.
info
(
"[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)"
)
...
...
@@ -2370,13 +2412,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
qkv_dtype
=
torch
.
float8_e4m3fn
,
qkv_layout
=
"t3hd"
if
cudnn_frontend_version
==
0
else
"bs3hd"
,
is_training
=
is_training
,
deterministic
=
_deterministic
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
not
(
fused_attn_backends
and
unfused_attn_supported
):
pytest
.
skip
(
"Not enough backends to run this test with."
)
fused_attn_fwd_fp8
,
fused_attn_bwd_fp8
=
_run_custom_mha_fp8
(
dtype
,
config
,
"FusedAttention"
)
unfused_attn_fwd_f16
,
unfused_attn_bwd_f16
=
_run_ref_mha_f16
(
dtype
,
config
,
"UnfusedAttention"
)
unfused_attn_fwd_f16
,
unfused_attn_bwd_f16
=
_run_ref_mha_f16
(
dtype
,
config
,
"UnfusedDotProductAttention"
)
atol
=
5e-1
rtol
=
5e-1
...
...
@@ -2409,10 +2454,13 @@ def _run_custom_mha_fp8(dtype, config, backend):
reset_rng_states
()
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
backend
==
"FlashAttention"
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
if
backend
==
"FusedAttention"
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
if
backend
==
"UnfusedDotProductAttention"
:
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
inp
=
0.0001
*
torch
.
randint
(
...
...
@@ -2463,10 +2511,13 @@ def _run_ref_mha_f16(dtype, config, backend):
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
backend
==
"FlashAttention"
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
if
backend
==
"FusedAttention"
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
if
backend
==
"UnfusedDotProductAttention"
:
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
inp
=
torch
.
load
(
"qkv.pt"
).
to
(
device
=
"cuda"
)
...
...
@@ -2754,7 +2805,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
cu_seqlens
,
max_s
,
)
->
torch
.
Tensor
:
with
self
.
prepare_forward
(
inp
,
num_gemms
=
3
)
as
inp
:
with
self
.
prepare_forward
_ctx
(
inp
,
num_gemms
=
3
)
as
inp
:
out
=
_custom_mha_fp8
.
apply
(
inp
,
self
.
qkv_weight
,
...
...
tests/pytorch/attention/test_attention_with_cp.py
View file @
9df0c4a3
...
...
@@ -148,7 +148,7 @@ model_configs_fused_attn = {
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
"cp_1_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
"cp_1_4"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
)),
# MHA
"cp_1_4"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
512
)),
# MHA
"cp_2_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
),
# GQA
"cp_2_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
),
# GQA
"cp_2_2"
:
ModelConfig
(
...
...
@@ -164,7 +164,7 @@ model_configs_fused_attn = {
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_bias_type
=
"post_scale_bias"
),
# GQA
"cp_2_4"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
)
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
512
)
),
# GQA
"cp_3_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# MLA
"cp_3_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
head_dim_v
=
64
),
# MLA
...
...
@@ -188,7 +188,16 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
if
test_essential
:
configs
=
[
"cp_1_0"
,
"cp_1_1"
,
"cp_1_4"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
configs
=
[
"cp_1_0"
,
"cp_1_1"
,
"cp_1_4"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_2_4"
,
"cp_3_2"
,
"cp_4_2"
,
]
model_configs_fused_attn
=
{
k
:
model_configs_fused_attn
[
k
]
for
k
in
configs
}
dtypes
=
[
"bf16"
,
"fp8"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
...
...
@@ -284,9 +293,14 @@ def test_cp_with_fused_attention(
pytest
.
skip
(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
)
if
config
.
softmax_type
!=
"vanilla"
and
qkv_format
==
"thd"
:
if
(
get_cudnn_version
()
<
(
9
,
18
,
0
)
and
config
.
softmax_type
!=
"vanilla"
and
qkv_format
==
"thd"
):
pytest
.
skip
(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!"
"Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for"
" non-vanilla softmax types!"
)
dtypes
=
{
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
,
"fp8"
:
torch
.
bfloat16
}
...
...
tests/pytorch/debug/test_log.py
View file @
9df0c4a3
...
...
@@ -15,6 +15,7 @@ from transformer_engine.pytorch import (
is_fp8_available
,
is_mxfp8_available
,
is_fp8_block_scaling_available
,
is_nvfp4_available
,
)
from
transformer_engine.pytorch.quantization
import
RecipeState
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
...
...
@@ -29,6 +30,7 @@ mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
is_fp8_block_scaling_available
(
return_reason
=
True
)
nvfp4_available
,
reason_for_no_nvfp4
=
is_nvfp4_available
(
return_reason
=
True
)
LOG_QUANTIZED_CONFIG_BASE
=
"""
log:
...
...
@@ -363,6 +365,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState
.
_reset
()
# NVFP4 tests
LOG_NVFP4_CONFIG_BASE
=
"""
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogNvfp4TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
def
test_nvfp4_numeric
(
feature_dirs
):
"""Test that NVFP4 underflows% and MSE stats are computed correctly with known values."""
if
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
log_nvfp4_config
=
LOG_NVFP4_CONFIG_BASE
.
format
(
stats
=
"underflows%, mse"
)
with
debug_session
(
log_nvfp4_config
,
feature_dirs
)
as
log_dir
:
from
transformer_engine.pytorch.tensor.nvfp4_tensor
import
NVFP4Quantizer
from
transformer_engine.pytorch.quantization
import
RecipeState
recipe_state
=
RecipeState
.
create
(
recipe
.
NVFP4BlockScaling
(),
mode
=
"forward"
,
num_quantizers
=
3
,
)
# Create test tensor with known distribution
torch
.
manual_seed
(
42
)
tensor
=
torch
.
randn
(
128
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
# Add some small values that should underflow to zero in FP4
tensor
[
0
,
:
16
]
=
0.0001
quantizer
=
recipe_state
.
make_quantizers
()[
0
]
quantized_tensor
=
quantizer
(
tensor
)
debug_api
.
transformer_engine
.
inspect_tensor
(
layer_name
=
"test_layer"
,
tensor_name
=
"activation"
,
iteration
=
0
,
tp_group
=
None
,
tensor
=
tensor
,
quantizer
=
quantizer
,
rowwise_quantized_tensor
=
quantized_tensor
,
columnwise_quantized_tensor
=
quantized_tensor
,
)
debug_api
.
step
()
dequantized_tensor
=
quantized_tensor
.
dequantize
()
output
=
read_log
(
log_dir
)
# Validate both stats are present
assert
"nvfp4_underflows%"
in
output
,
"underflows% stat missing"
assert
"nvfp4_mse"
in
output
,
"mse stat missing"
# Extract values and validate numerics
underflows_value
=
None
mse_value
=
None
for
line
in
output
.
splitlines
():
if
"nvfp4_underflows%"
in
line
and
"value="
in
line
:
underflows_value
=
float
(
line
.
split
(
"value="
)[
1
].
split
()[
0
])
if
"nvfp4_mse"
in
line
and
"value="
in
line
:
mse_value
=
float
(
line
.
split
(
"value="
)[
1
].
split
()[
0
])
# Compute expected underflows: non-zero elements that became zero after quantization
orig_nonzero_mask
=
tensor
!=
0
dequant_zero_mask
=
dequantized_tensor
==
0
expected_underflows
=
(
(
orig_nonzero_mask
&
dequant_zero_mask
).
sum
().
float
()
/
tensor
.
numel
()
*
100
)
# Allow some tolerance
assert
underflows_value
==
pytest
.
approx
(
expected_underflows
.
cpu
().
item
(),
abs
=
1e-4
)
# Compute expected MSE
expected_mse
=
torch
.
nn
.
functional
.
mse_loss
(
dequantized_tensor
.
float
(),
tensor
.
float
(),
reduction
=
"mean"
)
assert
mse_value
==
pytest
.
approx
(
expected_mse
.
cpu
().
item
(),
abs
=
1e-4
)
def
test_fp8_stats_allows_nvfp4_with_recipe_prefix
(
feature_dirs
):
"""Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis."""
if
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
# Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately)
log_fp8_config
=
LOG_QUANTIZED_CONFIG_BASE
.
format
(
stats
=
"mxfp8_mse"
)
with
debug_session
(
log_fp8_config
,
feature_dirs
)
as
log_dir
:
model
=
te
.
Linear
(
128
,
128
,
params_dtype
=
torch
.
bfloat16
)
inp
=
torch
.
randn
(
128
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
# Should work - recipe-prefixed stats compute MXFP8 separately for comparison
for
_
in
range
(
2
):
with
te
.
autocast
(
recipe
=
recipe
.
NVFP4BlockScaling
()):
output
=
model
(
inp
)
loss
=
output
.
sum
()
loss
.
backward
()
debug_api
.
step
()
output
=
read_log
(
log_dir
)
# Should have logged MXFP8 MSE stat (what-if scenario)
assert
"mxfp8_mse"
in
output
def
test_log_grouped_gemm
(
feature_dirs
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
...
...
tests/pytorch/debug/test_sanity.py
View file @
9df0c4a3
...
...
@@ -30,10 +30,17 @@ configs = {
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
"""
,
"log_fp8"
:
"""log_fp8:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors: [activation, gradient, weight]
stats: [underflows
, overflows
]
stats: [underflows
%
]
start_step : 0
end_step: 1
"""
,
...
...
@@ -46,22 +53,26 @@ fake_quant_config:
FakeQuant:
enabled: True
gemms: [fprop, dgrad, wgrad]
tensors: [activation, weight, gradient]
quant_format: FP8E5M2
"""
,
}
# Configs that require FP8 to be enabled
fp8_required_configs
=
{
"log_fp8"
}
def
_get_model
(
model_key
):
if
model_key
==
"linear"
:
return
te
.
Linear
(
D
,
D
)
return
te
.
Linear
(
D
,
D
,
name
=
"layer"
)
if
model_key
==
"layernorm_linear"
:
return
te
.
LayerNormLinear
(
D
,
D
)
return
te
.
LayerNormLinear
(
D
,
D
,
name
=
"layer"
)
if
model_key
==
"layernorm_mlp"
:
return
te
.
LayerNormMLP
(
D
,
D
,
D
)
return
te
.
LayerNormMLP
(
D
,
D
,
D
,
name
=
"layer"
)
if
model_key
==
"mha_attention"
:
return
te
.
MultiheadAttention
(
D
,
H
)
return
te
.
MultiheadAttention
(
D
,
H
,
name
=
"layer"
)
if
model_key
==
"transformer_layer"
:
return
te
.
TransformerLayer
(
D
,
D
,
H
)
return
te
.
TransformerLayer
(
D
,
D
,
H
,
name
=
"layer"
)
def
_run_forward_backward
(
model
,
fp8
):
...
...
@@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
def
test_sanity_debug
(
model_key
,
fp8
,
config_key
,
feature_dirs
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
not
fp8
and
config_key
in
fp8_required_configs
:
pytest
.
skip
(
f
"Config '
{
config_key
}
' requires FP8"
)
_run_test
(
model_key
,
fp8
,
configs
[
config_key
],
feature_dirs
)
tests/pytorch/test_checkpoint.py
View file @
9df0c4a3
...
...
@@ -101,7 +101,7 @@ class TestLoadCheckpoint:
# Path to save checkpoint
if
checkpoint_dir
is
None
:
checkpoint_dir
=
TestLoadCheckpoint
.
_checkpoint_dir
()
checkpoint_dir
.
mkdir
(
exist_ok
=
True
)
checkpoint_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
checkpoint_file
=
checkpoint_dir
/
f
"
{
name
}
.pt"
# Create module and save checkpoint
...
...
tests/pytorch/test_fusible_ops.py
View file @
9df0c4a3
...
...
@@ -5,8 +5,10 @@
from
__future__
import
annotations
from
collections.abc
import
Iterable
import
functools
import
io
import
math
import
random
from
typing
import
Optional
import
pytest
...
...
@@ -37,7 +39,14 @@ from transformer_engine.pytorch import (
import
transformer_engine_torch
as
tex
# Import utility functions
from
utils
import
dtype_tols
,
make_recipe
,
quantization_tols
,
reset_rng_states
from
utils
import
(
assert_close
,
assert_close_grads
,
dtype_tols
,
make_recipe
,
quantization_tols
,
reset_rng_states
,
)
if
IS_HIP_EXTENSION
:
import
os
...
...
@@ -116,6 +125,9 @@ def maybe_skip_quantization(
@
torch
.
no_grad
()
def
make_reference_and_test_tensors
(
shape
:
int
|
Iterable
[
int
],
*
,
min
:
float
=
0.0
,
max
:
float
=
1.0
,
quantization
:
Optional
[
str
]
=
None
,
ref_dtype
:
torch
.
dtype
=
torch
.
float64
,
ref_device
:
torch
.
device
=
"cpu"
,
...
...
@@ -136,7 +148,8 @@ def make_reference_and_test_tensors(
"""
# Random reference tensor
ref
=
torch
.
rand
(
shape
,
dtype
=
ref_dtype
,
device
=
ref_device
)
ref
=
torch
.
empty
(
shape
,
dtype
=
ref_dtype
,
device
=
ref_device
)
ref
.
uniform_
(
min
,
max
)
# Construct test tensor from reference tensor
test
=
ref
.
to
(
device
=
test_device
,
dtype
=
test_dtype
)
...
...
@@ -1569,7 +1582,19 @@ class TestBasicOps:
@
pytest
.
mark
.
parametrize
(
"activation"
,
(
"gelu"
,
"geglu"
,
"qgelu"
,
"qgeglu"
,
"relu"
,
"reglu"
,
"srelu"
,
"sreglu"
,
"silu"
,
"swiglu"
),
(
"gelu"
,
"geglu"
,
"qgelu"
,
"qgeglu"
,
"relu"
,
"reglu"
,
"glu"
,
"srelu"
,
"sreglu"
,
"silu"
,
"swiglu"
,
),
)
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
37
,),
(
2
,
13
),
(
32
,
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
...
...
@@ -1589,7 +1614,7 @@ class TestBasicOps:
# Tensor dimensions
in_shape
=
list
(
out_shape
)
if
activation
in
(
"geglu"
,
"qgeglu"
,
"reglu"
,
"sreglu"
,
"swiglu"
):
if
activation
in
(
"geglu"
,
"glu"
,
"qgeglu"
,
"reglu"
,
"sreglu"
,
"swiglu"
):
in_shape
[
-
1
]
*=
2
# Skip invalid configurations
...
...
@@ -1629,6 +1654,13 @@ class TestBasicOps:
elif
activation
==
"reglu"
:
x1
,
x2
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
y_ref
=
torch
.
nn
.
functional
.
relu
(
x1
)
*
x2
elif
activation
==
"sigmoid"
:
y_ref
=
torch
.
nn
.
functional
.
sigmoid
(
x_ref
)
elif
activation
==
"glu"
:
x
=
x_ref
.
reshape
(
*
in_shape
[:
-
1
],
2
,
in_shape
[
-
1
]
//
2
)
x
=
x
.
flip
(
-
2
)
# PyTorch GLU swaps gate and linear unit
x
=
x
.
reshape
(
in_shape
)
y_ref
=
torch
.
nn
.
functional
.
glu
(
x
)
elif
activation
==
"srelu"
:
y_ref
=
torch
.
nn
.
functional
.
relu
(
x_ref
)
**
2
elif
activation
==
"sreglu"
:
...
...
@@ -1648,6 +1680,7 @@ class TestBasicOps:
make_op
=
dict
(
gelu
=
te_ops
.
GELU
,
geglu
=
te_ops
.
GEGLU
,
glu
=
te_ops
.
GLU
,
qgelu
=
te_ops
.
QGELU
,
qgeglu
=
te_ops
.
QGEGLU
,
relu
=
te_ops
.
ReLU
,
...
...
@@ -1692,6 +1725,7 @@ class TestBasicOps:
quantization
:
Optional
[
str
],
quantize_forward
:
bool
,
quantize_backward
:
bool
,
glu_interleave_size
:
Optional
[
int
]
=
None
,
):
# Tensor dimensions
...
...
@@ -1718,7 +1752,17 @@ class TestBasicOps:
)
# Plain PyTorch implementation
x1
,
x2
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
x
=
x_ref
if
glu_interleave_size
is
not
None
:
x
=
x
.
reshape
(
*
in_shape
[:
-
1
],
in_shape
[
-
1
]
//
(
2
*
glu_interleave_size
),
2
,
glu_interleave_size
,
)
x
=
x
.
transpose
(
-
3
,
-
2
)
x
=
x
.
reshape
(
in_shape
)
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
y_ref
=
torch
.
nn
.
functional
.
silu
(
x1
)
*
x2
y_ref
.
backward
(
dy_ref
)
...
...
@@ -1726,7 +1770,7 @@ class TestBasicOps:
recipe
=
make_recipe
(
quantization
)
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
False
,
backward
=
quantize_backward
),
te_ops
.
SwiGLU
(),
te_ops
.
SwiGLU
(
glu_interleave_size
=
glu_interleave_size
),
te_ops
.
Quantize
(
forward
=
quantize_forward
,
backward
=
False
),
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
...
...
@@ -1739,10 +1783,19 @@ class TestBasicOps:
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
assert_close
(
y_test
,
y_ref
,
**
tols
)
assert_close_grads
(
x_test
,
x_ref
,
**
tols
)
def
test_interleaved_swiglu
(
self
):
"""SwiGLU with block interleaved input format"""
self
.
test_swiglu
(
out_shape
=
(
32
,
192
),
dtype
=
torch
.
float32
,
quantization
=
None
,
quantize_forward
=
False
,
quantize_backward
=
False
,
glu_interleave_size
=
32
,
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
...
...
@@ -1752,6 +1805,7 @@ class TestBasicOps:
self
,
*
,
out_shape
:
Iterable
[
int
]
=
(
32
,
32
),
glu_interleave_size
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
...
...
@@ -1760,7 +1814,7 @@ class TestBasicOps:
limit
:
float
=
0.75
,
alpha
:
float
=
1.702
,
):
# Test
SwiGLU variant used in GPT
OSS
.
"""
SwiGLU variant used in GPT
-
OSS
"""
# Tensor dimensions
in_shape
=
list
(
out_shape
)
in_shape
[
-
1
]
*=
2
...
...
@@ -1785,7 +1839,17 @@ class TestBasicOps:
)
# Plain PyTorch implementation
x_glu
,
x_linear
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
x
=
x_ref
if
glu_interleave_size
is
not
None
:
x
=
x
.
reshape
(
*
in_shape
[:
-
1
],
in_shape
[
-
1
]
//
(
2
*
glu_interleave_size
),
2
,
glu_interleave_size
,
)
x
=
x
.
transpose
(
-
3
,
-
2
)
x
=
x
.
reshape
(
in_shape
)
x_glu
,
x_linear
=
x
.
chunk
(
2
,
dim
=-
1
)
x_glu
=
x_glu
.
clamp
(
min
=
None
,
max
=
limit
)
x_linear
=
x_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
out_glu
=
x_glu
*
torch
.
sigmoid
(
alpha
*
x_glu
)
...
...
@@ -1797,7 +1861,11 @@ class TestBasicOps:
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
False
,
backward
=
quantize_backward
),
te_ops
.
ClampedSwiGLU
(
limit
=
limit
,
alpha
=
alpha
),
te_ops
.
ClampedSwiGLU
(
limit
=
limit
,
alpha
=
alpha
,
glu_interleave_size
=
glu_interleave_size
,
),
te_ops
.
Quantize
(
forward
=
quantize_forward
,
backward
=
False
),
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
...
...
@@ -1813,10 +1881,19 @@ class TestBasicOps:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
assert_close
(
y_test
,
y_ref
,
**
tols
)
assert_close_grads
(
x_test
,
x_ref
,
**
tols
)
def
test_interleaved_clamped_swiglu
(
self
):
"""GPT-OSS SwiGLU with block interleaved input format"""
self
.
test_clamped_swiglu
(
out_shape
=
(
32
,
192
),
dtype
=
torch
.
float32
,
quantization
=
None
,
quantize_forward
=
False
,
quantize_backward
=
False
,
glu_interleave_size
=
32
,
)
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
1
,
0
,
-
2.5
,
3.5
))
@
pytest
.
mark
.
parametrize
(
"shape"
,
((),
(
1
,
13
),
(
4
,
4
,
2
)))
...
...
@@ -1936,6 +2013,231 @@ class TestBasicOps:
abs
(
z_score
)
<
2.5758
),
f
"Number of zeros is outside 99% confidence interval (
{
prob
=
}
,
{
prob_observed
=
}
)"
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantized_compute"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"input_requires_grad"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"weight_requires_grad"
,
(
False
,
True
))
def
test_grouped_linear
(
self
,
*
,
group_size
:
int
=
4
,
bias
:
bool
,
weight_shape
:
tuple
[
int
,
int
]
=
(
128
,
128
),
split_alignment
:
int
=
128
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantized_compute
:
bool
,
quantized_weight
:
bool
,
input_requires_grad
:
bool
,
weight_requires_grad
:
bool
,
)
->
None
:
"""Grouped GEMM"""
# Split sizes
split_sizes
=
[
split_alignment
*
i
for
i
in
range
(
group_size
)]
random
.
shuffle
(
split_sizes
)
split_sizes
=
torch
.
tensor
(
split_sizes
,
dtype
=
torch
.
int
,
device
=
device
)
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
in_shape
=
(
split_sizes
.
sum
().
item
(),
in_features
)
out_shape
=
(
in_shape
[
0
],
out_features
)
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
if
quantization
is
None
and
(
quantized_compute
or
quantized_weight
):
pytest
.
skip
(
"Quantization scheme is not specified"
)
if
quantization
is
not
None
and
not
(
quantized_compute
or
quantized_weight
):
pytest
.
skip
(
"Quantization scheme is not used"
)
if
quantization
is
not
None
and
dtype
not
in
(
torch
.
bfloat16
,
torch
.
float16
):
pytest
.
skip
(
"Quantized group GEMM is only supported with BF16/FP16"
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
input_requires_grad
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
ws_ref
,
ws_test
=
[],
[]
bs_ref
,
bs_test
=
[],
[]
for
_
in
range
(
group_size
):
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
weight_requires_grad
,
)
b_ref
,
b_test
=
None
,
None
if
bias
:
b_ref
,
b_test
=
make_reference_and_test_tensors
(
out_features
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
weight_requires_grad
,
)
ws_ref
.
append
(
w_ref
)
ws_test
.
append
(
w_test
)
bs_ref
.
append
(
b_ref
)
bs_test
.
append
(
b_test
)
# Plain PyTorch implementation
xs_ref
=
torch
.
split
(
x_ref
,
split_sizes
.
tolist
())
ys_ref
=
[]
for
x
,
w
,
b
in
zip
(
xs_ref
,
ws_ref
,
bs_ref
):
ys_ref
.
append
(
torch
.
nn
.
functional
.
linear
(
x
,
w
,
bias
=
b
))
y_ref
=
torch
.
cat
(
ys_ref
)
if
input_requires_grad
or
weight_requires_grad
:
y_ref
.
backward
(
dy_ref
)
# Construct fusible operation
recipe
=
make_recipe
(
quantization
)
with
te
.
quantized_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
op
=
te_ops
.
GroupedLinear
(
group_size
,
in_features
,
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
with
torch
.
no_grad
():
for
group_idx
in
range
(
group_size
):
getattr
(
op
,
f
"weight
{
group_idx
}
"
).
copy_
(
ws_test
[
group_idx
])
if
bias
:
getattr
(
op
,
f
"bias
{
group_idx
}
"
).
copy_
(
bs_test
[
group_idx
])
del
ws_test
,
bs_test
for
param
in
op
.
parameters
():
param
.
requires_grad_
(
requires_grad
=
weight_requires_grad
)
# Forward and backward pass with op
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
op
(
x_test
,
split_sizes
)
if
input_requires_grad
or
weight_requires_grad
:
y_test
.
backward
(
dy_test
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
quantized_compute
:
tols
=
quantization_tols
(
quantization
)
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
if
input_requires_grad
:
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
else
:
assert
x_test
.
grad
is
None
for
group_idx
in
range
(
group_size
):
w_test
=
getattr
(
op
,
f
"weight
{
group_idx
}
"
)
if
weight_requires_grad
:
dw_test
=
w_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
dw_test
,
ws_ref
[
group_idx
].
grad
,
**
tols
)
else
:
assert
w_test
.
grad
is
None
if
bias
:
b_test
=
getattr
(
op
,
f
"bias
{
group_idx
}
"
)
if
weight_requires_grad
:
db_test
=
b_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
db_test
,
bs_ref
[
group_idx
].
grad
,
**
tols
)
else
:
assert
b_test
.
grad
is
None
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
71
,
192
),
(
5
,
7
,
128
)))
@
pytest
.
mark
.
parametrize
(
"input_requires_grad"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"scales_requires_grad"
,
(
False
,
True
))
def
test_scaled_swiglu
(
self
,
*
,
in_shape
:
Iterable
[
int
],
glu_interleave_size
:
Optional
[
int
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
"cuda"
,
input_requires_grad
:
bool
,
scales_requires_grad
:
bool
,
)
->
None
:
"""SwiGLU with post-scale"""
# Tensor dims
out_shape
=
list
(
in_shape
)
out_shape
[
-
1
]
//=
2
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
input_requires_grad
,
)
scales_ref
,
scales_test
=
make_reference_and_test_tensors
(
in_shape
[:
-
1
],
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
scales_requires_grad
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
x
=
x_ref
if
glu_interleave_size
is
not
None
:
x
=
x
.
reshape
(
-
1
,
in_shape
[
-
1
]
//
(
2
*
glu_interleave_size
),
2
,
glu_interleave_size
,
)
x
=
x
.
transpose
(
1
,
2
)
x
=
x
.
reshape
(
in_shape
)
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
y
=
torch
.
nn
.
functional
.
silu
(
x1
)
*
x2
y_ref
=
scales_ref
.
unsqueeze
(
-
1
)
*
y
if
input_requires_grad
or
scales_requires_grad
:
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
op
=
te_ops
.
ScaledSwiGLU
(
glu_interleave_size
=
glu_interleave_size
)
y_test
=
op
(
x_test
,
scales_test
)
if
input_requires_grad
or
scales_requires_grad
:
y_test
.
backward
(
dy_test
)
# Check results
tols
=
dtype_tols
(
dtype
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
assert_close
(
y_test
,
y_ref
,
**
tols
)
assert_close_grads
(
x_test
,
x_ref
,
**
tols
)
assert_close_grads
(
scales_test
,
scales_ref
,
**
tols
)
def
test_interleaved_scaled_swiglu
(
self
):
"""SwiGLU with post-scale and block interleaved input format"""
self
.
test_scaled_swiglu
(
in_shape
=
(
32
,
192
),
glu_interleave_size
=
32
,
input_requires_grad
=
True
,
scales_requires_grad
=
True
,
)
class
TestFusedOps
:
"""Tests for fused operations"""
...
...
@@ -2342,13 +2644,13 @@ class TestFusedOps:
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
if
with_quantization
:
assert
len
(
backward_ops
)
==
2
assert
isinstance
(
backward_ops
[
0
][
0
],
BackwardActivationBias
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Quantize
)
assert
isinstance
(
backward_ops
[
0
][
0
],
te_ops
.
Quantize
)
assert
isinstance
(
backward_ops
[
1
][
0
],
BackwardActivationBias
)
else
:
assert
len
(
backward_ops
)
==
3
assert
isinstance
(
backward_ops
[
0
][
0
],
act_typ
e
)
assert
isinstance
(
backward_ops
[
0
][
0
],
te_ops
.
Quantiz
e
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Bias
)
assert
isinstance
(
backward_ops
[
2
][
0
],
te_ops
.
Quantiz
e
)
assert
isinstance
(
backward_ops
[
2
][
0
],
act_typ
e
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
...
...
@@ -2944,3 +3246,499 @@ class TestSequentialModules:
if
bias
:
torch
.
testing
.
assert_close
(
to_cpu
(
ffn1
.
bias
.
grad
),
b1_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
to_cpu
(
ffn2
.
bias
.
grad
),
b2_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"glu_interleave_size"
,
(
None
,
32
))
def
test_grouped_mlp
(
self
,
*
,
group_size
:
int
=
4
,
bias
:
bool
,
hidden_size
:
int
=
256
,
dtype
:
torch
.
dtype
,
quantization
:
Optional
[
str
],
device
:
torch
.
device
=
"cuda"
,
split_alignment
:
int
=
256
,
glu_interleave_size
:
Optional
[
int
],
)
->
None
:
"""GroupedLinear + ScaledSwiGLU + GroupedLinear"""
# Split sizes
split_sizes
=
[
split_alignment
*
i
for
i
in
range
(
group_size
)]
random
.
shuffle
(
split_sizes
)
split_sizes
=
torch
.
tensor
(
split_sizes
,
dtype
=
torch
.
int
,
device
=
device
)
# Make input shape
in_shape
=
(
split_sizes
.
sum
().
item
(),
hidden_size
)
out_shape
=
in_shape
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
,
dtype
=
dtype
)
if
with_quantization
and
dtype
not
in
(
torch
.
bfloat16
,
torch
.
float16
):
pytest
.
skip
(
"Quantized group GEMM is only supported with BF16/FP16"
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
min
=-
0.25
,
max
=
0.25
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
min
=-
0.25
,
max
=
0.25
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
probs_ref
,
probs_test
=
make_reference_and_test_tensors
(
(
in_shape
[
0
],),
test_dtype
=
dtype
,
test_device
=
device
,
)
fc1_ws_ref
,
fc1_ws_test
=
[],
[]
fc1_bs_ref
,
fc1_bs_test
=
[],
[]
fc2_ws_ref
,
fc2_ws_test
=
[],
[]
fc2_bs_ref
,
fc2_bs_test
=
[],
[]
for
_
in
range
(
group_size
):
fc1_w_ref
,
fc1_w_test
=
make_reference_and_test_tensors
(
(
2
*
hidden_size
,
hidden_size
),
min
=-
0.25
,
max
=
0.25
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
fc2_w_ref
,
fc2_w_test
=
make_reference_and_test_tensors
(
(
hidden_size
,
hidden_size
),
min
=-
0.25
,
max
=
0.25
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
fc1_b_ref
,
fc1_b_test
=
None
,
None
fc2_b_ref
,
fc2_b_test
=
None
,
None
if
bias
:
fc1_b_ref
,
fc1_b_test
=
make_reference_and_test_tensors
(
(
2
*
hidden_size
,),
min
=-
0.5
,
max
=
0.5
,
test_dtype
=
dtype
,
test_device
=
device
,
)
fc2_b_ref
,
fc2_b_test
=
make_reference_and_test_tensors
(
(
hidden_size
,),
min
=-
0.5
,
max
=
0.5
,
test_dtype
=
dtype
,
test_device
=
device
,
)
fc1_ws_ref
.
append
(
fc1_w_ref
)
fc1_bs_ref
.
append
(
fc1_b_ref
)
fc1_ws_test
.
append
(
fc1_w_test
)
fc1_bs_test
.
append
(
fc1_b_test
)
fc2_ws_ref
.
append
(
fc2_w_ref
)
fc2_bs_ref
.
append
(
fc2_b_ref
)
fc2_ws_test
.
append
(
fc2_w_test
)
fc2_bs_test
.
append
(
fc2_b_test
)
# Reference implementation
xs
=
torch
.
split
(
x_ref
,
split_sizes
.
tolist
())
probs
=
torch
.
split
(
probs_ref
,
split_sizes
.
tolist
())
ys
=
[]
for
group_idx
in
range
(
group_size
):
x
=
xs
[
group_idx
]
x
=
torch
.
nn
.
functional
.
linear
(
x
,
fc1_ws_ref
[
group_idx
],
bias
=
fc1_bs_ref
[
group_idx
])
if
glu_interleave_size
is
not
None
:
x
=
x
.
reshape
(
-
1
,
2
*
hidden_size
//
(
2
*
glu_interleave_size
),
2
,
glu_interleave_size
,
)
x
=
x
.
transpose
(
1
,
2
)
x
=
x
.
reshape
(
-
1
,
2
*
hidden_size
)
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
x
=
torch
.
nn
.
functional
.
silu
(
x1
)
*
x2
x
=
x
*
probs
[
group_idx
].
unsqueeze
(
-
1
)
x
=
torch
.
nn
.
functional
.
linear
(
x
,
fc2_ws_ref
[
group_idx
],
bias
=
fc2_bs_ref
[
group_idx
])
ys
.
append
(
x
)
y_ref
=
torch
.
cat
(
ys
)
y_ref
.
backward
(
dy_ref
)
# Construct operations
recipe
=
make_recipe
(
quantization
)
with
te
.
quantized_model_init
(
enabled
=
with_quantization
,
recipe
=
recipe
):
fc1
=
te_ops
.
GroupedLinear
(
group_size
,
hidden_size
,
2
*
hidden_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
fc2
=
te_ops
.
GroupedLinear
(
group_size
,
hidden_size
,
hidden_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
module
=
te_ops
.
Sequential
(
fc1
,
te_ops
.
ScaledSwiGLU
(
glu_interleave_size
=
glu_interleave_size
),
fc2
,
)
# Copy weights
with
torch
.
no_grad
():
for
group_idx
in
range
(
group_size
):
getattr
(
fc1
,
f
"weight
{
group_idx
}
"
).
copy_
(
fc1_ws_test
[
group_idx
])
getattr
(
fc2
,
f
"weight
{
group_idx
}
"
).
copy_
(
fc2_ws_test
[
group_idx
])
if
bias
:
getattr
(
fc1
,
f
"bias
{
group_idx
}
"
).
copy_
(
fc1_bs_test
[
group_idx
])
getattr
(
fc2
,
f
"bias
{
group_idx
}
"
).
copy_
(
fc2_bs_test
[
group_idx
])
del
fc1_ws_test
,
fc1_bs_test
,
fc2_ws_test
,
fc2_bs_test
# Fuse ops and perform forward and backward pass
with
te
.
autocast
(
enabled
=
with_quantization
,
recipe
=
recipe
):
y_test
=
module
(
x_test
,
split_sizes
,
probs_test
,
split_sizes
)
y_test
.
backward
(
dy_test
)
# Loose tols for sanity checking
tols
=
{
"rtol"
:
0.125
,
"atol"
:
0.25
}
if
quantization
==
"nvfp4"
:
tols
=
{
"rtol"
:
0.25
,
"atol"
:
0.5
}
# Check values
assert_close
(
y_test
,
y_ref
,
**
tols
)
assert_close_grads
(
x_test
,
x_ref
,
**
tols
)
assert_close_grads
(
probs_test
,
probs_ref
,
**
tols
)
for
group_idx
in
range
(
group_size
):
assert_close_grads
(
getattr
(
fc2
,
f
"weight
{
group_idx
}
"
),
fc2_ws_ref
[
group_idx
],
**
tols
)
assert_close_grads
(
getattr
(
fc2
,
f
"bias
{
group_idx
}
"
),
fc2_bs_ref
[
group_idx
],
**
tols
)
assert_close_grads
(
getattr
(
fc1
,
f
"weight
{
group_idx
}
"
),
fc1_ws_ref
[
group_idx
],
**
tols
)
assert_close_grads
(
getattr
(
fc1
,
f
"bias
{
group_idx
}
"
),
fc1_bs_ref
[
group_idx
],
**
tols
)
class
TestCustomOps
:
"""Test with ops that are defined externally"""
def
test_custom_basic_op
(
self
,
*
,
shape
:
Iterable
[
int
]
=
(
7
,
5
),
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
"cuda"
,
)
->
None
:
"""Custom basic op"""
class
CustomScaleOp
(
te
.
ops
.
BasicOperation
):
"""Custom op that applies a learnable scale"""
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
scale
:
torch
.
nn
.
Parameter
scale
=
torch
.
ones
((),
dtype
=
dtype
,
device
=
device
)
scale
=
torch
.
nn
.
Parameter
(
scale
)
self
.
register_parameter
(
"scale"
,
scale
)
def
op_forward
(
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_output_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
ctx
.
save_for_backward
(
self
.
scale
,
input_
)
return
self
.
scale
*
input_
def
op_backward
(
self
,
ctx
:
OperationContext
,
grad_output
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
(
scale
,
input_
,
)
=
ctx
.
saved_tensors
grad_scale
=
torch
.
inner
(
input_
.
reshape
(
-
1
),
grad_output
.
reshape
(
-
1
))
grad_scale
=
grad_scale
.
reshape
(())
grad_input
=
scale
*
grad_output
return
grad_input
,
(
grad_scale
,)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(),
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
w_ref
*
x_ref
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
op
=
CustomScaleOp
()
forward
=
te
.
ops
.
Sequential
(
te
.
ops
.
Identity
(),
op
,
te
.
ops
.
Identity
())
with
torch
.
no_grad
():
op
.
scale
.
copy_
(
w_test
)
del
w_test
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
# Check results
tols
=
dtype_tols
(
dtype
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
op
.
scale
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
def
test_custom_forward_fused_op
(
self
,
*
,
shape
:
Iterable
[
int
]
=
(
7
,
11
),
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
"cuda"
,
):
"""Custom fused op in forward pass"""
class
CustomForwardLinearSiLU
(
te
.
ops
.
FusedOperation
):
"""Custom fused op for GEMM + SiLU"""
_enabled
=
True
def
__init__
(
self
,
*
,
linear
,
silu
)
->
None
:
super
().
__init__
((
linear
,
silu
))
def
fuser_forward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
input_
:
torch
.
Tensor
,
**
unused
,
)
->
torch
.
Tensor
:
weight
=
self
.
basic_ops
[
0
].
weight
dtype
=
weight
.
dtype
device
=
weight
.
device
# Perform compute on CPU, because why not?
x
=
input_
.
cpu
()
w
=
weight
.
cpu
()
y
=
torch
.
matmul
(
x
,
w
.
T
)
z
=
torch
.
nn
.
functional
.
silu
(
y
)
out
=
z
.
to
(
device
=
device
)
# Save state for linear backward
linear_op_ctx
=
basic_op_ctxs
[
0
]
linear_op_ctx
.
save_for_backward
(
input_
,
weight
)
linear_op_ctx
.
with_quantized_compute
=
False
linear_op_ctx
.
input_quantizer
=
None
linear_op_ctx
.
weight_quantizer
=
None
linear_op_ctx
.
grad_output_quantizer
=
None
linear_op_ctx
.
grad_input_quantizer
=
None
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_requires_grad
=
True
linear_op_ctx
.
weight_requires_grad
=
True
# Save state for SiLU backward
silu_op_ctx
=
basic_op_ctxs
[
1
]
silu_op_ctx
.
save_for_backward
(
y
.
to
(
device
=
device
))
silu_op_ctx
.
dtype
=
dtype
silu_op_ctx
.
prev_op_grad_output_quantizer
=
None
return
out
,
[(),
()]
@
staticmethod
def
fuse_ops
(
ops
:
list
[
FusibleOperation
],
**
unused
,
)
->
list
[
FusibleOperation
]:
"""Apply fusion the first time this function is called"""
if
CustomForwardLinearSiLU
.
_enabled
:
CustomForwardLinearSiLU
.
_enabled
=
False
op
=
CustomForwardLinearSiLU
(
linear
=
ops
[
0
],
silu
=
ops
[
1
])
return
[
op
]
+
ops
[
2
:]
return
ops
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
shape
[
-
1
],
shape
[
-
1
]),
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
linear
(
x_ref
,
w_ref
)
y_ref
=
torch
.
nn
.
functional
.
silu
(
y_ref
)
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
te
.
ops
.
register_forward_fusion
(
CustomForwardLinearSiLU
.
fuse_ops
)
model
=
te
.
ops
.
Sequential
(
te
.
ops
.
Linear
(
shape
[
-
1
],
shape
[
-
1
],
bias
=
False
),
te
.
ops
.
SiLU
(),
)
with
torch
.
no_grad
():
model
[
0
].
weight
.
copy_
(
w_test
)
del
w_test
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
# Check that forward operations have been fused
forward_ops
=
model
.
_module_groups
[
0
].
_forward_ops
assert
len
(
forward_ops
)
==
1
assert
isinstance
(
forward_ops
[
0
][
0
],
CustomForwardLinearSiLU
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
model
[
0
].
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
def
test_custom_backward_fused_op
(
self
,
*
,
shape
:
Iterable
[
int
]
=
(
13
,
5
),
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
"cuda"
,
):
"""Custom fused op in backward pass"""
class
CustomBackwardLinearScale
(
te
.
ops
.
FusedOperation
):
"""Custom fused op for backward linear + scale"""
_enabled
:
bool
=
True
def
__init__
(
self
,
*
,
scale
,
linear
)
->
None
:
super
().
__init__
((
scale
,
linear
))
def
fuser_backward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
grad_output
:
torch
.
Tensor
,
**
unused
,
)
->
torch
.
Tensor
:
# Load state from linear forward
linear_op_ctx
=
basic_op_ctxs
[
1
]
x
,
w
=
linear_op_ctx
.
saved_tensors
dtype
=
linear_op_ctx
.
dtype
device
=
w
.
device
# Perform compute in FP64 and apply scale before dgrad
# GEMM instead of after
scale
=
self
.
basic_ops
[
0
].
scale
dy
=
grad_output
.
double
()
x
=
x
.
double
()
w
=
w
.
double
()
dx
=
torch
.
matmul
(
dy
,
scale
*
w
)
dw
=
torch
.
matmul
(
dy
.
T
,
x
)
dx
=
dx
.
to
(
dtype
=
dtype
)
dw
=
dw
.
to
(
dtype
=
dtype
)
return
dx
,
[(),
(
dw
,)],
[(),
()]
@
staticmethod
def
fuse_ops
(
ops
:
list
[
FusibleOperation
],
**
unused
,
)
->
list
[
FusibleOperation
]:
"""Apply fusion the first time this function is called"""
if
CustomBackwardLinearScale
.
_enabled
:
CustomBackwardLinearScale
.
_enabled
=
False
op
=
CustomBackwardLinearScale
(
scale
=
ops
[
0
],
linear
=
ops
[
1
])
return
[
op
]
+
ops
[
2
:]
return
ops
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
shape
[
-
1
],
shape
[
-
1
]),
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
scale
=
1.234
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
linear
(
scale
*
x_ref
,
w_ref
)
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operation
te
.
ops
.
register_backward_fusion
(
CustomBackwardLinearScale
.
fuse_ops
,
prepend
=
True
)
model
=
te
.
ops
.
Sequential
(
te
.
ops
.
ConstantScale
(
scale
),
te
.
ops
.
Linear
(
shape
[
-
1
],
shape
[
-
1
],
bias
=
False
),
)
with
torch
.
no_grad
():
model
[
1
].
weight
.
copy_
(
w_test
)
del
w_test
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
# Check that forward operations have been fused
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
assert
len
(
backward_ops
)
==
1
assert
isinstance
(
backward_ops
[
0
][
0
],
CustomBackwardLinearScale
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
# Check results
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
model
[
1
].
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
tests/pytorch/test_grouped_tensor.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for GroupedTensor class"""
from
typing
import
List
,
Tuple
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.tensor.storage.grouped_tensor
import
GroupedTensor
from
transformer_engine.pytorch
import
(
Quantizer
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
Float8BlockQuantizer
,
MXFP8Quantizer
,
NVFP4Quantizer
,
)
from
transformer_engine.pytorch.constants
import
TE_DType_To_Torch
import
transformer_engine_torch
as
tex
# Check available recipes
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
te
.
is_fp8_block_scaling_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
nvfp4_available
,
reason_for_no_nvfp4
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
_quantization_params
=
[
pytest
.
param
(
"fp8_delayed_scaling"
,
marks
=
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
),
),
pytest
.
param
(
"fp8_current_scaling"
,
marks
=
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
),
),
pytest
.
param
(
"fp8_blockwise"
,
marks
=
pytest
.
mark
.
skipif
(
not
fp8_block_scaling_available
,
reason
=
reason_for_no_fp8_block_scaling
),
),
pytest
.
param
(
"mxfp8"
,
marks
=
pytest
.
mark
.
skipif
(
not
mxfp8_available
,
reason
=
reason_for_no_mxfp8
),
),
pytest
.
param
(
"nvfp4"
,
marks
=
pytest
.
mark
.
skipif
(
not
nvfp4_available
,
reason
=
reason_for_no_nvfp4
),
),
]
def
make_quantizer
(
quantization
:
str
,
num_tensors
:
int
,
shape
:
List
[
Tuple
[
int
,
int
]])
->
Quantizer
:
"""Create quantizers for given quantization scheme"""
if
quantization
==
"fp8_delayed_scaling"
:
quantizer
=
Float8Quantizer
(
scale
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
),
amax
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
elif
quantization
==
"fp8_current_scaling"
:
quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
,
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
elif
quantization
==
"fp8_blockwise"
:
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
rowwise
=
True
,
columnwise
=
False
,
force_pow_2_scales
=
True
,
amax_epsilon
=
0.0
,
block_scaling_dim
=
1
,
)
elif
quantization
==
"mxfp8"
:
quantizer
=
MXFP8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
elif
quantization
==
"nvfp4"
:
quantizer
=
NVFP4Quantizer
(
with_rht
=
False
,
with_post_rht_amax
=
False
,
with_2d_quantization
=
False
,
stochastic_rounding
=
False
,
with_random_sign_mask
=
False
,
)
else
:
raise
ValueError
(
f
"Unknown quantization scheme:
{
quantization
}
"
)
quantizer
.
internal
=
False
return
quantizer
def
_get_rowwise_data_tensor
(
qtensor
,
quantization
:
str
)
->
torch
.
Tensor
:
if
quantization
in
(
"fp8_delayed_scaling"
,
"fp8_current_scaling"
):
return
qtensor
.
_data
if
quantization
in
(
"fp8_blockwise"
,
"mxfp8"
,
"nvfp4"
):
return
qtensor
.
_rowwise_data
raise
ValueError
(
f
"Unknown quantization scheme:
{
quantization
}
"
)
def
_rowwise_offset_bytes
(
numel
:
int
,
quantization
:
str
)
->
int
:
if
quantization
==
"nvfp4"
:
return
numel
//
2
return
numel
class
TestGroupedTensor
:
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
def
test_basic_construction_all_same_shape
(
self
)
->
None
:
"""Test GroupedTensor construction with all tensors having same shape"""
num_tensors
=
4
shape
=
[(
256
,
512
)
for
_
in
range
(
num_tensors
)]
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
assert
grouped_tensor
.
num_tensors
==
num_tensors
assert
grouped_tensor
.
all_same_shape
()
assert
grouped_tensor
.
all_same_first_dim
()
assert
grouped_tensor
.
all_same_last_dim
()
assert
grouped_tensor
.
logical_shape
==
(
num_tensors
*
256
,
512
)
assert
grouped_tensor
.
get_common_first_dim
()
==
256
assert
grouped_tensor
.
get_common_last_dim
()
==
512
assert
grouped_tensor
.
has_data
()
def
test_basic_construction_varying_first_dim
(
self
)
->
None
:
"""Test GroupedTensor construction with varying first dimension"""
num_tensors
=
3
shape
=
[(
128
,
512
),
(
256
,
512
),
(
384
,
512
)]
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
assert
grouped_tensor
.
num_tensors
==
num_tensors
assert
not
grouped_tensor
.
all_same_shape
()
assert
not
grouped_tensor
.
all_same_first_dim
()
assert
grouped_tensor
.
all_same_last_dim
()
assert
grouped_tensor
.
get_common_last_dim
()
==
shape
[
0
][
1
]
assert
grouped_tensor
.
logical_shape
==
(
sum
(
v
for
v
,
_
in
shape
),
shape
[
0
][
1
],
)
# sum of first dims
def
test_split_into_quantized_tensors_no_quantization
(
self
)
->
None
:
"""Test split_into_quantized_tensors for unquantized tensors"""
num_tensors
=
3
shape
=
[(
256
,
512
)
for
_
in
range
(
num_tensors
)]
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
# Get the original data pointer
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
# Split into tensors
tensors
=
grouped_tensor
.
split_into_quantized_tensors
()
assert
len
(
tensors
)
==
num_tensors
# Verify each tensor has correct shape and shares storage
for
i
,
tensor
in
enumerate
(
tensors
):
assert
tensor
.
shape
==
shape
[
i
]
assert
isinstance
(
tensor
,
torch
.
Tensor
)
assert
not
hasattr
(
tensor
,
"_data"
)
# Not a quantized tensor
# Verify data pointer is within the original grouped tensor storage
# The tensor should be a view of the original data
assert
tensor
.
data_ptr
()
>=
original_data_ptr
# Calculate expected offset
expected_offset
=
i
*
(
shape
[
i
][
0
]
*
shape
[
i
][
1
])
*
tensor
.
element_size
()
assert
tensor
.
data_ptr
()
==
original_data_ptr
+
expected_offset
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_params
)
def
test_split_into_quantized_tensors_quantized
(
self
,
quantization
:
str
)
->
None
:
"""Test split_into_quantized_tensors for quantized tensors"""
num_tensors
=
3
shape
=
[(
512
,
512
)
for
_
in
range
(
num_tensors
)]
quantizers
=
make_quantizer
(
quantization
,
num_tensors
,
shape
)
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
quantizers
,
device
=
"cuda"
,
)
# Get the original data pointer
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
# Split into tensors
tensors
=
grouped_tensor
.
split_into_quantized_tensors
()
assert
len
(
tensors
)
==
num_tensors
# Verify each tensor shares storage with the grouped tensor
for
i
,
tensor
in
enumerate
(
tensors
):
rowwise_data
=
_get_rowwise_data_tensor
(
tensor
,
quantization
)
assert
rowwise_data
is
not
None
assert
rowwise_data
.
data_ptr
()
>=
original_data_ptr
numel
=
shape
[
i
][
0
]
*
shape
[
i
][
1
]
expected_offset
=
_rowwise_offset_bytes
(
i
*
numel
,
quantization
)
assert
rowwise_data
.
data_ptr
()
==
original_data_ptr
+
expected_offset
def
test_split_varying_shapes
(
self
)
->
None
:
"""Test split_into_quantized_tensors with varying shapes"""
num_tensors
=
3
shape
=
[(
128
,
512
),
(
256
,
512
),
(
384
,
512
)]
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
tensors
=
grouped_tensor
.
split_into_quantized_tensors
()
assert
len
(
tensors
)
==
num_tensors
# Verify shapes and storage
cumulative_offset
=
0
for
i
,
tensor
in
enumerate
(
tensors
):
assert
tensor
.
shape
==
shape
[
i
]
expected_offset
=
cumulative_offset
*
tensor
.
element_size
()
assert
tensor
.
data_ptr
()
==
original_data_ptr
+
expected_offset
cumulative_offset
+=
shape
[
i
][
0
]
*
shape
[
i
][
1
]
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_params
)
def
test_quantize_inplace
(
self
,
quantization
:
str
)
->
None
:
"""Test that quantize is done in-place for all recipes"""
num_tensors
=
3
shape
=
[(
512
,
512
)
for
_
in
range
(
num_tensors
)]
quantizers
=
make_quantizer
(
quantization
,
num_tensors
,
shape
)
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
quantizers
,
device
=
"cuda"
,
)
# Get original data pointers before quantization
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
original_scale_inv_ptr
=
grouped_tensor
.
scale_inv
.
data_ptr
()
original_scale_ptr
=
(
grouped_tensor
.
scale
.
data_ptr
()
if
grouped_tensor
.
scale
is
not
None
else
None
)
# Create input tensors
input_tensors
=
[
torch
.
randn
(
s
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
s
in
shape
]
# Quantize in place
quantized_tensors
=
grouped_tensor
.
quantize
(
input_tensors
)
# Verify data pointers haven't changed (in-place operation)
assert
grouped_tensor
.
data
.
data_ptr
()
==
original_data_ptr
assert
grouped_tensor
.
scale_inv
.
data_ptr
()
==
original_scale_inv_ptr
if
original_scale_ptr
is
not
None
:
assert
grouped_tensor
.
scale
.
data_ptr
()
==
original_scale_ptr
# Verify returned tensors point to the same storage
for
i
,
qtensor
in
enumerate
(
quantized_tensors
):
rowwise_data
=
_get_rowwise_data_tensor
(
qtensor
,
quantization
)
numel
=
shape
[
i
][
0
]
*
shape
[
i
][
1
]
expected_offset
=
_rowwise_offset_bytes
(
i
*
numel
,
quantization
)
assert
rowwise_data
.
data_ptr
()
==
original_data_ptr
+
expected_offset
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_params
)
def
test_quantize_varying_shapes
(
self
,
quantization
:
str
)
->
None
:
"""Test quantize with varying shapes"""
num_tensors
=
3
shape
=
[(
256
,
512
),
(
512
,
512
),
(
768
,
512
)]
quantizers
=
make_quantizer
(
quantization
,
num_tensors
,
shape
)
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
quantizers
,
device
=
"cuda"
,
)
# Get original data pointers
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
# Create input tensors with varying shapes
input_tensors
=
[
torch
.
randn
(
s
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
s
in
shape
]
# Quantize in place
quantized_tensors
=
grouped_tensor
.
quantize
(
input_tensors
)
# Verify data pointer hasn't changed
assert
grouped_tensor
.
data
.
data_ptr
()
==
original_data_ptr
# Verify each tensor points to correct location
cumulative_numel
=
0
for
qtensor
,
tensor_shape
in
zip
(
quantized_tensors
,
shape
):
rowwise_data
=
_get_rowwise_data_tensor
(
qtensor
,
quantization
)
expected_offset
=
_rowwise_offset_bytes
(
cumulative_numel
,
quantization
)
assert
rowwise_data
.
data_ptr
()
==
original_data_ptr
+
expected_offset
cumulative_numel
+=
tensor_shape
[
0
]
*
tensor_shape
[
1
]
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_params
)
def
test_static_quantize_method
(
self
,
quantization
:
str
)
->
None
:
"""Test the static quantize method"""
num_tensors
=
3
shape
=
[(
512
,
512
)
for
_
in
range
(
num_tensors
)]
quantizers
=
make_quantizer
(
quantization
,
num_tensors
,
shape
)
# Create input tensors
input_tensors
=
[
torch
.
randn
(
s
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
s
in
shape
]
# Use static quantize method
grouped_tensor
=
GroupedTensor
.
create_and_quantize
(
tensors
=
input_tensors
,
quantizer
=
quantizers
,
device
=
"cuda"
,
)
# Verify the grouped tensor was created correctly
assert
grouped_tensor
.
num_tensors
==
num_tensors
assert
grouped_tensor
.
has_data
()
# Verify quantized_tensors were created and point to same storage
assert
grouped_tensor
.
quantized_tensors
is
not
None
assert
len
(
grouped_tensor
.
quantized_tensors
)
==
num_tensors
original_data_ptr
=
grouped_tensor
.
data
.
data_ptr
()
for
i
,
qtensor
in
enumerate
(
grouped_tensor
.
quantized_tensors
):
rowwise_data
=
_get_rowwise_data_tensor
(
qtensor
,
quantization
)
numel
=
shape
[
i
][
0
]
*
shape
[
i
][
1
]
expected_offset
=
_rowwise_offset_bytes
(
i
*
numel
,
quantization
)
assert
rowwise_data
.
data_ptr
()
==
original_data_ptr
+
expected_offset
def
test_clear
(
self
)
->
None
:
"""Test clear method"""
num_tensors
=
3
shape
=
[(
256
,
512
)
for
_
in
range
(
num_tensors
)]
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
num_tensors
,
shape
=
shape
,
quantizer
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
assert
grouped_tensor
.
has_data
()
assert
grouped_tensor
.
num_tensors
==
num_tensors
grouped_tensor
.
clear
()
assert
not
grouped_tensor
.
has_data
()
assert
grouped_tensor
.
num_tensors
==
0
assert
grouped_tensor
.
data
is
None
assert
grouped_tensor
.
logical_shape
==
(
0
,
0
)
tests/pytorch/test_numerics.py
View file @
9df0c4a3
...
...
@@ -94,6 +94,7 @@ all_boolean = [True, False]
all_activations
=
[
"gelu"
,
"geglu"
,
"glu"
,
"qgelu"
,
"qgeglu"
,
"relu"
,
...
...
@@ -484,6 +485,7 @@ class TorchGroupedLinearWithPadding(nn.Module):
_supported_act
=
{
"gelu"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"geglu"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"glu"
:
nn
.
Sigmoid
(),
"qgelu"
:
TorchQuickGELU
(),
"qgeglu"
:
TorchQuickGELU
(),
"relu"
:
nn
.
ReLU
(),
...
...
tests/pytorch/test_onnx_export.py
View file @
9df0c4a3
...
...
@@ -745,6 +745,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation):
_test_export_layernorm_mlp
(
activation
=
activation
)
# Quantization recipes with fp8_dpa=True for attention emulation export test
dpa_quantization_recipes
=
[
None
]
# None = no quantization
if
fp8_available
:
dpa_quantization_recipes
.
append
(
recipe
.
DelayedScaling
(
fp8_dpa
=
True
))
dpa_quantization_recipes
.
append
(
recipe
.
Float8CurrentScaling
(
fp8_dpa
=
True
))
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
dpa_quantization_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision, use_mask, attn_mask_type"
,
[
...
...
@@ -762,6 +770,7 @@ def test_export_core_attention(
precision
:
torch
.
dtype
,
use_mask
:
bool
,
attn_mask_type
:
str
,
fp8_recipe
:
recipe
.
Recipe
,
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ONNX is not currently required in hip"
)
...
...
@@ -783,22 +792,25 @@ def test_export_core_attention(
mask_str
=
get_attn_mask_str
(
use_mask
,
attn_mask_type
)
high_prec_str
=
dtype2str
(
precision
)
fname
=
f
"te.core_attention
{
mask_str
}{
high_prec_str
}
.onnx"
fp8_str
=
"_fp8_dpa"
if
fp8_recipe
is
not
None
else
""
fname
=
f
"te.core_attention
{
fp8_str
}{
mask_str
}{
high_prec_str
}
.onnx"
is_fp8
=
fp8_recipe
is
not
None
model
=
te
.
attention
.
DotProductAttention
(
num_attention_heads
=
num_attention_heads
,
kv_channels
=
kv_channels
,
attention_dropout
=
0.5
,
qkv_format
=
qkv_format
,
attn_mask_type
=
attn_mask_type
,
).
to
(
device
=
"cuda"
)
do_export
(
model
,
inp
,
fname
,
input_names
=
input_names
,
fp8_recipe
=
Non
e
)
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
False
,
fp8_recipe
=
Non
e
)
do_export
(
model
,
inp
,
fname
,
input_names
=
input_names
,
fp8_recipe
=
fp8_recip
e
)
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
is_fp8
,
fp8_recipe
=
fp8_recip
e
)
serialize_inputs_outputs
(
fname
,
inp
,
te_outputs
,
input_names
=
input_names
)
if
precision
in
(
torch
.
bfloat16
,):
return
atol
=
5e-1
if
is_fp8
else
1e-2
validate_result
(
fname
,
inp
,
model
,
is_fp8
=
True
,
atol
=
1e-2
,
input_names
=
input_names
,
te_outputs
=
te_outputs
fname
,
inp
,
model
,
is_fp8
=
True
,
atol
=
atol
,
input_names
=
input_names
,
te_outputs
=
te_outputs
)
...
...
tests/pytorch/test_sanity.py
View file @
9df0c4a3
...
...
@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
from
typing
import
Optional
from
typing
import
Optional
,
List
import
torch
import
pytest
...
...
@@ -114,6 +114,7 @@ batch_sizes_with_zero = [0, 1, 2]
all_activations
=
[
"gelu"
,
"geglu"
,
"glu"
,
"qgelu"
,
"qgeglu"
,
"relu"
,
...
...
@@ -138,6 +139,117 @@ def reset_global_fp8_state():
FP8GlobalStateManager
.
reset
()
def
check_grouped_tensor_pointers_helper
(
tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"tensor"
):
"""
Verify that tensors are stored in contiguous memory.
Args:
tensors: List or iterable of tensors to check
num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4)
tensor_name: Name to use in error messages
"""
tensor_list
=
list
(
tensors
)
if
len
(
tensor_list
)
<
2
:
return
# Nothing to check
for
i
in
range
(
1
,
len
(
tensor_list
)):
prev_tensor
=
tensor_list
[
i
-
1
]
curr_tensor
=
tensor_list
[
i
]
# Calculate expected offset based on previous tensor size
prev_numel
=
prev_tensor
.
numel
()
expected_offset
=
(
prev_numel
//
num_elems_in_byte
)
*
prev_tensor
.
element_size
()
# Verify current tensor's data pointer is correctly offset
expected_ptr
=
prev_tensor
.
data_ptr
()
+
expected_offset
actual_ptr
=
curr_tensor
.
data_ptr
()
assert
(
actual_ptr
==
expected_ptr
),
f
"
{
tensor_name
}
{
i
}
data pointer mismatch: expected
{
expected_ptr
}
, got
{
actual_ptr
}
"
def
check_grouped_tensor_pointers
(
weights
:
List
[
torch
.
Tensor
],
fp8_recipe
:
Optional
[
recipe
.
Recipe
]
=
None
):
"""
Verify that the pointers of the weights are in contiguous memory for GroupedTensor.
TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach.
"""
num_elems_in_a_data_byte
=
1
if
fp8_recipe
is
None
else
2
if
fp8_recipe
.
nvfp4
()
else
1
# Check data.
if
hasattr
(
weights
[
0
],
"_data"
)
and
weights
[
0
].
_data
is
not
None
:
data_tensors
=
[
w
.
_data
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
data_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"data"
)
# Check transpose.
if
hasattr
(
weights
[
0
],
"_transpose"
)
and
weights
[
0
].
_transpose
is
not
None
:
transpose_tensors
=
[
w
.
_transpose
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
transpose_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"transpose"
)
# Check scale_inv.
if
hasattr
(
weights
[
0
],
"_scale_inv"
)
and
weights
[
0
].
_scale_inv
is
not
None
:
scale_inv_tensors
=
[
w
.
_scale_inv
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
scale_inv_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"scale_inv"
)
# Check rowwise scale_inv.
if
hasattr
(
weights
[
0
],
"_rowwise_scale_inv"
)
and
weights
[
0
].
_rowwise_scale_inv
is
not
None
:
scale_inv_tensors
=
[
w
.
_rowwise_scale_inv
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
scale_inv_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"rowwise_scale_inv"
)
# Check columnwise scale_inv.
if
(
hasattr
(
weights
[
0
],
"_columnwise_scale_inv"
)
and
weights
[
0
].
_columnwise_scale_inv
is
not
None
):
columnwise_scale_inv_tensors
=
[
w
.
_columnwise_scale_inv
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
columnwise_scale_inv_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"columnwise scale_inv"
,
)
# Check rowwise amax.
if
hasattr
(
weights
[
0
],
"_rowwise_amax"
)
and
weights
[
0
].
_rowwise_amax
is
not
None
:
rowwise_amax_tensors
=
[
w
.
_rowwise_amax
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
rowwise_amax_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"rowwise amax"
)
# Check columnwise amax.
if
hasattr
(
weights
[
0
],
"_columnwise_amax"
)
and
weights
[
0
].
_columnwise_amax
is
not
None
:
columnwise_amax_tensors
=
[
w
.
_columnwise_amax
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
columnwise_amax_tensors
,
num_elems_in_byte
=
1
,
tensor_name
=
"columnwise amax"
)
# Check rowwise data.
if
hasattr
(
weights
[
0
],
"_rowwise_data"
)
and
weights
[
0
].
_rowwise_data
is
not
None
:
rowwise_data_tensors
=
[
w
.
_rowwise_data
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
rowwise_data_tensors
,
num_elems_in_byte
=
num_elems_in_a_data_byte
,
tensor_name
=
"rowwise data"
,
)
# Check columnwise data.
if
hasattr
(
weights
[
0
],
"_columnwise_data"
)
and
weights
[
0
].
_columnwise_data
is
not
None
:
columnwise_data_tensors
=
[
w
.
_columnwise_data
for
w
in
weights
]
check_grouped_tensor_pointers_helper
(
columnwise_data_tensors
,
num_elems_in_byte
=
num_elems_in_a_data_byte
,
tensor_name
=
"columnwise data"
,
)
def
_test_sanity_e2e_amp
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
max_seqlen_q
,
config
.
batch_size
,
config
.
hidden_size
),
...
...
@@ -486,10 +598,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"single_param"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"empty_split"
,
[
"first"
,
"last"
,
"middle"
])
@
pytest
.
mark
.
parametrize
(
"num_gemms"
,
[
4
])
def
test_sanity_grouped_linear
(
dtype
,
bs
,
model
,
fp8_recipe
,
fp8_model_params
,
use_bias
,
num_gemms
,
empty_split
dtype
,
bs
,
model
,
fp8_recipe
,
fp8_model_params
,
use_bias
,
single_param
,
num_gemms
,
empty_split
,
):
if
NVTE_TEST_NVINSPECT_ENABLED
and
fp8_model_params
:
pytest
.
skip
(
"FP8 model parameters are not supported in debug mode."
)
...
...
@@ -499,6 +620,9 @@ def test_sanity_grouped_linear(
bs
=
bs
*
16
num_tokens
=
bs
*
config
.
max_seqlen_q
*
(
num_gemms
-
1
)
if
single_param
:
os
.
environ
[
"NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"
]
=
"1"
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
...
...
@@ -508,9 +632,19 @@ def test_sanity_grouped_linear(
use_fp8
=
fp8_recipe
is
not
None
with
quantized_model_init
(
enabled
=
use_fp8
and
fp8_model_params
,
recipe
=
fp8_recipe
):
te_grouped_linear
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
ffn_hidden_size
,
bias
=
use_bias
,
params_dtype
=
dtype
num_gemms
,
config
.
hidden_size
,
ffn_hidden_size
,
bias
=
use_bias
,
params_dtype
=
dtype
,
).
cuda
()
# Verify that weights are stored in contiguous GroupedTensor storage.
weights
=
[
getattr
(
te_grouped_linear
,
f
"weight
{
i
}
"
)
for
i
in
range
(
num_gemms
)]
if
fp8_recipe
is
None
or
not
(
fp8_recipe
.
delayed
()
or
fp8_recipe
.
float8_current_scaling
()):
if
single_param
:
check_grouped_tensor_pointers
(
weights
,
fp8_recipe
)
inp_hidden_states
=
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
requires_grad
=
True
).
cuda
()
...
...
@@ -528,6 +662,9 @@ def test_sanity_grouped_linear(
loss
.
backward
()
assert
out
.
shape
==
(
num_tokens
,
ffn_hidden_size
)
if
single_param
:
del
os
.
environ
[
"NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"
]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
...
...
@@ -1005,7 +1142,13 @@ def test_replace_raw_data_for_float8tensor():
random_bf16_data
=
torch
.
randn
(
fp8_tensor
.
shape
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
fp8_quantizer
.
update_quantized
(
random_bf16_data
,
fp8_tensor
)
attrs_to_check
=
[
"_quantizer"
,
"_fp8_dtype"
,
"_scale_inv"
,
"_transpose"
,
"_transpose_invalid"
]
attrs_to_check
=
[
"_quantizer"
,
"_fp8_dtype"
,
"_scale_inv"
,
"_transpose"
,
"_transpose_invalid"
,
]
attrs
=
{}
for
attr
in
attrs_to_check
:
attrs
[
attr
]
=
getattr
(
fp8_tensor
,
attr
)
...
...
tests/pytorch/utils.py
View file @
9df0c4a3
...
...
@@ -15,7 +15,7 @@ import torch
import
transformer_engine
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
InferenceParams
from
transformer_engine.pytorch
import
InferenceParams
,
QuantizedTensor
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
get_attention_backend
,
...
...
@@ -353,7 +353,7 @@ def get_available_attention_backends(
backends
=
{
0
:
"F16_max512_seqlen"
,
1
:
"F16_arbitrary_seqlen"
,
2
:
"FP8"
}
if
AttentionLogging
.
_is_logging_setup
is
False
:
AttentionLogging
.
setup_logging
()
with
logging_context
(
highest_level
=
AttentionLogging
.
_log_level
):
for
i
in
range
(
3
):
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
str
(
i
)
_attention_backends
[
"backend_selection_requires_update"
]
=
True
...
...
@@ -361,3 +361,48 @@ def get_available_attention_backends(
if
fused_attention_backend
==
FusedAttnBackend
[
backends
[
i
]]:
fused_attn_backends
.
append
(
fused_attention_backend
)
return
available_backends
,
flash_attention_backend
,
fused_attn_backends
@
torch
.
no_grad
def
assert_close
(
actual
:
Optional
[
torch
.
Tensor
],
expected
:
Optional
[
torch
.
Tensor
],
*
,
check_device
:
bool
=
False
,
check_dtype
:
bool
=
False
,
check_layout
:
bool
=
False
,
**
kwargs
,
)
->
None
:
"""Assert that two tensors are close.
This function is a wrapper around torch.testing.assert_close. It
changes the defaults for device and dtype checks (useful when the
reference implementation is computed in high precision on CPU) and
it can handle quantized tensors.
"""
if
isinstance
(
actual
,
QuantizedTensor
):
actual
=
actual
.
dequantize
()
if
isinstance
(
expected
,
QuantizedTensor
):
expected
=
expected
.
dequantize
()
torch
.
testing
.
assert_close
(
actual
,
expected
,
check_device
=
check_device
,
check_dtype
=
check_dtype
,
check_layout
=
check_layout
,
**
kwargs
,
)
def
assert_close_grads
(
actual
:
Optional
[
torch
.
Tensor
],
expected
:
Optional
[
torch
.
Tensor
],
**
kwargs
,
)
->
None
:
"""Assert that two tensors have close gradients."""
if
actual
is
None
and
expected
is
None
:
return
assert
actual
is
not
None
assert
expected
is
not
None
assert_close
(
actual
.
grad
,
expected
.
grad
,
**
kwargs
)
transformer_engine/common/CMakeLists.txt
View file @
9df0c4a3
...
...
@@ -202,6 +202,7 @@ if(USE_CUDA)
fused_attn/fused_attn_fp8.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_grouped_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
...
...
@@ -225,15 +226,18 @@ if(USE_CUDA)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/swiglu.cu
cast/cast.cu
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/graph_safe_group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
...
...
@@ -357,6 +361,7 @@ else()
fused_attn/kv_cache.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_grouped_gemm.cu
gemm/hipblas_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
...
...
@@ -381,6 +386,7 @@ else()
list
(
APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/swiglu.cu
cast/cast.cu
...
...
@@ -476,20 +482,18 @@ endif()
option
(
NVTE_WITH_CUBLASMP
"Use cuBLASMp for tensor parallel GEMMs"
OFF
)
if
(
NVTE_WITH_CUBLASMP
)
target_compile_definitions
(
transformer_engine PRIVATE NVTE_WITH_CUBLASMP
)
target_include_directories
(
transformer_engine PRIVATE
${
CUBLASMP_DIR
}
/include
${
NVSHMEM_DIR
}
/include
)
target_include_directories
(
transformer_engine PRIVATE
${
CUBLASMP_DIR
}
/include
)
find_library
(
CUBLASMP_LIB
NAMES cublasmp libcublasmp
PATHS
${
CUBLASMP_DIR
}
PATH_SUFFIXES lib
REQUIRED
)
find_library
(
NVSHMEM_HOST_LIB
NAMES nvshmem_host libnvshmem_host.so.3
PATHS
${
NVSHMEM_DIR
}
find_library
(
NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED
)
target_link_libraries
(
transformer_engine PUBLIC
${
CUBLASMP_LIB
}
${
NVSHMEM_HOST
_LIB
}
)
target_link_libraries
(
transformer_engine PUBLIC
${
NCCL_LIB
}
${
CUBLASMP
_LIB
}
)
message
(
STATUS
"Using cuBLASMp at:
${
CUBLASMP_DIR
}
"
)
message
(
STATUS
"Using nvshmem at:
${
NVSHMEM_DIR
}
"
)
endif
()
if
(
USE_CUDA
)
...
...
@@ -561,6 +565,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
if
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
)
list
(
APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/swiglu.cu
)
endif
()
...
...
transformer_engine/common/__init__.py
View file @
9df0c4a3
...
...
@@ -246,11 +246,13 @@ def _nvidia_cudart_include_dir() -> str:
return
""
# Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia"
# above doesn't through. However, they don't set "__file__" attribute.
if
nvidia
.
__file__
is
None
:
return
""
# above doesn't throw. However, they don't set "__file__" attribute.
if
nvidia
.
__file__
is
not
None
:
nvidia_root
=
Path
(
nvidia
.
__file__
).
parent
else
:
nvidia_root
=
Path
(
nvidia
.
__path__
[
0
])
# namespace package
include_dir
=
Path
(
nvidia
.
__file__
).
paren
t
/
"cuda_runtime"
include_dir
=
nvidia
_roo
t
/
"cuda_runtime"
return
str
(
include_dir
)
if
include_dir
.
exists
()
else
""
...
...
transformer_engine/common/activation/gelu.cu
View file @
9df0c4a3
...
...
@@ -13,6 +13,14 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn
<
fp32
,
Empty
,
gelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
void
nvte_group_gelu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_gelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
true
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
gelu
<
fp32
,
fp32
>>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_dgelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dgelu
);
...
...
@@ -20,6 +28,20 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn
<
fp32
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
void
nvte_group_dgelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_dgelu
);
using
namespace
transformer_engine
;
NVTETensor
dbias
=
nullptr
;
NVTETensor
workspace
=
nullptr
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
...
...
@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_group_quantize_dbias_dgelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
activation_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias_dgelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_geglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_geglu
);
using
namespace
transformer_engine
;
...
...
@@ -54,6 +90,15 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn
<
fp32
,
Empty
,
qgelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
void
nvte_group_qgelu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_qgelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
true
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
qgelu
<
fp32
,
fp32
>>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_dqgelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dqgelu
);
...
...
@@ -61,6 +106,20 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn
<
fp32
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
void
nvte_group_dqgelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_dqgelu
);
using
namespace
transformer_engine
;
NVTETensor
dbias
=
nullptr
;
NVTETensor
workspace
=
nullptr
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dqgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
...
...
@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_group_quantize_dbias_dqgelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
activation_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias_dqgelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_qgeglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_qgeglu
);
using
namespace
transformer_engine
;
...
...
transformer_engine/common/activation/glu.cu
0 → 100644
View file @
9df0c4a3
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../util/math.h"
#include "./activation_template.h"
void
nvte_glu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_glu
);
using
namespace
transformer_engine
;
Empty
e
=
{};
gated_act_fn
<
fp32
,
Empty
,
sigmoid
<
fp32
,
fp32
>>
(
input
,
output
,
e
,
stream
);
}
void
nvte_dglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dglu
);
using
namespace
transformer_engine
;
Empty
e
=
{};
dgated_act_fn
<
fp32
,
Empty
,
sigmoid
<
fp32
,
fp32
>
,
dsigmoid
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
e
,
stream
);
}
transformer_engine/common/activation/relu.cu
View file @
9df0c4a3
...
...
@@ -13,6 +13,14 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn
<
fp32
,
Empty
,
relu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
void
nvte_group_relu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_relu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
true
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
relu
<
fp32
,
fp32
>>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_drelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_drelu
);
...
...
@@ -20,6 +28,20 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn
<
fp32
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
void
nvte_group_drelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_drelu
);
using
namespace
transformer_engine
;
NVTETensor
dbias
=
nullptr
;
NVTETensor
workspace
=
nullptr
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_drelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
...
...
@@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_group_quantize_dbias_drelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
activation_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias_drelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_reglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_reglu
);
using
namespace
transformer_engine
;
...
...
@@ -54,6 +90,15 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn
<
fp32
,
Empty
,
srelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
void
nvte_group_srelu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_srelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
true
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
srelu
<
fp32
,
fp32
>>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_dsrelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dsrelu
);
...
...
@@ -61,6 +106,20 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn
<
fp32
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
void
nvte_group_dsrelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_dsrelu
);
using
namespace
transformer_engine
;
NVTETensor
dbias
=
nullptr
;
NVTETensor
workspace
=
nullptr
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dsrelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
...
...
@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_group_quantize_dbias_dsrelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
activation_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias_dsrelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_sreglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_sreglu
);
using
namespace
transformer_engine
;
...
...
transformer_engine/common/activation/swiglu.cu
View file @
9df0c4a3
...
...
@@ -13,6 +13,14 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn
<
fp32
,
Empty
,
silu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
void
nvte_group_silu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_silu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
true
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
silu
<
fp32
,
fp32
>>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_dsilu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dsilu
);
...
...
@@ -20,6 +28,20 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn
<
fp32
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
void
nvte_group_dsilu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_dsilu
);
using
namespace
transformer_engine
;
NVTETensor
dbias
=
nullptr
;
NVTETensor
workspace
=
nullptr
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dsilu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
...
...
@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_group_quantize_dbias_dsilu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
activation_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias_dsilu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_swiglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_swiglu
);
using
namespace
transformer_engine
;
...
...
transformer_engine/common/cast/cast.cu
View file @
9df0c4a3
...
...
@@ -28,6 +28,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
dispatch
::
quantize_fwd_helper
<
IS_ACT
,
Empty
,
nullptr
>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_group_quantize
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
false
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
nullptr
>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_quantize_noop
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
noop
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_noop
);
...
...
@@ -62,6 +71,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_group_quantize_dbias
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_quantize_dbias
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
false
;
constexpr
const
NVTEGroupedTensor
activation_input
=
nullptr
;
dispatch
::
group_quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
nullptr
>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_dequantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dequantize
);
using
namespace
transformer_engine
;
...
...
transformer_engine/common/cast/core/common.cuh
View file @
9df0c4a3
...
...
@@ -37,6 +37,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) {
return
cols
%
alignment_requirement
==
0
;
}
__device__
__forceinline__
unsigned
char
*
align_smem_ptr_per_TMA_requirements
(
unsigned
char
*
p
)
{
size_t
addr
=
reinterpret_cast
<
size_t
>
(
p
);
addr
=
(
addr
+
TMA_SHMEM_ALIGNMENT
-
1
)
&
~
(
TMA_SHMEM_ALIGNMENT
-
1
);
return
reinterpret_cast
<
unsigned
char
*>
(
addr
);
}
namespace
kernel
{
constexpr
size_t
THREADS_PER_BLOCK
=
256
;
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment