Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
87e3e56e
Commit
87e3e56e
authored
Aug 27, 2025
by
yuguo
Browse files
Merge commit '
734bcedd
' of...
Merge commit '
734bcedd
' of
https://github.com/NVIDIA/TransformerEngine
parents
2f11bd2e
734bcedd
Changes
217
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
989 additions
and
1104 deletions
+989
-1104
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+107
-310
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+186
-199
tests/pytorch/test_parallel_cross_entropy.py
tests/pytorch/test_parallel_cross_entropy.py
+0
-1
tests/pytorch/test_qk_norm.py
tests/pytorch/test_qk_norm.py
+182
-35
tests/pytorch/test_recipe.py
tests/pytorch/test_recipe.py
+35
-37
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+94
-461
tests/pytorch/utils.py
tests/pytorch/utils.py
+187
-0
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+7
-0
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+15
-3
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+35
-3
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
...ngine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
+24
-0
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
...engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
+8
-0
transformer_engine/common/common.cu
transformer_engine/common/common.cu
+5
-2
transformer_engine/common/common.h
transformer_engine/common/common.h
+3
-2
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+14
-10
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
+2
-2
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
...ngine/common/fused_router/fused_score_for_moe_aux_loss.cu
+6
-3
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
...ine/common/fused_router/fused_topk_with_score_function.cu
+3
-3
transformer_engine/common/fused_router/utils.h
transformer_engine/common/fused_router/utils.h
+48
-21
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+28
-12
No files found.
tests/pytorch/test_numerics.py
View file @
87e3e56e
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
from
collections
import
OrderedDict
import
math
import
math
import
os
import
os
from
typing
import
Dict
,
List
,
Tuple
,
Optional
from
typing
import
Dict
,
List
,
Tuple
,
Optional
...
@@ -39,54 +38,39 @@ from transformer_engine.pytorch import (
...
@@ -39,54 +38,39 @@ from transformer_engine.pytorch import (
Fp8Unpadding
,
Fp8Unpadding
,
)
)
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.attention.inference
import
InferenceParams
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
,
get_cudnn_version
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
,
get_available_attention_backends
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
seed
=
1234
seed
=
1234
torch
.
manual_seed
(
seed
)
# Reset RNG states.
torch
.
cuda
.
manual_seed
(
seed
)
reset_rng_states
()
# Record initial RNG state from script run.
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
if
torch_version
()
>=
(
2
,
7
,
0
):
if
torch_version
()
>=
(
2
,
7
,
0
):
torch
.
_dynamo
.
config
.
recompile_limit
=
16
torch
.
_dynamo
.
config
.
recompile_limit
=
16
else
:
else
:
torch
.
_dynamo
.
config
.
cache_size_limit
=
16
torch
.
_dynamo
.
config
.
cache_size_limit
=
16
class
ModelConfig
:
def
__init__
(
self
,
hidden_size
,
eps
,
num_attention_heads
,
embed
,
num_layers
,
seq_len
):
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
num_attention_heads
=
num_attention_heads
self
.
embed
=
embed
self
.
num_layers
=
num_layers
self
.
seq_len
=
seq_len
model_configs
=
{
model_configs
=
{
"small"
:
ModelConfig
(
1
28
,
1
e-5
,
8
,
3
6
,
4
,
128
),
"small"
:
ModelConfig
(
1
,
1
28
,
8
,
1
6
,
num_layers
=
4
),
"126m"
:
ModelConfig
(
768
,
1e-5
,
12
,
64
,
12
,
2048
),
"126m"
:
ModelConfig
(
1
,
2048
,
12
,
64
,
num_layers
=
12
),
}
}
model_configs_inference
=
{
model_configs_inference
=
{
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m"
:
ModelConfig
(
1
,
256
,
12
,
64
,
num_layers
=
12
),
"126m"
:
ModelConfig
(
768
,
1e-5
,
12
,
64
,
12
,
256
),
}
}
backends_inference
=
[
"FlashAttention"
,
"UnfusedAttention"
,
"FusedAttention"
]
backends_inference
=
[
"FlashAttention"
,
"UnfusedAttention"
,
"FusedAttention"
]
module_inference
=
[
"TransformerLayer"
,
"MultiheadAttention"
]
module_inference
=
[
"TransformerLayer"
,
"MultiheadAttention"
]
...
@@ -120,12 +104,27 @@ if NVTE_TEST_NVINSPECT_ENABLED:
...
@@ -120,12 +104,27 @@ if NVTE_TEST_NVINSPECT_ENABLED:
feature_dirs
=
os
.
environ
[
"NVTE_TEST_NVINSPECT_FEATURE_DIRS"
],
feature_dirs
=
os
.
environ
[
"NVTE_TEST_NVINSPECT_FEATURE_DIRS"
],
)
)
fp8_recipes
=
[
recipe
.
MXFP8BlockScaling
(),
fp8_recipes
=
[]
recipe
.
DelayedScaling
(),
if
mxfp8_available
:
recipe
.
Float8CurrentScaling
(),
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
recipe
.
Float8BlockScaling
(),
if
fp8_block_scaling_available
:
]
fp8_recipes
.
append
(
recipe
.
Float8BlockScaling
())
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
def
is_fused_attn_available
(
config
:
ModelConfig
,
dtype
:
torch
.
dtype
,
qkv_layout
=
"bshd_bshd_bshd"
,
is_training
=
True
):
_
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
)
return
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
in
fused_attn_backends
def
get_causal_attn_mask
(
sq
:
int
)
->
torch
.
Tensor
:
def
get_causal_attn_mask
(
sq
:
int
)
->
torch
.
Tensor
:
...
@@ -177,12 +176,6 @@ def assert_allclose(
...
@@ -177,12 +176,6 @@ def assert_allclose(
raise
AssertionError
(
msg
)
raise
AssertionError
(
msg
)
def
reset_rng_states
()
->
None
:
"""revert back to initial RNG state."""
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_global_fp8_state
():
def
reset_global_fp8_state
():
yield
yield
...
@@ -535,13 +528,13 @@ def _test_e2e_selective_recompute(
...
@@ -535,13 +528,13 @@ def _test_e2e_selective_recompute(
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
apply_residual_connection_post_layernorm
=
False
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
output_layernorm
=
False
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
...
@@ -550,13 +543,13 @@ def _test_e2e_selective_recompute(
...
@@ -550,13 +543,13 @@ def _test_e2e_selective_recompute(
)
)
te_inp_hidden_states
=
torch
.
randn
(
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
te_inp_hidden_states
.
retain_grad
()
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
seq
_
len
)
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_
seqlen
_q
)
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
te_out
=
block
(
te_out
=
block
(
...
@@ -582,14 +575,8 @@ def _test_e2e_selective_recompute(
...
@@ -582,14 +575,8 @@ def _test_e2e_selective_recompute(
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -630,13 +617,13 @@ def _test_e2e_full_recompute(
...
@@ -630,13 +617,13 @@ def _test_e2e_full_recompute(
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
apply_residual_connection_post_layernorm
=
False
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
output_layernorm
=
False
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
...
@@ -645,14 +632,14 @@ def _test_e2e_full_recompute(
...
@@ -645,14 +632,14 @@ def _test_e2e_full_recompute(
)
)
te_inp_hidden_states
=
torch
.
randn
(
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
use_reentrant
,
requires_grad
=
use_reentrant
,
)
)
if
use_reentrant
:
if
use_reentrant
:
te_inp_hidden_states
.
retain_grad
()
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
seq
_
len
)
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_
seqlen
_q
)
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
if
recompute
:
if
recompute
:
...
@@ -698,14 +685,8 @@ def _test_e2e_full_recompute(
...
@@ -698,14 +685,8 @@ def _test_e2e_full_recompute(
def
test_gpt_full_activation_recompute
(
def
test_gpt_full_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
):
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -761,13 +742,13 @@ def _test_e2e_checkpointing_get_model(config, dtype):
...
@@ -761,13 +742,13 @@ def _test_e2e_checkpointing_get_model(config, dtype):
return
TransformerLayer
(
return
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
apply_residual_connection_post_layernorm
=
False
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
output_layernorm
=
False
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
...
@@ -779,7 +760,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
...
@@ -779,7 +760,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
reset_rng_states
()
reset_rng_states
()
te_inp_hidden_states
=
torch
.
randn
(
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
...
@@ -809,14 +790,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
...
@@ -809,14 +790,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
if
p
.
requires_grad
:
if
p
.
requires_grad
:
param_grads
.
append
(
p
.
grad
.
clone
())
param_grads
.
append
(
p
.
grad
.
clone
())
global
_cpu_rng_state
,
_cuda_rng_state
_cpu_rng_state
=
torch
.
get_rng_state
()
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
del
block
del
block
block
=
_test_e2e_checkpointing_get_model
(
config
,
dtype
)
block
=
_test_e2e_checkpointing_get_model
(
config
,
dtype
)
block
.
load_state_dict
(
torch
.
load
(
path
,
weights_only
=
False
))
block
.
load_state_dict
(
torch
.
load
(
path
,
weights_only
=
False
))
reset_rng_states
()
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
for
p
in
block
.
parameters
():
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
if
p
.
requires_grad
:
...
@@ -849,6 +830,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
...
@@ -849,6 +830,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
def
test_gpt_checkpointing
(
dtype
,
bs
,
model
):
def
test_gpt_checkpointing
(
dtype
,
bs
,
model
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
):
pytest
.
skip
(
"No attention backend available."
)
outputs
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
False
)
outputs
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
True
)
outputs_checkpoint
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
True
)
...
@@ -869,13 +852,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
...
@@ -869,13 +852,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
reset_rng_states
()
reset_rng_states
()
inp_hidden_states
=
torch
.
randn
(
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
inp_hidden_states
.
retain_grad
()
inp_hidden_states
.
retain_grad
()
inp_attn_mask
=
get_causal_attn_mask
(
config
.
seq
_
len
)
inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_
seqlen
_q
)
out
=
block
(
inp_hidden_states
,
attention_mask
=
inp_attn_mask
)
out
=
block
(
inp_hidden_states
,
attention_mask
=
inp_attn_mask
)
loss
=
out
.
sum
()
loss
=
out
.
sum
()
...
@@ -895,11 +878,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
...
@@ -895,11 +878,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
def
test_gpt_accuracy
(
dtype
,
bs
,
model
,
parallel_attention_mlp
):
def
test_gpt_accuracy
(
dtype
,
bs
,
model
,
parallel_attention_mlp
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
False
):
pytest
.
skip
(
"No attention backend available."
)
te_gpt
=
TransformerLayer
(
te_gpt
=
TransformerLayer
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
ffn_hidden_size
=
4
*
config
.
hidden_size
,
ffn_hidden_size
=
4
*
config
.
hidden_size
,
num_attention_heads
=
config
.
num_
attention_
heads
,
num_attention_heads
=
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
layernorm_epsilon
=
config
.
eps
,
attention_dropout
=
0.1
,
attention_dropout
=
0.1
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
...
@@ -914,7 +899,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
...
@@ -914,7 +899,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
TorchGPT
(
TorchGPT
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
eps
,
config
.
eps
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
parallel_attention_mlp
=
parallel_attention_mlp
,
parallel_attention_mlp
=
parallel_attention_mlp
,
)
)
.
to
(
dtype
=
dtype
)
.
to
(
dtype
=
dtype
)
...
@@ -975,13 +960,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
...
@@ -975,13 +960,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states
()
reset_rng_states
()
inp_hidden_states
=
torch
.
randn
(
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
inp_hidden_states
.
retain_grad
()
inp_hidden_states
.
retain_grad
()
inp_attn_mask
=
get_causal_attn_mask
(
config
.
seq
_
len
)
if
mask_type
==
"causal"
else
None
inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_
seqlen
_q
)
if
mask_type
==
"causal"
else
None
forward_kwargs
=
{}
forward_kwargs
=
{}
if
te
:
if
te
:
...
@@ -1006,10 +991,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
...
@@ -1006,10 +991,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@
pytest
.
mark
.
parametrize
(
"mask_type"
,
mask_types
)
@
pytest
.
mark
.
parametrize
(
"mask_type"
,
mask_types
)
def
test_mha_accuracy
(
dtype
,
bs
,
model
,
mask_type
):
def
test_mha_accuracy
(
dtype
,
bs
,
model
,
mask_type
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
False
):
pytest
.
skip
(
"No attention backend available."
)
te_mha
=
MultiheadAttention
(
te_mha
=
MultiheadAttention
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
fuse_qkv_params
=
True
,
fuse_qkv_params
=
True
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
qkv_weight_interleaved
=
False
,
qkv_weight_interleaved
=
False
,
...
@@ -1020,7 +1007,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
...
@@ -1020,7 +1007,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
torch_mha
=
(
torch_mha
=
(
TorchMHA
(
TorchMHA
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
)
)
.
to
(
dtype
=
dtype
)
.
to
(
dtype
=
dtype
)
.
cuda
()
.
cuda
()
...
@@ -1066,7 +1053,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
...
@@ -1066,7 +1053,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
FP8GlobalStateManager
.
reset
()
FP8GlobalStateManager
.
reset
()
inp_hidden_states
=
torch
.
randn
(
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
...
@@ -1098,11 +1085,12 @@ def _test_dpa_accuracy(block, bs, dtype, config):
...
@@ -1098,11 +1085,12 @@ def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states
()
reset_rng_states
()
mask
=
torch
.
triu
(
mask
=
torch
.
triu
(
torch
.
ones
(
config
.
seq_len
,
config
.
seq_len
,
dtype
=
torch
.
bool
,
device
=
"cuda"
),
diagonal
=
1
torch
.
ones
(
config
.
max_seqlen_q
,
config
.
max_seqlen_kv
,
dtype
=
torch
.
bool
,
device
=
"cuda"
),
diagonal
=
1
,
)
)
query
,
key
,
value
=
[
query
,
key
,
value
=
[
torch
.
randn
(
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
num_
attention_
heads
,
config
.
embed
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
num_heads
,
config
.
kv_channels
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
...
@@ -1131,8 +1119,8 @@ def test_dpa_accuracy(dtype, bs, model):
...
@@ -1131,8 +1119,8 @@ def test_dpa_accuracy(dtype, bs, model):
te_dpa
=
(
te_dpa
=
(
DotProductAttention
(
DotProductAttention
(
config
.
num_
attention_
heads
,
config
.
num_heads
,
config
.
embed
,
config
.
kv_channels
,
attention_dropout
=
0.0
,
# disable dropout, FU uses rng differently
attention_dropout
=
0.0
,
# disable dropout, FU uses rng differently
)
)
.
to
(
dtype
=
dtype
)
.
to
(
dtype
=
dtype
)
...
@@ -1141,7 +1129,7 @@ def test_dpa_accuracy(dtype, bs, model):
...
@@ -1141,7 +1129,7 @@ def test_dpa_accuracy(dtype, bs, model):
torch_dpa
=
(
torch_dpa
=
(
TorchDotProductAttention
(
TorchDotProductAttention
(
config
.
embed
,
config
.
kv_channels
,
0.0
,
# dropout
0.0
,
# dropout
)
)
.
to
(
dtype
=
dtype
)
.
to
(
dtype
=
dtype
)
...
@@ -1267,8 +1255,8 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
...
@@ -1267,8 +1255,8 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
te_linear_ref
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
False
te_linear_ref
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
False
)
)
# Shoul
e
be bit-wise match
# Shoul
d
be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
te_outputs
,
te_outputs_ref
)):
for
_
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
te_outputs
,
te_outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
...
@@ -1280,17 +1268,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
...
@@ -1280,17 +1268,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation
=
True
fuse_wgrad_accumulation
=
True
fp8_model_params
=
False
fp8_model_params
=
False
fp8
=
recipe
is
not
None
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
seq
_
len
%
16
!=
0
and
fp8
:
if
config
.
max_
seqlen
_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
...
@@ -1653,14 +1636,12 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
...
@@ -1653,14 +1636,12 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
2
]
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
all_activations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
def
test_layernorm_mlp_accuracy_delay_wgrad_compute
(
def
test_layernorm_mlp_accuracy_delay_wgrad_compute
(
dtype
,
bs
,
model
,
activation
,
normalization
,
bias
,
fuse_wgrad_accumulation
dtype
,
bs
,
model
,
bias
,
fuse_wgrad_accumulation
):
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -1669,7 +1650,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
...
@@ -1669,7 +1650,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size
=
4
*
config
.
hidden_size
,
ffn_hidden_size
=
4
*
config
.
hidden_size
,
eps
=
config
.
eps
,
eps
=
config
.
eps
,
bias
=
bias
,
bias
=
bias
,
normalization
=
normalization
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
delay_wgrad_compute
=
True
,
delay_wgrad_compute
=
True
,
...
@@ -1681,7 +1661,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
...
@@ -1681,7 +1661,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size
=
4
*
config
.
hidden_size
,
ffn_hidden_size
=
4
*
config
.
hidden_size
,
eps
=
config
.
eps
,
eps
=
config
.
eps
,
bias
=
bias
,
bias
=
bias
,
normalization
=
normalization
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
delay_wgrad_compute
=
False
,
delay_wgrad_compute
=
False
,
...
@@ -1691,8 +1670,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
...
@@ -1691,8 +1670,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
# Share params
# Share params
with
torch
.
no_grad
():
with
torch
.
no_grad
():
ln_mlp_ref
.
layer_norm_weight
=
Parameter
(
ln_mlp
.
layer_norm_weight
.
clone
())
ln_mlp_ref
.
layer_norm_weight
=
Parameter
(
ln_mlp
.
layer_norm_weight
.
clone
())
if
normalization
!=
"RMSNorm"
:
ln_mlp_ref
.
layer_norm_bias
=
Parameter
(
ln_mlp
.
layer_norm_bias
.
clone
())
ln_mlp_ref
.
layer_norm_bias
=
Parameter
(
ln_mlp
.
layer_norm_bias
.
clone
())
ln_mlp_ref
.
fc1_weight
=
Parameter
(
ln_mlp
.
fc1_weight
.
clone
())
ln_mlp_ref
.
fc1_weight
=
Parameter
(
ln_mlp
.
fc1_weight
.
clone
())
ln_mlp_ref
.
fc2_weight
=
Parameter
(
ln_mlp
.
fc2_weight
.
clone
())
ln_mlp_ref
.
fc2_weight
=
Parameter
(
ln_mlp
.
fc2_weight
.
clone
())
if
bias
:
if
bias
:
...
@@ -1730,7 +1708,7 @@ def _test_grouped_linear_accuracy(
...
@@ -1730,7 +1708,7 @@ def _test_grouped_linear_accuracy(
FP8GlobalStateManager
.
reset
()
FP8GlobalStateManager
.
reset
()
inp_hidden_states
=
torch
.
randn
(
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
...
@@ -1743,14 +1721,14 @@ def _test_grouped_linear_accuracy(
...
@@ -1743,14 +1721,14 @@ def _test_grouped_linear_accuracy(
split_size
=
16
split_size
=
16
if
recipe
.
mxfp8
():
if
recipe
.
mxfp8
():
split_size
=
128
split_size
=
128
m
=
config
.
seq
_
len
//
split_size
m
=
config
.
max_
seqlen
_q
//
split_size
dist
=
torch
.
sort
(
torch
.
randint
(
0
,
m
,
(
num_gemms
-
2
,))).
values
.
tolist
()
dist
=
torch
.
sort
(
torch
.
randint
(
0
,
m
,
(
num_gemms
-
2
,))).
values
.
tolist
()
dist
.
append
(
dist
[
-
1
])
# Manually add a zero
dist
.
append
(
dist
[
-
1
])
# Manually add a zero
m_splits
=
torch
.
tensor
(
dist
+
[
m
])
-
torch
.
tensor
([
0
]
+
dist
)
m_splits
=
torch
.
tensor
(
dist
+
[
m
])
-
torch
.
tensor
([
0
]
+
dist
)
m_splits
=
m_splits
*
split_size
m_splits
=
m_splits
*
split_size
assert
m_splits
.
sum
()
==
config
.
seq
_
len
and
len
(
m_splits
)
==
num_gemms
assert
m_splits
.
sum
()
==
config
.
max_
seqlen
_q
and
len
(
m_splits
)
==
num_gemms
else
:
else
:
m_splits
=
torch
.
tensor
([
config
.
seq
_
len
])
m_splits
=
torch
.
tensor
([
config
.
max_
seqlen
_q
])
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
if
isinstance
(
block
,
GroupedLinear
):
if
isinstance
(
block
,
GroupedLinear
):
...
@@ -1806,17 +1784,11 @@ def test_grouped_linear_accuracy(
...
@@ -1806,17 +1784,11 @@ def test_grouped_linear_accuracy(
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
fp8
=
recipe
is
not
None
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
seq
_
len
%
16
!=
0
and
fp8
:
if
config
.
max_
seqlen
_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
...
@@ -1908,19 +1880,13 @@ def test_grouped_linear_accuracy_save_original_input(
...
@@ -1908,19 +1880,13 @@ def test_grouped_linear_accuracy_save_original_input(
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
fp8
=
recipe
is
not
None
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
seq
_
len
%
16
!=
0
and
fp8
:
if
config
.
max_
seqlen
_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
...
@@ -2074,14 +2040,14 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
...
@@ -2074,14 +2040,14 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
FP8GlobalStateManager
.
reset
()
FP8GlobalStateManager
.
reset
()
inp_hidden_states
=
torch
.
randn
(
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
*
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
*
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
inp_hidden_states
.
retain_grad
()
inp_hidden_states
.
retain_grad
()
m_splits
=
_generate_random_numbers
(
num_gemms
,
config
.
seq
_
len
*
bs
)
m_splits
=
_generate_random_numbers
(
num_gemms
,
config
.
max_
seqlen
_q
*
bs
)
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
if
isinstance
(
block
,
TorchGroupedLinearWithPadding
):
if
isinstance
(
block
,
TorchGroupedLinearWithPadding
):
...
@@ -2124,17 +2090,11 @@ def test_padding_grouped_linear_accuracy(
...
@@ -2124,17 +2090,11 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params
,
fp8_model_params
,
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
seq
_
len
%
16
!=
0
and
fp8
:
if
config
.
max_
seqlen
_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
...
@@ -2199,19 +2159,13 @@ def test_padding_grouped_linear_accuracy_save_original_input(
...
@@ -2199,19 +2159,13 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8_model_params
,
fp8_model_params
,
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
seq
_
len
%
16
!=
0
and
fp8
:
if
config
.
max_
seqlen
_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
...
@@ -2268,9 +2222,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
...
@@ -2268,9 +2222,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
# Placeholders used for graph capture.
# Placeholders used for graph capture.
static_input
=
torch
.
randn
(
static_input
=
torch
.
randn
(
config
.
seq_len
,
bs
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
config
.
max_seqlen_q
,
bs
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
)
static_target
=
torch
.
randn
(
config
.
max_seqlen_q
,
bs
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
)
)
static_target
=
torch
.
randn
(
config
.
seq_len
,
bs
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
)
real_input
=
torch
.
rand_like
(
static_input
)
real_input
=
torch
.
rand_like
(
static_input
)
real_target
=
torch
.
rand_like
(
static_target
)
real_target
=
torch
.
rand_like
(
static_target
)
...
@@ -2334,7 +2290,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
...
@@ -2334,7 +2290,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
block_args
=
(
block_args
=
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
)
)
block_kwargs
=
dict
(
block_kwargs
=
dict
(
layernorm_epsilon
=
config
.
eps
,
layernorm_epsilon
=
config
.
eps
,
...
@@ -2342,7 +2298,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
...
@@ -2342,7 +2298,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
output_layernorm
=
False
,
...
@@ -2377,13 +2333,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
...
@@ -2377,13 +2333,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
apply_residual_connection_post_layernorm
=
False
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
output_layernorm
=
False
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
...
@@ -2392,13 +2348,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
...
@@ -2392,13 +2348,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
)
)
te_inp_hidden_states
=
torch
.
randn
(
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
te_inp_hidden_states
.
retain_grad
()
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
seq
_
len
)
te_inp_attn_mask
=
get_causal_attn_mask
(
config
.
max_
seqlen
_q
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe
):
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe
):
te_out
=
block
(
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
)
te_out
=
block
(
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
)
...
@@ -2418,14 +2374,8 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
...
@@ -2418,14 +2374,8 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
NVTE_TEST_NVINSPECT_ENABLED
:
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -2461,13 +2411,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
...
@@ -2461,13 +2411,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_sbhd
=
TransformerLayer
(
block_sbhd
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0
,
hidden_dropout
=
0
,
attention_dropout
=
0
,
attention_dropout
=
0
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
output_layernorm
=
False
,
...
@@ -2482,13 +2432,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
...
@@ -2482,13 +2432,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_bshd
=
TransformerLayer
(
block_bshd
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0
,
hidden_dropout
=
0
,
attention_dropout
=
0
,
attention_dropout
=
0
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
output_layernorm
=
False
,
...
@@ -2500,13 +2450,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
...
@@ -2500,13 +2450,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_thd
=
TransformerLayer
(
block_thd
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
layernorm_epsilon
=
config
.
eps
,
layernorm_epsilon
=
config
.
eps
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0
,
hidden_dropout
=
0
,
attention_dropout
=
0
,
attention_dropout
=
0
,
kv_channels
=
config
.
embed
,
kv_channels
=
config
.
kv_channels
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
output_layernorm
=
False
,
...
@@ -2521,15 +2471,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
...
@@ -2521,15 +2471,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
assert
torch
.
all
(
torch
.
eq
(
p1
,
p2
)
&
torch
.
eq
(
p1
,
p3
)),
f
"
{
n1
}
,
{
n2
}
and
{
n3
}
not identical"
assert
torch
.
all
(
torch
.
eq
(
p1
,
p2
)
&
torch
.
eq
(
p1
,
p3
)),
f
"
{
n1
}
,
{
n2
}
and
{
n3
}
not identical"
x_sbhd
=
torch
.
randn
(
x_sbhd
=
torch
.
randn
(
(
config
.
seq
_
len
,
bs
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
x_bshd
=
x_sbhd
.
transpose
(
0
,
1
).
contiguous
()
x_bshd
=
x_sbhd
.
transpose
(
0
,
1
).
contiguous
()
x_thd
=
x_bshd
.
reshape
(
bs
*
config
.
seq
_
len
,
config
.
hidden_size
).
contiguous
()
x_thd
=
x_bshd
.
reshape
(
bs
*
config
.
max_
seqlen
_q
,
config
.
hidden_size
).
contiguous
()
x_thd_cumsum
=
torch
.
arange
(
bs
+
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
*
config
.
seq
_
len
x_thd_cumsum
=
torch
.
arange
(
bs
+
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
*
config
.
max_
seqlen
_q
# To make sure forward is also identical (just in case some module decides
# To make sure forward is also identical (just in case some module decides
# to act fancy)
# to act fancy)
...
@@ -2556,167 +2506,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
...
@@ -2556,167 +2506,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
x_thd
,
x_thd
,
cu_seqlens_q
=
x_thd_cumsum
,
cu_seqlens_q
=
x_thd_cumsum
,
cu_seqlens_kv
=
x_thd_cumsum
,
cu_seqlens_kv
=
x_thd_cumsum
,
max_seqlen_q
=
config
.
seq
_
len
,
max_seqlen_q
=
config
.
max_
seqlen
_q
,
max_seqlen_kv
=
config
.
seq
_
len
,
max_seqlen_kv
=
config
.
max_
seqlen
_kv
,
)
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
y_bshd
,
y_bshd
,
y_thd
.
reshape
(
bs
,
config
.
seq_len
,
config
.
hidden_size
).
contiguous
(),
y_thd
.
reshape
(
bs
,
config
.
max_seqlen_q
,
config
.
hidden_size
).
contiguous
(),
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model_key"
,
model_configs_inference
.
keys
())
@
pytest
.
mark
.
parametrize
(
"use_RoPE"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"input_format"
,
input_formats_inference
)
@
pytest
.
mark
.
parametrize
(
"module"
,
module_inference
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
backends_inference
)
@
pytest
.
mark
.
parametrize
(
"is_paged"
,
[
False
,
True
])
def
test_kv_cache_accuracy
(
dtype
,
bs
,
model_key
,
use_RoPE
,
input_format
,
module
,
backend
,
is_paged
):
reset_rng_states
()
if
backend
in
[
"FusedAttention"
]:
pytest
.
skip
(
"Not support FusedAttention"
)
if
backend
in
[
"FusedAttention"
,
"FlashAttention"
]
and
dtype
==
torch
.
float32
:
pytest
.
skip
(
"FusedAttention and FlashAttention do not support FP32"
)
if
use_RoPE
:
pytest
.
skip
(
"KV cache does not support starting positions for RoPE"
)
if
(
backend
==
"FusedAttention"
and
get_device_compute_capability
()
==
(
8
,
9
)
and
get_cudnn_version
()
<
(
9
,
12
,
0
)
):
pytest
.
skip
(
"Skip KV cache for sm89 and cuDNN < 9.12"
)
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"0"
if
backend
==
"FlashAttention"
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
elif
backend
==
"FusedAttention"
:
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
elif
backend
==
"UnfusedAttention"
:
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
config
=
model_configs_inference
[
model_key
]
S
=
config
.
seq_len
B
=
bs
H
=
config
.
num_attention_heads
D
=
config
.
hidden_size
head_size
=
config
.
embed
layer_number
=
1
# Limits the max size of KV-cache
B_max
=
B
S_max
=
S
if
module
==
"TransformerLayer"
:
model
=
TransformerLayer
(
hidden_size
=
D
,
ffn_hidden_size
=
4
*
D
,
num_attention_heads
=
H
,
attn_input_format
=
input_format
,
self_attn_mask_type
=
"causal"
,
enc_dec_attn_mask_type
=
"causal"
,
layer_number
=
layer_number
,
attention_dropout
=
0.0
,
params_dtype
=
dtype
,
device
=
"cuda"
,
).
eval
()
else
:
model
=
(
MultiheadAttention
(
hidden_size
=
D
,
num_attention_heads
=
H
,
qkv_format
=
input_format
,
layer_number
=
layer_number
,
attention_dropout
=
0.0
,
attn_mask_type
=
"causal"
,
params_dtype
=
dtype
,
)
.
cuda
()
.
eval
()
)
inference_params
=
InferenceParams
(
max_batch_size
=
B_max
,
max_sequence_length
=
S_max
,
num_heads_kv
=
H
,
head_dim_k
=
head_size
,
dtype
=
dtype
,
is_paged
=
is_paged
,
total_num_pages
=
int
(
B_max
*
S_max
/
256
),
page_size
=
256
,
)
rotary_freqs
=
torch
.
randn
((
S_max
,
1
,
1
,
head_size
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
input
=
torch
.
randn
((
S
,
B
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
if
input_format
==
"bshd"
:
input
=
input
.
transpose
(
0
,
1
).
contiguous
()
incremental_output
=
torch
.
zeros_like
(
input
)
# Generate output for the entire sequence
full_output
=
model
(
hidden_states
=
input
,
rotary_pos_emb
=
rotary_freqs
if
use_RoPE
else
None
)
# Incrementaly generate outputs using KV-cache
step_dict
=
OrderedDict
(
zip
(
list
(
range
(
B
)),
[
1
]
*
B
))
for
i
in
range
(
S
):
inference_params
.
pre_step
(
step_dict
)
if
input_format
==
"sbhd"
:
incremental_input
=
input
[
i
].
view
(
1
,
B
,
D
)
else
:
incremental_input
=
input
[:,
i
,
:].
view
(
B
,
1
,
D
)
seqlens_q
=
torch
.
ones
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cu_seqlens_q
=
torch
.
zeros
(
B
+
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
cu_seqlens_kv
=
cu_seqlens_q
.
clone
()
mask_type
=
"padding"
kwargs
=
{}
if
module
==
"TransformerLayer"
:
kwargs
[
"self_attn_mask_type"
]
=
mask_type
else
:
kwargs
[
"attn_mask_type"
]
=
mask_type
line_output
=
model
(
hidden_states
=
incremental_input
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_freqs
if
use_RoPE
else
None
,
**
kwargs
,
max_seqlen_q
=
1
,
max_seqlen_kv
=
S
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
,
)
)
if
input_format
==
"sbhd"
:
incremental_output
[
i
,
:,
:]
=
line_output
.
view
(
B
,
D
)
else
:
incremental_output
[:,
i
,
:]
=
line_output
.
view
(
B
,
D
)
if
module
==
"TransformerLayer"
:
atol
=
{
torch
.
float32
:
5e-3
,
torch
.
half
:
5e-3
,
torch
.
bfloat16
:
5e-2
,
}
else
:
atol
=
{
torch
.
float32
:
1e-3
,
torch
.
half
:
1e-3
,
torch
.
bfloat16
:
1e-2
,
}
# Check if the fully generated output matches the one generated incrementally
assert_allclose
(
full_output
,
incremental_output
,
atol
[
dtype
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"shape"
,
"shape"
,
...
@@ -2815,9 +2613,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
...
@@ -2815,9 +2613,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
(
16
,
4096
,
128
,
512
),
(
16
,
4096
,
128
,
512
),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
])
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
False
,
True
])
def
test_fp8_grouped_gemm
(
shape
,
fp8_dtype
,
accumulate
):
def
test_fp8_grouped_gemm
(
shape
,
accumulate
):
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
...
...
tests/pytorch/test_onnx_export.py
View file @
87e3e56e
...
@@ -27,7 +27,6 @@ import warnings
...
@@ -27,7 +27,6 @@ import warnings
import
numpy
as
np
import
numpy
as
np
import
onnxruntime
as
ort
import
onnxruntime
as
ort
import
torch
import
torch
import
random
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
typing
import
Optional
,
Union
,
Tuple
,
List
from
typing
import
Optional
,
Union
,
Tuple
,
List
from
onnxruntime_extensions
import
PyCustomOpDef
,
get_library_path
,
onnx_op
from
onnxruntime_extensions
import
PyCustomOpDef
,
get_library_path
,
onnx_op
...
@@ -59,14 +58,13 @@ TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
...
@@ -59,14 +58,13 @@ TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
skip_FP8
=
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
skip_MXFP8
=
pytest
.
mark
.
skipif
(
not
mxfp8_available
,
reason
=
reason_for_no_mxfp8
)
fp8_recipes
=
[
fp8_recipes
=
[]
None
,
if
mxfp8_available
:
recipe
.
DelayedScaling
(),
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
recipe
.
MXFP8BlockScaling
(),
if
fp8_available
:
]
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
fp8_recipes
.
append
(
None
)
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
...
@@ -369,14 +367,6 @@ def validate_result(
...
@@ -369,14 +367,6 @@ def validate_result(
)
)
def
create_meta
(
scale_factor
:
float
,
size
:
int
=
1
):
meta
=
tex
.
FP8TensorMeta
()
meta
.
amax_history
=
torch
.
zeros
(
1
,
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
meta
.
scale_inv
=
torch
.
ones
(
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
/
scale_factor
meta
.
scale
=
torch
.
ones
(
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
scale_factor
return
meta
def
dtype2str
(
dtype
:
torch
.
dtype
,
fake_bf16_io
=
False
):
def
dtype2str
(
dtype
:
torch
.
dtype
,
fake_bf16_io
=
False
):
if
fake_bf16_io
:
if
fake_bf16_io
:
assert
dtype
==
torch
.
bfloat16
assert
dtype
==
torch
.
bfloat16
...
@@ -413,36 +403,12 @@ Test cases begin here.
...
@@ -413,36 +403,12 @@ Test cases begin here.
"""
"""
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
def
_test_export_linear
(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
# Returning the bias is a TE fusion optimization we don't care about.
use_bias
:
bool
=
True
,
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
return_bias
:
bool
=
False
,
@
pytest
.
mark
.
parametrize
(
precision
:
torch
.
dtype
=
torch
.
float32
,
"precision, use_bias"
,
[
(
torch
.
float32
,
False
),
(
torch
.
float32
,
True
),
(
torch
.
float16
,
False
),
(
torch
.
float16
,
True
),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(
torch
.
bfloat16
,
False
),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
(
torch
.
bfloat16
,
True
),
],
)
def
test_export_linear
(
seed_default_rng
,
scale_factor
:
float
,
fp8_recipe
:
recipe
.
Recipe
,
use_bias
:
bool
,
return_bias
:
bool
,
precision
:
torch
.
dtype
,
):
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
return_bias
and
not
use_bias
:
if
return_bias
and
not
use_bias
:
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
...
@@ -498,32 +464,28 @@ def test_export_linear(
...
@@ -498,32 +464,28 @@ def test_export_linear(
)
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
"precision"
,
def
test_export_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
[
_test_export_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
,
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
],
def
test_export_linear_use_bias
(
seed_default_rng
,
use_bias
):
)
_test_export_linear
(
use_bias
=
use_bias
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm
(
seed_default_rng
,
scale_factor
:
float
,
fp8_recipe
:
recipe
.
Recipe
,
precision
:
torch
.
dtype
,
zero_centered_gamma
:
bool
,
normalization
:
str
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
def
test_export_linear_return_bias
(
seed_default_rng
,
return_bias
):
_test_export_linear
(
return_bias
=
return_bias
)
def
_test_export_layernorm
(
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
precision
:
torch
.
dtype
=
torch
.
float32
,
zero_centered_gamma
:
bool
=
False
,
normalization
:
str
=
all_normalizations
[
0
],
):
# Set dimensions (these are arbitrary).
# Set dimensions (these are arbitrary).
batch_size
=
4
batch_size
=
4
in_features
=
64
in_features
=
64
...
@@ -564,39 +526,31 @@ def test_export_layernorm(
...
@@ -564,39 +526,31 @@ def test_export_layernorm(
)
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"return_layernorm_output"
,
[
True
,
False
])
def
test_export_layernorm_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
@
pytest
.
mark
.
parametrize
(
_test_export_layernorm
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
"precision, use_bias"
,
[
(
torch
.
float32
,
False
),
def
test_export_layernorm_zero_centered_gamma
(
seed_default_rng
):
(
torch
.
float32
,
True
),
_test_export_layernorm
(
zero_centered_gamma
=
True
)
(
torch
.
float16
,
True
),
(
torch
.
float16
,
False
),
(
torch
.
bfloat16
,
True
),
(
torch
.
bfloat16
,
False
),
],
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm_linear
(
def
test_export_layernorm_normalization
(
seed_default_rng
,
normalization
):
seed_default_rng
,
_test_export_layernorm
(
normalization
=
normalization
)
scale_factor
:
float
,
fp8_recipe
:
recipe
.
Recipe
,
use_bias
:
bool
,
def
_test_export_layernorm_linear
(
return_bias
:
bool
,
scale_factor
:
float
=
112
,
return_layernorm_output
:
bool
,
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
precision
:
torch
.
dtype
,
use_bias
:
bool
=
True
,
zero_centered_gamma
:
bool
,
return_bias
:
bool
=
False
,
normalization
:
str
,
return_layernorm_output
:
bool
=
False
,
precision
:
torch
.
dtype
=
torch
.
float32
,
zero_centered_gamma
:
bool
=
False
,
normalization
:
str
=
all_normalizations
[
0
],
):
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
return_bias
and
not
use_bias
:
if
return_bias
and
not
use_bias
:
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
...
@@ -644,41 +598,44 @@ def test_export_layernorm_linear(
...
@@ -644,41 +598,44 @@ def test_export_layernorm_linear(
)
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"return_layernorm_output"
,
[
True
,
False
])
def
test_export_layernorm_linear_recipe
(
seed_default_rng
,
fp8_recipe
,
precision
):
@
pytest
.
mark
.
parametrize
(
_test_export_layernorm_linear
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
"precision, use_bias"
,
[
(
torch
.
float32
,
False
),
def
test_export_layernorm_linear_return_ln_out
(
seed_default_rng
):
(
torch
.
float32
,
True
),
_test_export_layernorm_linear
(
return_layernorm_output
=
True
)
(
torch
.
float16
,
True
),
(
torch
.
float16
,
False
),
(
torch
.
bfloat16
,
True
),
def
test_export_layernorm_linear_zero_centered_gamma
(
seed_default_rng
):
(
torch
.
bfloat16
,
False
),
_test_export_layernorm_linear
(
zero_centered_gamma
=
True
)
],
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
)
def
test_export_layernorm_linear_normalization
(
seed_default_rng
,
normalization
):
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
_test_export_layernorm_linear
(
normalization
=
normalization
)
def
test_export_layernorm_mlp
(
seed_default_rng
,
scale_factor
:
float
,
def
test_export_layernorm_linear_no_bias
(
seed_default_rng
):
fp8_recipe
:
recipe
.
Recipe
,
_test_export_layernorm_linear
(
use_bias
=
False
)
use_bias
:
bool
,
return_bias
:
bool
,
return_layernorm_output
:
bool
,
def
test_export_layernorm_linear_return_bias
(
seed_default_rng
):
precision
:
torch
.
dtype
,
_test_export_layernorm_linear
(
return_bias
=
True
)
zero_centered_gamma
:
bool
,
activation
:
str
,
normalization
:
str
,
def
_test_export_layernorm_mlp
(
scale_factor
:
float
=
112
,
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
use_bias
:
bool
=
True
,
return_bias
:
bool
=
False
,
return_layernorm_output
:
bool
=
False
,
precision
:
torch
.
dtype
=
torch
.
float32
,
zero_centered_gamma
:
bool
=
False
,
activation
:
str
=
supported_activations
[
0
],
normalization
:
str
=
all_normalizations
[
0
],
):
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
return_bias
and
not
use_bias
:
if
return_bias
and
not
use_bias
:
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
...
@@ -720,6 +677,38 @@ def test_export_layernorm_mlp(
...
@@ -720,6 +677,38 @@ def test_export_layernorm_mlp(
)
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_layernorm_mlp
(
seed_default_rng
,
fp8_recipe
,
precision
):
_test_export_layernorm_mlp
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_layernorm_mlp_return_layernorm_output
(
seed_default_rng
):
_test_export_layernorm_mlp
(
return_layernorm_output
=
True
)
def
test_export_layernorm_mlp_return_bias
(
seed_default_rng
):
_test_export_layernorm_mlp
(
return_bias
=
True
)
def
test_export_layernorm_mlp_no_bias
(
seed_default_rng
):
_test_export_layernorm_mlp
(
use_bias
=
False
)
def
test_export_layernorm_mlp_zero_centered_gamma
(
seed_default_rng
):
_test_export_layernorm_mlp
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
[
1
:])
def
test_export_layernorm_mlp_normalization
(
seed_default_rng
,
normalization
):
_test_export_layernorm_mlp
(
normalization
=
normalization
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_layernorm_mlp_activation
(
seed_default_rng
,
activation
):
_test_export_layernorm_mlp
(
activation
=
activation
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"precision, use_mask, attn_mask_type"
,
"precision, use_mask, attn_mask_type"
,
[
[
...
@@ -734,8 +723,6 @@ def test_export_layernorm_mlp(
...
@@ -734,8 +723,6 @@ def test_export_layernorm_mlp(
],
],
)
)
def
test_export_core_attention
(
def
test_export_core_attention
(
seed_default_rng
,
set_max_seq_len
,
precision
:
torch
.
dtype
,
precision
:
torch
.
dtype
,
use_mask
:
bool
,
use_mask
:
bool
,
attn_mask_type
:
str
,
attn_mask_type
:
str
,
...
@@ -777,11 +764,6 @@ def test_export_core_attention(
...
@@ -777,11 +764,6 @@ def test_export_core_attention(
)
)
test_configs_multihead_attention
=
[
# "use_mask, attn_mask_type"
(
False
,
"no_mask"
),
# calls ScaledSoftmax
(
True
,
"arbitrary"
),
# calls ScaledMaskedSoftmax
]
test_configs_attention_type
=
[
test_configs_attention_type
=
[
# "input_layernorm, attention_type, fuse_qkv_params"
# "input_layernorm, attention_type, fuse_qkv_params"
(
True
,
"self"
,
True
),
(
True
,
"self"
,
True
),
...
@@ -795,31 +777,14 @@ test_configs_attention_type = [
...
@@ -795,31 +777,14 @@ test_configs_attention_type = [
]
]
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
def
_test_export_multihead_attention
(
@
pytest
.
mark
.
parametrize
(
"use_mask, attn_mask_type"
,
test_configs_multihead_attention
)
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
use_mask
:
bool
=
True
,
@
pytest
.
mark
.
parametrize
(
"return_layernorm_output"
,
[
False
])
precision
:
torch
.
dtype
=
torch
.
float32
,
@
pytest
.
mark
.
parametrize
(
input_layernorm
:
bool
=
True
,
"input_layernorm, attention_type, fuse_qkv_params"
,
test_configs_attention_type
attention_type
:
str
=
"self"
,
)
fuse_qkv_params
:
bool
=
True
,
def
test_export_multihead_attention
(
seed_default_rng
,
set_max_seq_len
,
fp8_recipe
:
recipe
.
Recipe
,
use_mask
:
bool
,
attn_mask_type
:
str
,
precision
:
torch
.
dtype
,
return_layernorm_output
:
bool
,
input_layernorm
:
bool
,
attention_type
:
str
,
fuse_qkv_params
:
bool
,
):
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
hidden_size
=
256
hidden_size
=
256
sequence_length
=
128
sequence_length
=
128
batch_size
=
4
batch_size
=
4
...
@@ -837,6 +802,7 @@ def test_export_multihead_attention(
...
@@ -837,6 +802,7 @@ def test_export_multihead_attention(
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
)
)
attn_mask_type
=
"arbitrary"
if
use_mask
else
"no_mask"
hidden_states_context
=
torch
.
randn
(
hidden_states_context
=
torch
.
randn
(
sequence_length
,
batch_size
,
hidden_size
,
dtype
=
precision
,
device
=
"cuda"
sequence_length
,
batch_size
,
hidden_size
,
dtype
=
precision
,
device
=
"cuda"
...
@@ -868,7 +834,7 @@ def test_export_multihead_attention(
...
@@ -868,7 +834,7 @@ def test_export_multihead_attention(
*
attention_args
,
*
attention_args
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
params_dtype
=
precision
,
params_dtype
=
precision
,
return_layernorm_output
=
return_layernorm_output
,
return_layernorm_output
=
False
,
input_layernorm
=
input_layernorm
,
input_layernorm
=
input_layernorm
,
attention_type
=
attention_type
,
attention_type
=
attention_type
,
fuse_qkv_params
=
fuse_qkv_params
,
fuse_qkv_params
=
fuse_qkv_params
,
...
@@ -960,30 +926,37 @@ def test_export_multihead_attention(
...
@@ -960,30 +926,37 @@ def test_export_multihead_attention(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"use_mask, attn_mask_type"
,
test_configs_multihead_attention
)
@
pytest
.
mark
.
parametrize
(
"output_layernorm"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"fuse_qkv_params"
,
[
False
,
True
])
def
test_export_multihead_attention_recipe
(
fp8_recipe
,
precision
):
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
_test_export_multihead_attention
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
)
def
test_export_transformer_layer
(
seed_default_rng
,
def
test_export_multihead_attention_no_mask
():
set_max_seq_len
,
_test_export_multihead_attention
(
use_mask
=
False
)
fp8_recipe
:
recipe
.
Recipe
,
use_mask
:
bool
,
attn_mask_type
:
str
,
def
test_export_multihead_attention_no_input_layernorm
():
output_layernorm
:
bool
,
_test_export_multihead_attention
(
input_layernorm
=
False
)
precision
:
torch
.
dtype
,
fuse_qkv_params
:
bool
,
zero_centered_gamma
:
bool
,
activation
:
str
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
def
test_export_multihead_attention_cross_attn
():
_test_export_multihead_attention
(
attention_type
=
"cross"
)
def
test_export_multihead_attention_unfused_qkv_params
():
_test_export_multihead_attention
(
fuse_qkv_params
=
False
)
def
_test_export_transformer_layer
(
fp8_recipe
:
recipe
.
Recipe
=
fp8_recipes
[
0
],
use_mask
:
bool
=
True
,
attn_mask_type
:
str
=
"arbitrary"
,
output_layernorm
:
bool
=
False
,
precision
:
torch
.
dtype
=
torch
.
float32
,
fuse_qkv_params
:
bool
=
True
,
zero_centered_gamma
:
bool
=
False
,
activation
:
str
=
supported_activations
[
0
],
):
# Layer configuration
# Layer configuration
hidden_size
=
64
hidden_size
=
64
sequence_length
=
128
sequence_length
=
128
...
@@ -1043,28 +1016,43 @@ def test_export_transformer_layer(
...
@@ -1043,28 +1016,43 @@ def test_export_transformer_layer(
)
)
@
skip_FP8
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
skip_MXFP8
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_export_transformer_layer_recipe
(
fp8_recipe
,
precision
):
_test_export_transformer_layer
(
fp8_recipe
=
fp8_recipe
,
precision
=
precision
)
def
test_export_transformer_layer_no_mask
():
_test_export_transformer_layer
(
use_mask
=
False
)
def
test_export_transformer_layer_output_layernorm
():
_test_export_transformer_layer
(
output_layernorm
=
True
)
def
test_export_transformer_layer_unfused_qkv_params
():
_test_export_transformer_layer
(
fuse_qkv_params
=
False
)
def
test_export_transformer_layer_zero_centered_gamma
():
_test_export_transformer_layer
(
zero_centered_gamma
=
True
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
[
1
:])
def
test_export_transformer_layer_activation
(
activation
):
_test_export_transformer_layer
(
activation
=
activation
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
True
])
def
test_export_gpt_generation
(
def
test_export_gpt_generation
(
seed_default_rng
,
set_max_seq_len
,
fp8_recipe
:
recipe
.
Recipe
,
fp8_recipe
:
recipe
.
Recipe
,
precision
:
torch
.
dtype
,
precision
:
torch
.
dtype
,
zero_centered_gamma
:
bool
,
):
):
"""Test that the ONNX model can correctly handle inputs with different shapes and that
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
"""
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
# Layer configuration
# Layer configuration
hidden_size
=
64
hidden_size
=
64
sequence_length
=
128
sequence_length
=
128
...
@@ -1091,7 +1079,6 @@ def test_export_gpt_generation(
...
@@ -1091,7 +1079,6 @@ def test_export_gpt_generation(
output_layernorm
=
output_layernorm
,
output_layernorm
=
output_layernorm
,
params_dtype
=
precision
,
params_dtype
=
precision
,
fuse_qkv_params
=
fuse_qkv_params
,
fuse_qkv_params
=
fuse_qkv_params
,
zero_centered_gamma
=
zero_centered_gamma
,
).
to
(
device
=
"cuda"
)
).
to
(
device
=
"cuda"
)
# "Context phase": use full input sequence length
# "Context phase": use full input sequence length
...
...
tests/pytorch/test_parallel_cross_entropy.py
View file @
87e3e56e
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
# See LICENSE for license information.
# See LICENSE for license information.
import
random
import
random
import
pytest
import
torch
import
torch
from
transformer_engine.pytorch.cross_entropy
import
parallel_cross_entropy
from
transformer_engine.pytorch.cross_entropy
import
parallel_cross_entropy
...
...
tests/pytorch/test_qk_norm.py
View file @
87e3e56e
...
@@ -8,10 +8,10 @@ import pytest
...
@@ -8,10 +8,10 @@ import pytest
import
torch
import
torch
@
pytest
.
mark
.
parametrize
(
"
use_
qk_norm
"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"qk_norm
_type"
,
[
None
,
"L2Normalization"
,
"RMSNorm"
,
"LayerNorm"
])
@
pytest
.
mark
.
parametrize
(
"attention_type"
,
[
"self"
,
"cross"
])
@
pytest
.
mark
.
parametrize
(
"attention_type"
,
[
"self"
,
"cross"
])
@
pytest
.
mark
.
parametrize
(
"qk_norm_eps"
,
[
1e-6
,
1e-5
])
@
pytest
.
mark
.
parametrize
(
"qk_norm_eps"
,
[
1e-6
,
1e-5
])
def
test_qk_norm_functionality
(
use_
qk_norm
,
attention_type
,
qk_norm_eps
)
->
None
:
def
test_qk_norm_functionality
(
qk_norm
_type
,
attention_type
,
qk_norm_eps
)
->
None
:
"""Test QK normalization functionality, module structure, and numerical behavior."""
"""Test QK normalization functionality, module structure, and numerical behavior."""
hidden_size
=
256
hidden_size
=
256
num_attention_heads
=
8
num_attention_heads
=
8
...
@@ -22,25 +22,59 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
...
@@ -22,25 +22,59 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
attention_type
=
attention_type
,
attention_type
=
attention_type
,
use_
qk_norm
=
use_
qk_norm
,
qk_norm
_type
=
qk_norm
_type
,
qk_norm_eps
=
qk_norm_eps
,
qk_norm_eps
=
qk_norm_eps
,
bias
=
False
,
bias
=
False
,
device
=
"cuda"
,
device
=
"cuda"
,
).
cuda
()
).
cuda
()
# Check module structure based on use_qk_norm parameter
# Check module structure based on qk_norm_type parameter
if
use_qk_norm
:
if
qk_norm_type
is
not
None
:
assert
hasattr
(
mha
,
"qk_norm"
),
"Should have qk_norm module when use_qk_norm=True"
assert
mha
.
q_norm
is
not
None
,
"Should have q_norm module when qk_norm_type is not None"
assert
not
hasattr
(
mha
,
"q_l2norm"
),
"Should not have separate q_l2norm module"
assert
mha
.
k_norm
is
not
None
,
"Should have k_norm module when qk_norm_type is not None"
assert
not
hasattr
(
mha
,
"k_l2norm"
),
"Should not have separate k_l2norm module"
# Check that the module is L2Norm type
# Check that the modules are of the correct type
from
transformer_engine.pytorch.ops.basic.l2normalization
import
L2Normalization
if
qk_norm_type
==
"L2Normalization"
:
from
transformer_engine.pytorch.ops.basic.l2normalization
import
L2Normalization
assert
isinstance
(
mha
.
qk_norm
,
L2Normalization
assert
isinstance
(
),
"qk_norm should be an L2Normalization module"
mha
.
q_norm
,
L2Normalization
),
"q_norm should be an L2Normalization module"
assert
isinstance
(
mha
.
k_norm
,
L2Normalization
),
"k_norm should be an L2Normalization module"
# For L2 normalization, q_norm and k_norm should be the same instance (parameter-free)
assert
(
mha
.
q_norm
is
mha
.
k_norm
),
"q_norm and k_norm should be the same instance for L2 normalization"
elif
qk_norm_type
==
"RMSNorm"
:
from
transformer_engine.pytorch.module.rmsnorm
import
RMSNorm
assert
isinstance
(
mha
.
q_norm
,
RMSNorm
),
"q_norm should be an RMSNorm module"
assert
isinstance
(
mha
.
k_norm
,
RMSNorm
),
"k_norm should be an RMSNorm module"
# For RMS normalization, q_norm and k_norm should be separate instances
assert
(
mha
.
q_norm
is
not
mha
.
k_norm
),
"q_norm and k_norm should be separate instances for RMS normalization"
elif
qk_norm_type
==
"LayerNorm"
:
from
transformer_engine.pytorch.module.layernorm
import
LayerNorm
assert
isinstance
(
mha
.
q_norm
,
LayerNorm
),
"q_norm should be a LayerNorm module"
assert
isinstance
(
mha
.
k_norm
,
LayerNorm
),
"k_norm should be a LayerNorm module"
# For LayerNorm, q_norm and k_norm should be separate instances
assert
(
mha
.
q_norm
is
not
mha
.
k_norm
),
"q_norm and k_norm should be separate instances for LayerNorm"
else
:
# For extensibility - just ensure they exist
assert
mha
.
q_norm
is
not
None
,
f
"q_norm should exist for qk_norm_type=
{
qk_norm_type
}
"
assert
mha
.
k_norm
is
not
None
,
f
"k_norm should exist for qk_norm_type=
{
qk_norm_type
}
"
else
:
else
:
assert
not
hasattr
(
mha
,
"qk_norm"
),
"Should not have qk_norm module when use_qk_norm=False"
assert
mha
.
q_norm
is
None
,
"Should not have q_norm module when qk_norm_type is None"
assert
mha
.
k_norm
is
None
,
"Should not have k_norm module when qk_norm_type is None"
# Create input tensors
# Create input tensors
batch_size
=
2
# Use a fixed batch size for testing
batch_size
=
2
# Use a fixed batch size for testing
...
@@ -89,17 +123,14 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
...
@@ -89,17 +123,14 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
assert
not
torch
.
isinf
(
output_with_rope
).
any
(),
"RoPE output contains Inf"
assert
not
torch
.
isinf
(
output_with_rope
).
any
(),
"RoPE output contains Inf"
def
test_qk_norm_output_difference
()
->
None
:
@
pytest
.
mark
.
parametrize
(
"qk_norm_type"
,
[
"L2Normalization"
,
"RMSNorm"
,
"LayerNorm"
])
def
test_qk_norm_output_difference
(
qk_norm_type
)
->
None
:
"""Test that QK normalization actually changes the output compared to no normalization."""
"""Test that QK normalization actually changes the output compared to no normalization."""
hidden_size
=
256
hidden_size
=
256
num_attention_heads
=
8
num_attention_heads
=
8
seq_len
=
128
seq_len
=
128
batch_size
=
2
batch_size
=
2
# Use same random seed to ensure identical weight initialization
current_rng_state
=
torch
.
get_rng_state
()
current_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
# Reset to a known seed for reproducible initialization
# Reset to a known seed for reproducible initialization
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
...
@@ -108,7 +139,7 @@ def test_qk_norm_output_difference() -> None:
...
@@ -108,7 +139,7 @@ def test_qk_norm_output_difference() -> None:
mha_with_norm
=
MultiheadAttention
(
mha_with_norm
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
use_
qk_norm
=
Tru
e
,
qk_norm
_type
=
qk_norm_typ
e
,
bias
=
False
,
bias
=
False
,
device
=
"cuda"
,
device
=
"cuda"
,
).
cuda
()
).
cuda
()
...
@@ -121,7 +152,7 @@ def test_qk_norm_output_difference() -> None:
...
@@ -121,7 +152,7 @@ def test_qk_norm_output_difference() -> None:
mha_no_norm
=
MultiheadAttention
(
mha_no_norm
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
use_
qk_norm
=
Fals
e
,
qk_norm
_type
=
Non
e
,
bias
=
False
,
bias
=
False
,
device
=
"cuda"
,
device
=
"cuda"
,
).
cuda
()
).
cuda
()
...
@@ -139,10 +170,11 @@ def test_qk_norm_output_difference() -> None:
...
@@ -139,10 +170,11 @@ def test_qk_norm_output_difference() -> None:
# Outputs should be different when QK normalization is enabled
# Outputs should be different when QK normalization is enabled
assert
not
torch
.
allclose
(
assert
not
torch
.
allclose
(
output_with_norm
,
output_no_norm
,
atol
=
1e-6
output_with_norm
,
output_no_norm
,
atol
=
1e-6
),
"QK normalization should change the output, but outputs are identical"
),
f
"QK normalization
(
{
qk_norm_type
}
)
should change the output, but outputs are identical"
def
test_qk_norm_with_fused_qkv
()
->
None
:
@
pytest
.
mark
.
parametrize
(
"qk_norm_type"
,
[
"L2Normalization"
,
"RMSNorm"
,
"LayerNorm"
])
def
test_qk_norm_with_fused_qkv
(
qk_norm_type
)
->
None
:
"""Test QK normalization works with fused QKV parameters."""
"""Test QK normalization works with fused QKV parameters."""
hidden_size
=
256
hidden_size
=
256
num_attention_heads
=
8
num_attention_heads
=
8
...
@@ -152,7 +184,7 @@ def test_qk_norm_with_fused_qkv() -> None:
...
@@ -152,7 +184,7 @@ def test_qk_norm_with_fused_qkv() -> None:
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
fuse_qkv_params
=
True
,
fuse_qkv_params
=
True
,
use_
qk_norm
=
Tru
e
,
qk_norm
_type
=
qk_norm_typ
e
,
bias
=
False
,
bias
=
False
,
device
=
"cuda"
,
device
=
"cuda"
,
).
cuda
()
).
cuda
()
...
@@ -173,7 +205,8 @@ def test_qk_norm_with_fused_qkv() -> None:
...
@@ -173,7 +205,8 @@ def test_qk_norm_with_fused_qkv() -> None:
),
f
"Output shape mismatch:
{
output
.
shape
}
"
),
f
"Output shape mismatch:
{
output
.
shape
}
"
def
test_qk_norm_transformer_layer_output_difference
()
->
None
:
@
pytest
.
mark
.
parametrize
(
"qk_norm_type"
,
[
"L2Normalization"
,
"RMSNorm"
,
"LayerNorm"
])
def
test_qk_norm_transformer_layer_output_difference
(
qk_norm_type
)
->
None
:
"""Test that QK normalization actually changes TransformerLayer output compared to no normalization."""
"""Test that QK normalization actually changes TransformerLayer output compared to no normalization."""
from
transformer_engine.pytorch
import
TransformerLayer
from
transformer_engine.pytorch
import
TransformerLayer
...
@@ -183,10 +216,6 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
...
@@ -183,10 +216,6 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
seq_len
=
128
seq_len
=
128
batch_size
=
2
batch_size
=
2
# Use same random seed to ensure identical weight initialization
current_rng_state
=
torch
.
get_rng_state
()
current_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
# Reset to a known seed for reproducible initialization
# Reset to a known seed for reproducible initialization
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
...
@@ -196,7 +225,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
...
@@ -196,7 +225,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
ffn_hidden_size
=
ffn_hidden_size
,
ffn_hidden_size
=
ffn_hidden_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
use_
qk_norm
=
Tru
e
,
qk_norm
_type
=
qk_norm_typ
e
,
bias
=
False
,
bias
=
False
,
device
=
"cuda"
,
device
=
"cuda"
,
).
cuda
()
).
cuda
()
...
@@ -210,7 +239,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
...
@@ -210,7 +239,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
ffn_hidden_size
=
ffn_hidden_size
,
ffn_hidden_size
=
ffn_hidden_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
use_
qk_norm
=
Fals
e
,
qk_norm
_type
=
Non
e
,
bias
=
False
,
bias
=
False
,
device
=
"cuda"
,
device
=
"cuda"
,
).
cuda
()
).
cuda
()
...
@@ -226,9 +255,10 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
...
@@ -226,9 +255,10 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
output_no_norm
=
transformer_no_norm
(
hidden_states
)
output_no_norm
=
transformer_no_norm
(
hidden_states
)
# Outputs should be different when QK normalization is enabled
# Outputs should be different when QK normalization is enabled
assert
not
torch
.
allclose
(
assert
not
torch
.
allclose
(
output_with_norm
,
output_no_norm
,
atol
=
1e-6
),
(
output_with_norm
,
output_no_norm
,
atol
=
1e-6
f
"QK normalization (
{
qk_norm_type
}
) should change the TransformerLayer output, but outputs"
),
"QK normalization should change the TransformerLayer output, but outputs are identical"
" are identical"
)
# Check that outputs have expected shapes and properties
# Check that outputs have expected shapes and properties
assert
output_with_norm
.
shape
==
(
assert
output_with_norm
.
shape
==
(
...
@@ -240,3 +270,120 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
...
@@ -240,3 +270,120 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
assert
not
torch
.
isinf
(
output_with_norm
).
any
(),
"Output with QK norm contains Inf"
assert
not
torch
.
isinf
(
output_with_norm
).
any
(),
"Output with QK norm contains Inf"
assert
not
torch
.
isnan
(
output_no_norm
).
any
(),
"Output without QK norm contains NaN"
assert
not
torch
.
isnan
(
output_no_norm
).
any
(),
"Output without QK norm contains NaN"
assert
not
torch
.
isinf
(
output_no_norm
).
any
(),
"Output without QK norm contains Inf"
assert
not
torch
.
isinf
(
output_no_norm
).
any
(),
"Output without QK norm contains Inf"
@
pytest
.
mark
.
parametrize
(
"qk_norm_type"
,
[
"L2Normalization"
,
"RMSNorm"
,
"LayerNorm"
])
def
test_qk_norm_before_after_rope
(
qk_norm_type
)
->
None
:
"""Test that QK normalization before and after RoPE works without errors."""
hidden_size
=
256
num_attention_heads
=
8
seq_len
=
64
batch_size
=
2
# Create model with QK norm after RoPE (default)
mha_after
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
qk_norm_type
=
qk_norm_type
,
qk_norm_before_rope
=
False
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Create model with QK norm before RoPE
mha_before
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
qk_norm_type
=
qk_norm_type
,
qk_norm_before_rope
=
True
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
hidden_states
=
torch
.
randn
(
seq_len
,
batch_size
,
hidden_size
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# Create RoPE embeddings
head_dim
=
hidden_size
//
num_attention_heads
rotary_dim
=
head_dim
//
2
rotary_pos_emb
=
torch
.
randn
(
seq_len
,
1
,
1
,
rotary_dim
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
with
torch
.
no_grad
():
output_after_rope
=
mha_after
(
hidden_states
,
rotary_pos_emb
=
rotary_pos_emb
)
output_before_rope
=
mha_before
(
hidden_states
,
rotary_pos_emb
=
rotary_pos_emb
)
output_after_no_rope
=
mha_after
(
hidden_states
)
output_before_no_rope
=
mha_before
(
hidden_states
)
# Check output shapes and properties
expected_shape
=
(
seq_len
,
batch_size
,
hidden_size
)
for
output
in
[
output_after_rope
,
output_before_rope
,
output_after_no_rope
,
output_before_no_rope
,
]:
assert
output
.
shape
==
expected_shape
,
f
"Output shape mismatch:
{
output
.
shape
}
"
assert
not
torch
.
isnan
(
output
).
any
(),
"Output contains NaN"
assert
not
torch
.
isinf
(
output
).
any
(),
"Output contains Inf"
assert
output_after_rope
.
shape
==
output_before_rope
.
shape
,
"Outputs should have same shape"
assert
mha_after
.
qk_norm_before_rope
==
False
,
"mha_after should have qk_norm_before_rope=False"
assert
mha_before
.
qk_norm_before_rope
==
True
,
"mha_before should have qk_norm_before_rope=True"
def
test_different_qk_norm_types_produce_different_outputs
()
->
None
:
"""Test that different QK normalization types produce different outputs."""
hidden_size
=
256
num_attention_heads
=
8
seq_len
=
128
batch_size
=
2
# Use same random seed to ensure identical weight initialization
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
# Create model with L2 normalization
mha_l2
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
qk_norm_type
=
"L2Normalization"
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Reset to same seed for identical initialization
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
# Create model with RMS normalization
mha_rms
=
MultiheadAttention
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
qk_norm_type
=
"RMSNorm"
,
bias
=
False
,
device
=
"cuda"
,
).
cuda
()
# Create input tensors
hidden_states
=
torch
.
randn
(
seq_len
,
batch_size
,
hidden_size
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# Compare outputs with identical weights but different QK norm types
with
torch
.
no_grad
():
output_l2
=
mha_l2
(
hidden_states
)
output_rms
=
mha_rms
(
hidden_states
)
# Outputs should be different when using different normalization types
assert
not
torch
.
allclose
(
output_l2
,
output_rms
,
atol
=
1e-6
),
"L2 and RMS normalization should produce different outputs, but outputs are identical"
# Check that outputs have expected shapes and properties
assert
output_l2
.
shape
==
output_rms
.
shape
,
"L2 and RMS outputs should have same shape"
assert
not
torch
.
isnan
(
output_l2
).
any
(),
"L2 output contains NaN"
assert
not
torch
.
isinf
(
output_l2
).
any
(),
"L2 output contains Inf"
assert
not
torch
.
isnan
(
output_rms
).
any
(),
"RMS output contains NaN"
assert
not
torch
.
isinf
(
output_rms
).
any
(),
"RMS output contains Inf"
tests/pytorch/test_recipe.py
View file @
87e3e56e
...
@@ -192,12 +192,6 @@ class TestFP8Recipe:
...
@@ -192,12 +192,6 @@ class TestFP8Recipe:
amax_compute_algo
=
amax_compute_algo
,
amax_compute_algo
=
amax_compute_algo
,
)
)
# Get FP8 meta tensors
with
te
.
fp8_autocast
(
fp8_recipe
=
recipe
):
x_fp8_meta
=
op
.
get_quantizer
(
"forward"
,
0
)
w_fp8_meta
=
op
.
get_quantizer
(
"forward"
,
1
)
dy_fp8_meta
=
op
.
get_quantizer
(
"backward"
,
0
)
# Perform training steps
# Perform training steps
x_history
=
[]
x_history
=
[]
w_history
=
[]
w_history
=
[]
...
@@ -229,19 +223,30 @@ class TestFP8Recipe:
...
@@ -229,19 +223,30 @@ class TestFP8Recipe:
y
=
op
(
x
)
y
=
op
(
x
)
y
.
backward
(
dy
)
y
.
backward
(
dy
)
def
check_amax_history
(
def
check_metas
(
fp8_meta
:
dict
,
test_scale
:
float
,
ref_amax_history
:
Iterable
[
float
],
test_amax_history
:
torch
.
Tensor
,
)
->
None
:
ref_amax_history_list
:
list
[
float
],
"""Check that amax history matches expected values"""
stage
:
str
,
if
len
(
ref_amax_history
)
>
amax_history_len
:
):
ref_amax_history
=
ref_amax_history
[
-
amax_history_len
:]
"""Check that meta tensors match expected values"""
# Compute amax
if
len
(
ref_amax_history_list
)
>
amax_history_len
:
ref_amax_history_list
=
ref_amax_history_list
[
-
(
amax_history_len
+
1
)
:]
ref_amax_history
=
torch
.
tensor
(
ref_amax_history
=
torch
.
tensor
(
ref_amax_history
,
ref_amax_history
_list
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
device
,
device
=
device
,
)
)
test_amax_history
=
fp8_meta
.
amax_history
[:,
0
]
if
amax_compute_algo
==
"max"
:
ref_amax
=
max
(
ref_amax_history_list
)
elif
amax_compute_algo
==
"most_recent"
:
ref_amax
=
ref_amax_history_list
[
-
1
]
else
:
raise
RuntimeError
(
f
"
{
amax_compute_algo
=
}
is not supported"
)
# Compare amax history
tols
=
dict
(
rtol
=
0
,
atol
=
0
)
tols
=
dict
(
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
test_amax_history
[
-
(
step
+
1
)
:],
test_amax_history
[
-
(
step
+
1
)
:],
...
@@ -249,23 +254,6 @@ class TestFP8Recipe:
...
@@ -249,23 +254,6 @@ class TestFP8Recipe:
**
tols
,
**
tols
,
)
)
def
check_scale
(
quantizer
:
Float8Quantizer
,
ref_amax_history
:
Iterable
[
float
],
stage
:
str
,
):
"""Check that scale and scale reciprocal match expected values"""
# Compute amax
if
len
(
ref_amax_history
)
>
amax_history_len
:
ref_amax_history
=
ref_amax_history
[
-
(
amax_history_len
+
1
)
:]
if
amax_compute_algo
==
"max"
:
ref_amax
=
max
(
ref_amax_history
)
elif
amax_compute_algo
==
"most_recent"
:
ref_amax
=
ref_amax_history
[
-
1
]
else
:
raise
RuntimeError
(
f
"
{
amax_compute_algo
=
}
is not supported"
)
# Compute scale
# Compute scale
max_val
=
{
max_val
=
{
"forward"
:
448.0
,
"forward"
:
448.0
,
...
@@ -273,16 +261,26 @@ class TestFP8Recipe:
...
@@ -273,16 +261,26 @@ class TestFP8Recipe:
}[
stage
]
}[
stage
]
ref_scale
=
(
max_val
/
ref_amax
)
/
(
2
**
margin
)
ref_scale
=
(
max_val
/
ref_amax
)
/
(
2
**
margin
)
# C
heck values in FP8 meta tensors
# C
ompare scale
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
quantizer
.
scale
.
item
()
,
test_scale
,
ref_scale
,
ref_scale
,
)
)
# Get scaling factors
x_test_scale
=
op
.
get_quantizer
(
"forward"
,
0
).
scale
.
item
()
w_test_scale
=
op
.
get_quantizer
(
"forward"
,
1
).
scale
.
item
()
dy_test_scale
=
op
.
get_quantizer
(
"backward"
,
0
).
scale
.
item
()
# Get amax histories
x_test_history
=
op
.
_fp8_metas
[
"forward"
][
forward_key
].
amax_history
[:,
0
]
w_test_history
=
op
.
_fp8_metas
[
"forward"
][
forward_key
].
amax_history
[:,
1
]
dy_test_history
=
op
.
_fp8_metas
[
"backward"
][
backward_key
].
amax_history
[:,
0
]
# Check that results match expected values
# Check that results match expected values
check_
scale
(
x_fp8_meta
,
x_history
,
"forward"
)
check_
metas
(
x_test_scale
,
x_test_history
,
x_history
,
"forward"
)
check_
scale
(
w_fp8_meta
,
w_history
,
"forward"
)
check_
metas
(
w_test_scale
,
w_test_history
,
w_history
,
"forward"
)
check_scale
(
dy_
fp8_meta
,
dy_history
,
"backward"
)
check_
metas
(
dy_test_
scale
,
dy_
test_history
,
dy_history
,
"backward"
)
@
pytest
.
mark
.
parametrize
(
"amax_case"
,
[
"zero"
,
"tiny"
,
"normal"
,
"inf"
,
"nan"
])
@
pytest
.
mark
.
parametrize
(
"amax_case"
,
[
"zero"
,
"tiny"
,
"normal"
,
"inf"
,
"nan"
])
@
pytest
.
mark
.
parametrize
(
"fused_update"
,
[
True
,
False
],
ids
=
[
"fused"
,
"non-fused"
])
@
pytest
.
mark
.
parametrize
(
"fused_update"
,
[
True
,
False
],
ids
=
[
"fused"
,
"non-fused"
])
...
...
tests/pytorch/test_sanity.py
View file @
87e3e56e
...
@@ -2,9 +2,7 @@
...
@@ -2,9 +2,7 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
Optional
from
contextlib
import
nullcontext
import
torch
import
torch
import
pytest
import
pytest
...
@@ -18,11 +16,9 @@ from transformer_engine.pytorch.fp8 import (
...
@@ -18,11 +16,9 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init
,
fp8_model_init
,
)
)
from
transformer_engine.pytorch.utils
import
(
from
transformer_engine.pytorch.utils
import
(
get_device_compute_capability
,
init_method_normal
,
init_method_normal
,
scaled_init_method_normal
,
scaled_init_method_normal
,
is_bf16_compatible
,
is_bf16_compatible
,
get_cudnn_version
,
)
)
from
transformer_engine.pytorch
import
(
from
transformer_engine.pytorch
import
(
LayerNormLinear
,
LayerNormLinear
,
...
@@ -32,7 +28,6 @@ from transformer_engine.pytorch import (
...
@@ -32,7 +28,6 @@ from transformer_engine.pytorch import (
TransformerLayer
,
TransformerLayer
,
RMSNorm
,
RMSNorm
,
LayerNorm
,
LayerNorm
,
get_cpu_offload_context
,
)
)
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
...
@@ -47,21 +42,17 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
...
@@ -47,21 +42,17 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.distributed
import
checkpoint
from
transformer_engine.pytorch.distributed
import
checkpoint
from
utils
import
dtype_tols
from
utils
import
ModelConfig
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Record initial RNG state from script run.
# Record initial RNG state from script run.
seed
=
1234
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
NVTE_TEST_NVINSPECT_ENABLED
=
int
(
os
.
environ
.
get
(
"NVTE_TEST_NVINSPECT_ENABLED"
,
"0"
))
NVTE_TEST_NVINSPECT_ENABLED
=
int
(
os
.
environ
.
get
(
"NVTE_TEST_NVINSPECT_ENABLED"
,
"0"
))
...
@@ -79,88 +70,33 @@ if NVTE_TEST_NVINSPECT_ENABLED:
...
@@ -79,88 +70,33 @@ if NVTE_TEST_NVINSPECT_ENABLED:
)
)
def
create_meta
(
scale_factor
:
float
,
size
:
int
=
1
):
def
is_fp8_supported
(
config
:
ModelConfig
):
meta
=
tex
.
FP8TensorMeta
()
if
(
meta
.
amax_history
=
torch
.
zeros
(
1
,
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
config
.
max_seqlen_q
*
config
.
batch_size
%
16
meta
.
scale_inv
=
torch
.
ones
(
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
/
scale_factor
or
config
.
max_seqlen_kv
*
config
.
batch_size
%
16
meta
.
scale
=
torch
.
ones
(
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
scale_factor
):
return
meta
return
False
if
config
.
hidden_size
%
16
or
config
.
hidden_size_kv
%
16
:
if
IS_HIP_EXTENSION
:
return
False
from
functools
import
cache
return
True
@
cache
def
use_hipblaslt
()
->
bool
:
return
(
os
.
getenv
(
"NVTE_USE_HIPBLASLT"
)
is
not
None
or
os
.
getenv
(
"NVTE_USE_ROCBLAS"
)
is
None
)
def
custom_amax_to_scale
(
amax
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
fp8_max
:
torch
.
Tensor
,
recipe
:
recipe
.
DelayedScaling
,
)
->
torch
.
Tensor
:
"""Custom func to test recipe."""
sf
=
fp8_max
/
amax
sf
=
torch
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
torch
.
where
(
torch
.
isfinite
(
amax
),
sf
,
scale
)
return
sf
def
custom_amax_compute
(
amax_history
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Custom func to test recipe."""
return
torch
.
min
(
amax_history
,
dim
=
0
).
values
def
reset_rng_states
()
->
None
:
"""revert back to initial RNG state."""
global
_cpu_rng_state
,
_cuda_rng_state
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
@
dataclass
class
ModelConfig
:
"""Transformer model configuration"""
num_layers
:
int
seq_len
:
int
batch_size
:
int
hidden_size
:
int
num_attention_heads
:
int
kv_channels
:
Optional
[
int
]
=
None
def
is_fp8_supported
(
self
):
if
self
.
seq_len
*
self
.
batch_size
%
16
:
return
False
if
self
.
hidden_size
%
16
:
return
False
return
True
model_configs
=
{
model_configs
=
{
"126m"
:
ModelConfig
(
1
2
,
2048
,
2
,
768
,
12
),
"126m"
:
ModelConfig
(
2
,
2048
,
1
2
,
64
,
num_layers
=
12
),
"small"
:
ModelConfig
(
2
,
32
,
2
,
64
,
2
),
"small"
:
ModelConfig
(
2
,
32
,
2
,
32
,
num_layers
=
2
),
"weird"
:
ModelConfig
(
2
,
37
,
3
,
69
,
3
),
"weird"
:
ModelConfig
(
3
,
37
,
3
,
23
,
num_layers
=
2
),
"large"
:
ModelConfig
(
1
,
128
,
2
,
5
12
,
4
,
128
),
"large"
:
ModelConfig
(
2
,
128
,
4
,
12
8
,
num_layers
=
1
),
}
}
fp8_recipes
=
[
fp8_recipes
=
[]
None
,
# Test non-FP8
if
mxfp8_available
:
recipe
.
MXFP8BlockScaling
(),
# Test default
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
recipe
.
Float8CurrentScaling
(),
# Test default
if
fp8_block_scaling_available
:
recipe
.
Float8BlockScaling
(),
# Test default
fp8_recipes
.
append
(
recipe
.
Float8BlockScaling
())
recipe
.
DelayedScaling
(),
# Test default
if
fp8_available
:
recipe
.
DelayedScaling
(
# Test most_recent algo
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
amax_history_len
=
16
,
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
amax_compute_algo
=
"most_recent"
,
fp8_recipes
.
append
(
None
)
),
recipe
.
DelayedScaling
(
# Test custom amax and scale compute algo
fp8_format
=
recipe
.
Format
.
E4M3
,
amax_compute_algo
=
custom_amax_compute
,
scaling_factor_compute_algo
=
custom_amax_to_scale
,
),
]
param_types
=
[
torch
.
float32
,
torch
.
float16
]
param_types
=
[
torch
.
float32
,
torch
.
float16
]
if
is_bf16_compatible
():
# bf16 requires sm_80 or higher
if
is_bf16_compatible
():
# bf16 requires sm_80 or higher
...
@@ -184,66 +120,9 @@ def reset_global_fp8_state():
...
@@ -184,66 +120,9 @@ def reset_global_fp8_state():
FP8GlobalStateManager
.
reset
()
FP8GlobalStateManager
.
reset
()
def
_test_sanity_e2e_cuda_graph
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
# Initialize loss function and optimizer.
loss_fn
=
torch
.
nn
.
MSELoss
()
optimizer
=
torch
.
optim
.
SGD
(
block
.
parameters
(),
lr
=
0.1
)
# Placeholders used for capture.
static_input
=
torch
.
randn
(
config
.
seq_len
,
config
.
batch_size
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
,
requires_grad
=
True
,
)
static_target
=
torch
.
randn
(
config
.
seq_len
,
config
.
batch_size
,
config
.
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
)
real_input
=
torch
.
rand_like
(
static_input
)
real_target
=
torch
.
rand_like
(
static_target
)
use_fp8
=
fp8_recipe
is
not
None
if
skip_wgrad
:
_disable_wgrads
(
block
)
# Pre graph capture warmup in a separate stream.
s
=
torch
.
cuda
.
Stream
()
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
3
):
optimizer
.
zero_grad
(
set_to_none
=
True
)
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
,
_graph
=
True
):
out
=
block
(
static_input
)
loss
=
loss_fn
(
out
,
static_target
)
loss
.
backward
()
optimizer
.
step
()
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
# Capture.
g
=
torch
.
cuda
.
CUDAGraph
()
optimizer
.
zero_grad
(
set_to_none
=
True
)
with
torch
.
cuda
.
graph
(
g
):
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
,
_graph
=
True
):
static_output
=
block
(
static_input
)
static_loss
=
loss_fn
(
static_output
,
static_target
)
static_loss
.
backward
()
optimizer
.
step
()
# Fills the graph's input memory with new data to compute on
with
torch
.
no_grad
():
static_input
.
copy_
(
real_input
)
static_target
.
copy_
(
real_target
)
g
.
replay
()
torch
.
cuda
.
synchronize
()
def
_test_sanity_e2e_amp
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
def
_test_sanity_e2e_amp
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
...
@@ -251,7 +130,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
...
@@ -251,7 +130,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states
.
retain_grad
()
te_inp_hidden_states
.
retain_grad
()
te_inp_attn_mask
=
torch
.
randint
(
te_inp_attn_mask
=
torch
.
randint
(
2
,
2
,
(
1
,
1
,
config
.
seq
_
len
,
config
.
seq
_
len
),
(
1
,
1
,
config
.
max_
seqlen
_q
,
config
.
max_
seqlen
_kv
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -278,14 +157,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
...
@@ -278,14 +157,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
def
_test_sanity_e2e_gradient_accumulation_fusion
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
def
_test_sanity_e2e_gradient_accumulation_fusion
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
te_inp_attn_mask
=
torch
.
randint
(
te_inp_attn_mask
=
torch
.
randint
(
2
,
2
,
(
1
,
1
,
config
.
seq
_
len
,
config
.
seq
_
len
),
(
1
,
1
,
config
.
max_
seqlen
_q
,
config
.
max_
seqlen
_kv
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -316,9 +195,9 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
...
@@ -316,9 +195,9 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
assert
len
(
failed_grads
)
==
0
,
f
"Gradient not accumulated for
{
failed_grads
}
."
assert
len
(
failed_grads
)
==
0
,
f
"Gradient not accumulated for
{
failed_grads
}
."
def
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
cpu_offload
):
def
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
...
@@ -327,16 +206,9 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
...
@@ -327,16 +206,9 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
if
skip_wgrad
:
if
skip_wgrad
:
_disable_wgrads
(
block
)
_disable_wgrads
(
block
)
if
cpu_offload
:
offload_context
,
sync_function
=
get_cpu_offload_context
(
enabled
=
True
)
else
:
offload_context
=
nullcontext
()
sync_function
=
lambda
x
:
x
use_fp8
=
fp8_recipe
is
not
None
use_fp8
=
fp8_recipe
is
not
None
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
)
,
offload_context
:
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
te_out
=
block
(
te_inp_hidden_states
)
te_out
=
block
(
te_inp_hidden_states
)
te_out
=
sync_function
(
te_out
)
loss
=
te_out
.
sum
()
loss
=
te_out
.
sum
()
loss
.
backward
()
loss
.
backward
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -344,7 +216,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
...
@@ -344,7 +216,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
def
_test_sanity_e2e_bert
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
def
_test_sanity_e2e_bert
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
...
@@ -352,7 +224,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
...
@@ -352,7 +224,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_attn_mask
=
torch
.
randint
(
te_inp_attn_mask
=
torch
.
randint
(
2
,
2
,
(
config
.
batch_size
,
1
,
1
,
config
.
seq
_
len
),
(
config
.
batch_size
,
1
,
1
,
config
.
max_
seqlen
_q
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -370,21 +242,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
...
@@ -370,21 +242,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
def
_test_sanity_e2e_T5
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
def
_test_sanity_e2e_T5
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
te_inp_hidden_states
=
torch
.
randn
(
te_inp_hidden_states
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
te_inp_attn_mask
=
torch
.
randint
(
te_inp_attn_mask
=
torch
.
randint
(
2
,
2
,
(
1
,
1
,
config
.
seq
_
len
,
config
.
seq
_
len
),
(
1
,
1
,
config
.
max_
seqlen
_q
,
config
.
max_
seqlen
_kv
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
enc_dec_attn_mask
=
torch
.
randint
(
enc_dec_attn_mask
=
torch
.
randint
(
2
,
2
,
(
config
.
batch_size
,
1
,
1
,
config
.
seq
_
len
),
(
config
.
batch_size
,
1
,
1
,
config
.
max_
seqlen
_kv
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -412,7 +284,7 @@ def _test_sanity_common(
...
@@ -412,7 +284,7 @@ def _test_sanity_common(
pytest
.
skip
(
"No gradient computation; Skipping to avoid PyTorch RuntimeError."
)
pytest
.
skip
(
"No gradient computation; Skipping to avoid PyTorch RuntimeError."
)
te_inp
=
torch
.
randn
(
te_inp
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
not
skip_dgrad
,
requires_grad
=
not
skip_dgrad
,
...
@@ -440,7 +312,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
...
@@ -440,7 +312,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
pytest
.
skip
(
"No gradient computation; Skipping to avoid PyTorch RuntimeError."
)
pytest
.
skip
(
"No gradient computation; Skipping to avoid PyTorch RuntimeError."
)
te_inp
=
torch
.
randn
(
te_inp
=
torch
.
randn
(
(
config
.
seq
_
len
,
config
.
batch_size
,
config
.
hidden_size
),
(
config
.
max_
seqlen
_q
,
config
.
batch_size
,
config
.
hidden_size
),
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
...
@@ -495,13 +367,7 @@ def test_sanity_layernorm_linear(
...
@@ -495,13 +367,7 @@ def test_sanity_layernorm_linear(
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
sigma
=
0.023
...
@@ -529,13 +395,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
...
@@ -529,13 +395,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
sigma
=
0.023
...
@@ -562,16 +422,10 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
...
@@ -562,16 +422,10 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
pytest
.
skip
(
"Quantized model parameters are not supported in debug mode."
)
pytest
.
skip
(
"Quantized model parameters are not supported in debug mode."
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
ffn_hidden_size
=
4
*
config
.
hidden_size
ffn_hidden_size
=
4
*
config
.
hidden_size
num_tokens
=
bs
*
config
.
seq
_
len
num_tokens
=
bs
*
config
.
max_
seqlen
_q
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
use_fp8
=
fp8_recipe
is
not
None
use_fp8
=
fp8_recipe
is
not
None
...
@@ -607,16 +461,10 @@ def test_sanity_grouped_linear(
...
@@ -607,16 +461,10 @@ def test_sanity_grouped_linear(
ffn_hidden_size
=
4
*
config
.
hidden_size
ffn_hidden_size
=
4
*
config
.
hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
bs
=
bs
*
16
bs
=
bs
*
16
num_tokens
=
bs
*
config
.
seq
_
len
*
(
num_gemms
-
1
)
num_tokens
=
bs
*
config
.
max_
seqlen
_q
*
(
num_gemms
-
1
)
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
use_fp8
=
fp8_recipe
is
not
None
use_fp8
=
fp8_recipe
is
not
None
...
@@ -628,7 +476,7 @@ def test_sanity_grouped_linear(
...
@@ -628,7 +476,7 @@ def test_sanity_grouped_linear(
inp_hidden_states
=
torch
.
randn
(
inp_hidden_states
=
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
requires_grad
=
True
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
requires_grad
=
True
).
cuda
()
).
cuda
()
m_splits
=
[
bs
*
config
.
seq
_
len
]
*
num_gemms
m_splits
=
[
bs
*
config
.
max_
seqlen
_q
]
*
num_gemms
if
empty_split
==
"first"
:
if
empty_split
==
"first"
:
m_splits
[
0
]
=
0
m_splits
[
0
]
=
0
elif
empty_split
==
"last"
:
elif
empty_split
==
"last"
:
...
@@ -666,13 +514,7 @@ def test_sanity_layernorm_mlp(
...
@@ -666,13 +514,7 @@ def test_sanity_layernorm_mlp(
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
sigma
=
0.023
...
@@ -697,36 +539,24 @@ def test_sanity_layernorm_mlp(
...
@@ -697,36 +539,24 @@ def test_sanity_layernorm_mlp(
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
all_activations
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"gelu"
,
"swiglu"
]
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"cpu_offload"
,
all_boolean
)
def
test_sanity_gpt
(
def
test_sanity_gpt
(
dtype
,
dtype
,
fp8_recipe
,
fp8_recipe
,
model
,
model
,
skip_wgrad
,
skip_wgrad
,
zero_centered_gamma
,
bias
,
bias
,
activation
,
activation
,
normalization
,
normalization
,
parallel_attention_mlp
,
parallel_attention_mlp
,
cpu_offload
,
):
):
if
cpu_offload
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"CPU offload is not supported in debug mode."
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
sigma
=
0.023
...
@@ -736,7 +566,7 @@ def test_sanity_gpt(
...
@@ -736,7 +566,7 @@ def test_sanity_gpt(
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
...
@@ -745,7 +575,6 @@ def test_sanity_gpt(
...
@@ -745,7 +575,6 @@ def test_sanity_gpt(
params_dtype
=
dtype
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
output_layernorm
=
False
,
zero_centered_gamma
=
zero_centered_gamma
,
bias
=
bias
,
bias
=
bias
,
activation
=
activation
,
activation
=
activation
,
normalization
=
normalization
,
normalization
=
normalization
,
...
@@ -753,7 +582,7 @@ def test_sanity_gpt(
...
@@ -753,7 +582,7 @@ def test_sanity_gpt(
parallel_attention_mlp
=
parallel_attention_mlp
,
parallel_attention_mlp
=
parallel_attention_mlp
,
)
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
cpu_offload
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
)
def
test_sanity_gpt_126m
():
def
test_sanity_gpt_126m
():
...
@@ -770,12 +599,10 @@ def test_sanity_gpt_126m():
...
@@ -770,12 +599,10 @@ def test_sanity_gpt_126m():
fp8_recipe
=
fp8_recipe
,
fp8_recipe
=
fp8_recipe
,
model
=
"126m"
,
model
=
"126m"
,
skip_wgrad
=
False
,
skip_wgrad
=
False
,
zero_centered_gamma
=
True
,
bias
=
True
,
bias
=
True
,
activation
=
"gelu"
,
activation
=
"gelu"
,
normalization
=
"LayerNorm"
,
normalization
=
"LayerNorm"
,
parallel_attention_mlp
=
False
,
parallel_attention_mlp
=
False
,
cpu_offload
=
False
,
)
)
...
@@ -783,19 +610,14 @@ def test_sanity_gpt_126m():
...
@@ -783,19 +610,14 @@ def test_sanity_gpt_126m():
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_sanity_bert
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
normalization
):
def
test_sanity_bert
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
normalization
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
sigma
=
0.023
...
@@ -805,7 +627,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
...
@@ -805,7 +627,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
...
@@ -814,7 +636,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
...
@@ -814,7 +636,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
True
,
apply_residual_connection_post_layernorm
=
True
,
output_layernorm
=
True
,
output_layernorm
=
True
,
zero_centered_gamma
=
zero_centered_gamma
,
self_attn_mask_type
=
"causal"
,
self_attn_mask_type
=
"causal"
,
normalization
=
normalization
,
normalization
=
normalization
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -835,7 +656,6 @@ def test_sanity_bert_126m():
...
@@ -835,7 +656,6 @@ def test_sanity_bert_126m():
fp8_recipe
=
fp8_recipe
,
fp8_recipe
=
fp8_recipe
,
model
=
"126m"
,
model
=
"126m"
,
skip_wgrad
=
False
,
skip_wgrad
=
False
,
zero_centered_gamma
=
False
,
normalization
=
"LayerNorm"
,
normalization
=
"LayerNorm"
,
)
)
...
@@ -844,19 +664,14 @@ def test_sanity_bert_126m():
...
@@ -844,19 +664,14 @@ def test_sanity_bert_126m():
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_sanity_T5
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
normalization
):
def
test_sanity_T5
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
normalization
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
sigma
=
0.023
...
@@ -866,7 +681,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
...
@@ -866,7 +681,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
...
@@ -876,7 +691,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
...
@@ -876,7 +691,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
apply_residual_connection_post_layernorm
=
False
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
output_layernorm
=
False
,
layer_type
=
"decoder"
,
layer_type
=
"decoder"
,
zero_centered_gamma
=
zero_centered_gamma
,
normalization
=
normalization
,
normalization
=
normalization
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -896,7 +710,6 @@ def test_sanity_T5_126m():
...
@@ -896,7 +710,6 @@ def test_sanity_T5_126m():
fp8_recipe
=
fp8_recipe
,
fp8_recipe
=
fp8_recipe
,
model
=
"126m"
,
model
=
"126m"
,
skip_wgrad
=
False
,
skip_wgrad
=
False
,
zero_centered_gamma
=
False
,
normalization
=
"LayerNorm"
,
normalization
=
"LayerNorm"
,
)
)
...
@@ -909,13 +722,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -909,13 +722,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
sigma
=
0.023
...
@@ -925,7 +732,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -925,7 +732,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
...
@@ -941,18 +748,11 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -941,18 +748,11 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
def
test_sanity_drop_path
(
dtype
,
fp8_recipe
,
model
):
def
test_sanity_drop_path
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
sigma
=
0.023
...
@@ -962,7 +762,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -962,7 +762,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
...
@@ -975,7 +775,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -975,7 +775,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
device
=
"cuda"
,
device
=
"cuda"
,
)
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
False
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
False
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
...
@@ -986,13 +786,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -986,13 +786,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
sigma
=
0.023
...
@@ -1002,7 +796,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -1002,7 +796,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
...
@@ -1015,27 +809,18 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -1015,27 +809,18 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
device
=
"cuda"
,
device
=
"cuda"
,
)
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
False
)
_test_sanity_e2e
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
def
test_sanity_gradient_accumulation_fusion
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
):
def
test_sanity_gradient_accumulation_fusion
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
sigma
=
0.023
...
@@ -1045,7 +830,7 @@ def test_sanity_gradient_accumulation_fusion(
...
@@ -1045,7 +830,7 @@ def test_sanity_gradient_accumulation_fusion(
block
=
TransformerLayer
(
block
=
TransformerLayer
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_
attention_
heads
,
config
.
num_heads
,
init_method
=
init_method
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
hidden_dropout
=
0.1
,
...
@@ -1054,7 +839,6 @@ def test_sanity_gradient_accumulation_fusion(
...
@@ -1054,7 +839,6 @@ def test_sanity_gradient_accumulation_fusion(
params_dtype
=
dtype
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
output_layernorm
=
False
,
zero_centered_gamma
=
zero_centered_gamma
,
fuse_qkv_params
=
True
,
fuse_qkv_params
=
True
,
fuse_wgrad_accumulation
=
True
,
fuse_wgrad_accumulation
=
True
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -1063,56 +847,6 @@ def test_sanity_gradient_accumulation_fusion(
...
@@ -1063,56 +847,6 @@ def test_sanity_gradient_accumulation_fusion(
_test_sanity_e2e_gradient_accumulation_fusion
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
)
_test_sanity_e2e_gradient_accumulation_fusion
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_gpt_cuda_graph
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
normalization
):
if
IS_HIP_EXTENSION
:
if
not
use_hipblaslt
():
pytest
.
skip
(
"CUDA graph capture not supported with rocBLAS path"
)
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
():
pytest
.
skip
(
"cuda graph not supported for float8_block_scaling recipe"
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_attention_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.1
,
attention_dropout
=
0.1
,
kv_channels
=
config
.
kv_channels
,
params_dtype
=
dtype
,
apply_residual_connection_post_layernorm
=
False
,
output_layernorm
=
False
,
zero_centered_gamma
=
zero_centered_gamma
,
fuse_qkv_params
=
True
,
normalization
=
normalization
,
device
=
"cuda"
,
)
_test_sanity_e2e_cuda_graph
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
)
def
test_model_multiple_cast
():
def
test_model_multiple_cast
():
a
=
torch
.
zeros
((
16
,
16
),
device
=
"cuda"
)
a
=
torch
.
zeros
((
16
,
16
),
device
=
"cuda"
)
m
=
Linear
(
16
,
32
)
m
=
Linear
(
16
,
32
)
...
@@ -1167,133 +901,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
...
@@ -1167,133 +901,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
3
,
0
),
reason
=
"cuDNN 9.3.0+ is required."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"large"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
def
test_sanity_attention_extra_state
(
model
,
dtype
):
config
=
model_configs
[
model
]
outputs
=
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
True
)
outputs_checkpoint_v1_6
=
_run_attention_extra_state
(
dtype
,
config
,
mimic_v1_6
=
True
,
checkpoint
=
True
)
# Check that results match
tols
=
dtype_tols
(
dtype
)
if
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
tols
.
update
(
dict
(
rtol
=
2e-2
,
atol
=
2e-3
))
for
i
,
(
ref
,
test
)
in
enumerate
(
zip
(
outputs
,
outputs_checkpoint
)):
torch
.
testing
.
assert_close
(
test
,
ref
,
**
tols
,
)
for
i
,
(
ref
,
test
)
in
enumerate
(
zip
(
outputs
,
outputs_checkpoint_v1_6
)):
torch
.
testing
.
assert_close
(
test
,
ref
,
**
tols
,
)
def
_run_attention_extra_state
(
dtype
,
config
,
checkpoint
=
False
,
mimic_v1_6
=
False
):
steps
=
10
path
=
"checkpoint.pt"
fp8_enabled
=
True
fp8_recipe
=
recipe
.
DelayedScaling
(
margin
=
0
,
fp8_format
=
recipe
.
Format
.
HYBRID
,
amax_history_len
=
1
,
amax_compute_algo
=
"most_recent"
,
fp8_dpa
=
fp8_enabled
,
fp8_mha
=
False
,
)
reset_rng_states
()
hidden_states
=
torch
.
randn
(
(
config
.
seq_len
,
config
.
batch_size
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
def
get_model
(
dtype
,
config
):
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
with
fp8_model_init
(
enabled
=
fp8_enabled
,
recipe
=
fp8_recipe
):
block
=
TransformerLayer
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
num_attention_heads
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
hidden_dropout
=
0.0
,
attention_dropout
=
0.0
,
fuse_qkv_params
=
True
,
params_dtype
=
dtype
,
device
=
"cuda"
,
)
return
block
block
=
get_model
(
dtype
,
config
)
for
i
in
range
(
steps
//
2
):
with
fp8_autocast
(
enabled
=
fp8_enabled
,
fp8_recipe
=
fp8_recipe
):
output
=
block
(
hidden_states
,
None
)
loss
=
output
.
sum
()
loss
.
backward
()
if
checkpoint
:
sd
=
block
.
state_dict
()
if
mimic_v1_6
:
sd
[
"self_attention.core_attention.fused_attention._extra_state"
]
=
sd
[
"self_attention.core_attention._extra_state"
]
del
sd
[
"self_attention.core_attention._extra_state"
]
torch
.
save
(
sd
,
path
)
param_grads
=
[]
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
param_grads
.
append
(
p
.
grad
.
clone
())
_cpu_rng_state_new
=
torch
.
get_rng_state
()
_cuda_rng_state_new
=
torch
.
cuda
.
get_rng_state
()
del
block
block
=
get_model
(
dtype
,
config
)
block
.
load_state_dict
(
torch
.
load
(
path
,
weights_only
=
False
))
torch
.
set_rng_state
(
_cpu_rng_state_new
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state_new
)
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
p
.
grad
=
param_grads
.
pop
(
0
)
assert
not
param_grads
,
"Oops!"
for
i
in
range
((
steps
+
1
)
//
2
):
with
fp8_autocast
(
enabled
=
fp8_enabled
,
fp8_recipe
=
fp8_recipe
):
output
=
block
(
hidden_states
,
None
)
loss
=
output
.
sum
()
loss
.
backward
()
torch
.
cuda
.
synchronize
()
if
os
.
path
.
exists
(
path
):
os
.
remove
(
path
)
outputs
=
[
output
,
hidden_states
.
grad
]
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
outputs
.
append
(
p
.
grad
)
return
outputs
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_replace_raw_data_for_float8tensor
():
def
test_replace_raw_data_for_float8tensor
():
"""Test the functionality of replace_raw_data"""
"""Test the functionality of replace_raw_data"""
...
@@ -1389,6 +996,32 @@ def test_sanity_checkpointing_on_callables():
...
@@ -1389,6 +996,32 @@ def test_sanity_checkpointing_on_callables():
torch
.
testing
.
assert_close
(
grad_checkpoint
,
grad_standard
)
torch
.
testing
.
assert_close
(
grad_checkpoint
,
grad_standard
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_linear_frozen_weights_memory_default_recipe
():
"""Test that memory usage is optimized when weights are frozen for MXFP8."""
dim
=
1024
linear
=
Linear
(
dim
,
dim
,
bias
=
False
)
x
=
torch
.
randn
(
dim
,
dim
,
requires_grad
=
True
,
device
=
"cuda"
)
# Freeze weights
linear
.
weight
.
requires_grad
=
False
# Forward and backward pass with FP8
with
fp8_autocast
():
o
=
linear
(
x
)
g_o
=
torch
.
randn_like
(
o
)
max_memory_before_backward
=
torch
.
cuda
.
max_memory_allocated
()
o
.
backward
(
g_o
)
max_memory_after_backward
=
torch
.
cuda
.
max_memory_allocated
()
memory_diff
=
(
max_memory_after_backward
-
max_memory_before_backward
)
/
1e6
assert
memory_diff
<
5.5
,
(
f
"Memory usage with frozen weights (
{
memory_diff
}
MB) should be less than 5.5MB as the"
" grad_output should be quantized only columnwise."
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"module_name"
,
"module_name"
,
(
"Linear"
,
"LayerNormLinear"
,
"LayerNormMLP"
,
"GroupedLinear"
,
"ops.Linear"
),
(
"Linear"
,
"LayerNormLinear"
,
"LayerNormMLP"
,
"GroupedLinear"
,
"ops.Linear"
),
...
...
tests/pytorch/utils.py
View file @
87e3e56e
...
@@ -4,12 +4,24 @@
...
@@ -4,12 +4,24 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
os
from
contextlib
import
contextmanager
import
pytest
import
torch
import
torch
import
transformer_engine
import
transformer_engine
import
transformer_engine.common.recipe
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
get_attention_backend
,
AttentionParams
,
AttentionLogging
,
)
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
def
str_to_dtype
(
dtype
:
str
|
torch
.
dtype
)
->
torch
.
dtype
:
def
str_to_dtype
(
dtype
:
str
|
torch
.
dtype
)
->
torch
.
dtype
:
...
@@ -106,3 +118,178 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
...
@@ -106,3 +118,178 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if
name
==
"fp8_block_scaling"
:
if
name
==
"fp8_block_scaling"
:
return
transformer_engine
.
common
.
recipe
.
Float8BlockScaling
()
return
transformer_engine
.
common
.
recipe
.
Float8BlockScaling
()
raise
ValueError
(
f
"Unsupported quantization scheme (
{
name
}
)"
)
raise
ValueError
(
f
"Unsupported quantization scheme (
{
name
}
)"
)
# Cached RNG state
_rng_states
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
def
reset_rng_states
()
->
None
:
"""Revert to deterministic RNG state"""
global
_rng_states
if
_rng_states
is
None
:
torch
.
manual_seed
(
1234
)
torch
.
cuda
.
manual_seed
(
1234
)
_rng_states
=
(
torch
.
get_rng_state
(),
torch
.
cuda
.
get_rng_state
())
else
:
cpu_rng_state
,
cuda_rng_state
=
_rng_states
torch
.
set_rng_state
(
cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
class
ModelConfig
:
def
__init__
(
self
,
batch_size
:
int
,
max_seqlen_q
:
int
,
num_heads
:
int
,
head_dim_qk
:
int
,
max_seqlen_kv
:
int
=
None
,
num_gqa_groups
:
int
=
None
,
head_dim_v
:
int
=
None
,
dropout_p
:
float
=
0.0
,
attn_mask_type
:
str
=
"no_mask"
,
attn_bias_type
:
str
=
"no_bias"
,
alibi_type
:
str
=
"none"
,
bias_shape
:
str
=
"1hss"
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
total_requests
:
int
=
None
,
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
eps
:
float
=
1e-5
,
):
self
.
batch_size
=
batch_size
self
.
max_seqlen_q
=
max_seqlen_q
self
.
max_seqlen_kv
=
max_seqlen_q
if
max_seqlen_kv
is
None
else
max_seqlen_kv
self
.
num_heads
=
num_heads
self
.
num_gqa_groups
=
num_heads
if
num_gqa_groups
is
None
else
num_gqa_groups
self
.
head_dim_qk
=
head_dim_qk
self
.
head_dim_v
=
head_dim_qk
if
head_dim_v
is
None
else
head_dim_v
if
self
.
head_dim_qk
==
self
.
head_dim_v
:
self
.
kv_channels
=
self
.
head_dim_qk
else
:
self
.
kv_channels
=
(
self
.
head_dim_qk
,
self
.
head_dim_v
)
self
.
hidden_size
=
self
.
num_heads
*
self
.
head_dim_qk
self
.
hidden_size_kv
=
self
.
num_gqa_groups
*
self
.
head_dim_v
self
.
dropout_p
=
dropout_p
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_bias_type
=
attn_bias_type
self
.
alibi_type
=
alibi_type
self
.
attn_type
=
"self"
if
(
self
.
max_seqlen_q
==
self
.
max_seqlen_kv
)
else
"cross"
self
.
bias_shape
=
bias_shape
self
.
window_size
=
window_size
self
.
total_requests
=
total_requests
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
self
.
eps
=
eps
@
contextmanager
def
logging_context
(
highest_level
=
logging
.
WARNING
):
previous_level
=
logging
.
root
.
manager
.
disable
logging
.
disable
(
highest_level
)
try
:
yield
finally
:
logging
.
disable
(
previous_level
)
def
get_available_attention_backends
(
config
:
ModelConfig
,
qkv_dtype
:
torch
.
dtype
,
qkv_layout
:
str
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
pad_between_seqs
:
bool
=
False
,
context_parallel
:
bool
=
False
,
deterministic
:
bool
=
False
,
fp8
:
bool
=
False
,
fp8_meta
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_training
:
bool
=
True
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
)
->
Tuple
[
List
,
List
]:
"""Check for all available attention backends that support a model configuration"""
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
os
.
environ
[
"NVTE_UNFUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
alibi_slopes_shape
=
None
if
config
.
attn_bias_type
==
"alibi"
and
config
.
alibi_type
==
"custom"
:
if
config
.
bias_shape
==
"1hss"
:
alibi_slopes_shape
=
[
config
.
num_heads
]
if
config
.
bias_shape
==
"bhss"
:
alibi_slopes_shape
=
[
config
.
batch_size
,
config
.
num_heads
]
core_attention_bias_shape
=
(
config
.
bias_shape
if
config
.
attn_bias_type
==
"post_scale_bias"
else
None
)
core_attention_bias_requires_grad
=
False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if
(
config
.
attn_bias_type
==
"post_scale_bias"
and
config
.
head_dim_qk
<=
128
and
config
.
head_dim_v
<=
128
):
core_attention_bias_requires_grad
=
True
fused_attn_backends
=
[]
available_backends
=
None
flash_attention_backend
=
None
fused_attention_backend
=
None
def
test
():
attention_params
=
AttentionParams
(
qkv_dtype
=
qkv_dtype
,
qkv_layout
=
qkv_layout
,
batch_size
=
config
.
batch_size
,
num_heads
=
config
.
num_heads
,
num_gqa_groups
=
config
.
num_gqa_groups
,
max_seqlen_q
=
config
.
max_seqlen_q
,
max_seqlen_kv
=
config
.
max_seqlen_kv
,
head_dim_qk
=
config
.
head_dim_qk
,
head_dim_v
=
config
.
head_dim_v
,
attn_mask_type
=
config
.
attn_mask_type
,
window_size
=
window_size
,
alibi_slopes_shape
=
alibi_slopes_shape
,
core_attention_bias_type
=
config
.
attn_bias_type
,
core_attention_bias_shape
=
core_attention_bias_shape
,
core_attention_bias_requires_grad
=
core_attention_bias_requires_grad
,
pad_between_seqs
=
pad_between_seqs
,
attention_dropout
=
config
.
dropout_p
,
context_parallel
=
context_parallel
,
deterministic
=
deterministic
,
fp8
=
fp8
,
fp8_meta
=
fp8_meta
,
is_training
=
is_training
,
inference_params
=
inference_params
,
)
(
use_flash_attention
,
use_fused_attention
,
flash_attention_backend
,
fused_attention_backend
,
use_unfused_attention
,
available_backends
,
)
=
get_attention_backend
(
attention_params
)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends
[
"use_flash_attention"
]
=
use_flash_attention
_attention_backends
[
"use_fused_attention"
]
=
use_fused_attention
_attention_backends
[
"flash_attention_backend"
]
=
flash_attention_backend
_attention_backends
[
"fused_attention_backend"
]
=
fused_attention_backend
_attention_backends
[
"use_unfused_attention"
]
=
use_unfused_attention
_attention_backends
[
"backend_selection_requires_update"
]
=
False
return
available_backends
,
flash_attention_backend
,
fused_attention_backend
backends
=
{
0
:
"F16_max512_seqlen"
,
1
:
"F16_arbitrary_seqlen"
,
2
:
"FP8"
}
if
AttentionLogging
.
_is_logging_setup
is
False
:
AttentionLogging
.
setup_logging
()
with
logging_context
(
highest_level
=
AttentionLogging
.
_log_level
):
for
i
in
range
(
3
):
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
str
(
i
)
_attention_backends
[
"backend_selection_requires_update"
]
=
True
available_backends
,
flash_attention_backend
,
fused_attention_backend
=
test
()
if
fused_attention_backend
==
FusedAttnBackend
[
backends
[
i
]]:
fused_attn_backends
.
append
(
fused_attention_backend
)
return
available_backends
,
flash_attention_backend
,
fused_attn_backends
transformer_engine/common/CMakeLists.txt
View file @
87e3e56e
...
@@ -126,6 +126,7 @@ if(USE_CUDA)
...
@@ -126,6 +126,7 @@ if(USE_CUDA)
transpose/multi_cast_transpose.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
activation/gelu.cu
fused_attn/flash_attn.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/context_parallel.cu
...
@@ -189,6 +190,7 @@ else()
...
@@ -189,6 +190,7 @@ else()
transpose/multi_cast_transpose.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
activation/gelu.cu
activation/relu.cu
activation/relu.cu
activation/swiglu.cu
activation/swiglu.cu
...
@@ -347,6 +349,8 @@ if(USE_CUDA)
...
@@ -347,6 +349,8 @@ if(USE_CUDA)
string_code_transpose_rtc_cast_transpose_cu
)
string_code_transpose_rtc_cast_transpose_cu
)
make_string_header_from_file
(
transpose/rtc/transpose.cu
make_string_header_from_file
(
transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu
)
string_code_transpose_rtc_transpose_cu
)
make_string_header_from_file
(
transpose/rtc/swap_first_dims.cu
string_code_transpose_rtc_swap_first_dims_cu
)
make_string_header_from_file
(
utils.cuh
make_string_header_from_file
(
utils.cuh
string_code_utils_cuh
)
string_code_utils_cuh
)
else
()
else
()
...
@@ -358,6 +362,8 @@ else()
...
@@ -358,6 +362,8 @@ else()
string_code_transpose_rtc_cast_transpose_cu
)
string_code_transpose_rtc_cast_transpose_cu
)
make_string_header_from_file
(
transpose/rtc/transpose.hip
make_string_header_from_file
(
transpose/rtc/transpose.hip
string_code_transpose_rtc_transpose_cu
)
string_code_transpose_rtc_transpose_cu
)
make_string_header_from_file
(
transpose/rtc/swap_first_dims.cu
string_code_transpose_rtc_swap_first_dims_cu
)
endif
()
endif
()
...
@@ -385,6 +391,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
...
@@ -385,6 +391,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties
(
activation/gelu.cu
set_source_files_properties
(
activation/gelu.cu
activation/relu.cu
activation/relu.cu
activation/swiglu.cu
activation/swiglu.cu
util/cast.cu
PROPERTIES
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
COMPILE_OPTIONS
"--use_fast_math"
)
endif
()
endif
()
...
...
transformer_engine/common/__init__.py
View file @
87e3e56e
...
@@ -246,6 +246,18 @@ def _load_cudnn():
...
@@ -246,6 +246,18 @@ def _load_cudnn():
if
found
:
if
found
:
return
handle
return
handle
# Attempt to locate libcudnn via ldconfig
libs
=
subprocess
.
check_output
(
f
"ldconfig -p | grep 'libcudnn
{
_get_sys_extension
()
}
'"
,
shell
=
True
)
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
sos
=
[]
for
lib
in
libs
:
if
"libcudnn"
in
lib
and
"=>"
in
lib
:
sos
.
append
(
lib
.
split
(
">"
)[
1
].
strip
())
if
sos
:
return
ctypes
.
CDLL
(
sos
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return
ctypes
.
CDLL
(
f
"libcudnn
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
return
ctypes
.
CDLL
(
f
"libcudnn
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
...
@@ -267,12 +279,12 @@ def _load_nvrtc():
...
@@ -267,12 +279,12 @@ def _load_nvrtc():
return
handle
return
handle
# Attempt to locate NVRTC via ldconfig
# Attempt to locate NVRTC via ldconfig
libs
=
subprocess
.
check_output
(
"ldconfig -p | grep 'libnvrtc'"
,
shell
=
True
)
libs
=
subprocess
.
check_output
(
f
"ldconfig -p | grep 'libnvrtc
{
_get_sys_extension
()
}
'"
,
shell
=
True
)
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
sos
=
[]
sos
=
[]
for
lib
in
libs
:
for
lib
in
libs
:
if
"stub"
in
lib
or
"libnvrtc-builtins"
in
lib
:
continue
if
"libnvrtc"
in
lib
and
"=>"
in
lib
:
if
"libnvrtc"
in
lib
and
"=>"
in
lib
:
sos
.
append
(
lib
.
split
(
">"
)[
1
].
strip
())
sos
.
append
(
lib
.
split
(
">"
)[
1
].
strip
())
if
sos
:
if
sos
:
...
...
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
87e3e56e
...
@@ -189,14 +189,26 @@ CommOverlapCore::~CommOverlapCore() {
...
@@ -189,14 +189,26 @@ CommOverlapCore::~CommOverlapCore() {
if
(
_atomic_gemm
)
cudaFree
(
_counter
.
dptr
());
if
(
_atomic_gemm
)
cudaFree
(
_counter
.
dptr
());
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
cudaStreamDestroy
(
_stream_compute
[
i
]);
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
cudaStreamSynchronize
(
_stream_compute
[
i
]);
cudaStreamDestroy
(
_stream_compute
[
i
]);
}
auto
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
{
NVTE_WARN
(
"Error detected while destroying communicator: "
,
cudaGetErrorString
(
error
));
}
if
(
_comm_created
)
{
if
(
_comm_created
)
{
try
{
#ifdef NVTE_UB_WITH_MPI
#ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi
(
_ub_comm
);
destroy_communicator_mpi
(
_ub_comm
);
#else
#else
destroy_communicator
(
_ub_comm
);
destroy_communicator
(
_ub_comm
);
#endif
#endif
}
catch
(
const
std
::
exception
&
e
)
{
NVTE_WARN
(
"Error destroying communicator, cleanup may be incomplete:
\n
"
,
e
.
what
());
}
_comm_created
=
false
;
_comm_created
=
false
;
}
}
}
}
...
@@ -382,6 +394,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
...
@@ -382,6 +394,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
CommOverlapBase
::~
CommOverlapBase
()
{
CommOverlapBase
::~
CommOverlapBase
()
{
cudaEventDestroy
(
_start_d2dcopy
);
cudaEventDestroy
(
_start_d2dcopy
);
cudaStreamSynchronize
(
_stream_comm
);
cudaStreamDestroy
(
_stream_comm
);
cudaStreamDestroy
(
_stream_comm
);
}
}
...
@@ -704,6 +717,25 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -704,6 +717,25 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
}
// CommOverlapBase::split_overlap_rs
}
// CommOverlapBase::split_overlap_rs
void
CommOverlapBase
::
bulk_overlap_external_ag
(
cudaStream_t
send_stream
,
cudaStream_t
recv_stream
,
cudaStream_t
stream_main
)
{
int
comm_bytes
=
_ubuf
.
bytes
();
int
comm_bytes_per_rank
=
comm_bytes
/
_tp_size
;
// We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush
userbuffers_send_all
(
_ub_reg
,
0
,
_ub_reg
,
0
,
comm_bytes_per_rank
,
_tp_id
,
_tp_size
,
_ub_comm
,
send_stream
);
userbuffers_recv_all
(
_ub_reg
,
0
,
_ub_reg
,
0
,
comm_bytes_per_rank
,
_tp_id
,
_tp_size
,
_ub_comm
,
recv_stream
);
for
(
auto
stream
:
{
send_stream
,
recv_stream
})
{
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
stream
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
// We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_stop_comm
,
0
));
}
}
/***************************************************************************************************
/***************************************************************************************************
* Comm+GEMM Overlap P2P Base (Ring-Exchange)
* Comm+GEMM Overlap P2P Base (Ring-Exchange)
**************************************************************************************************/
**************************************************************************************************/
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
View file @
87e3e56e
...
@@ -2652,6 +2652,30 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
...
@@ -2652,6 +2652,30 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
}
}
}
}
void
userbuffers_send_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
)
{
for
(
int
j
=
1
;
j
<
tp_size
;
j
++
)
{
int
i
=
(
tp_rank
+
j
)
%
tp_size
;
int
send_offset
=
srcoffset
+
bytes_per_slice
*
tp_rank
;
int
recv_offset
=
dstoffset
+
bytes_per_slice
*
tp_rank
;
userbuffers_send
(
srchandler
,
send_offset
,
dsthandler
,
recv_offset
,
bytes_per_slice
,
comm
,
i
,
stream
);
}
}
void
userbuffers_recv_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
)
{
for
(
int
j
=
tp_size
-
1
;
j
>
0
;
j
--
)
{
int
i
=
(
tp_rank
+
j
)
%
tp_size
;
int
send_offset
=
srcoffset
+
bytes_per_slice
*
i
;
int
recv_offset
=
dstoffset
+
bytes_per_slice
*
i
;
userbuffers_recv
(
srchandler
,
send_offset
,
dsthandler
,
recv_offset
,
bytes_per_slice
,
comm
,
i
,
stream
);
}
}
// producer
// producer
static
__global__
void
producer_kernel
(
void
*
atomic_ptr
,
int
chunk_i
)
{
static
__global__
void
producer_kernel
(
void
*
atomic_ptr
,
int
chunk_i
)
{
// Decrement atomic val to signal current output tile finish
// Decrement atomic val to signal current output tile finish
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
View file @
87e3e56e
...
@@ -312,4 +312,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp
...
@@ -312,4 +312,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp
void
reduce_bf16
(
void
*
input
,
void
*
output
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
void
reduce_bf16
(
void
*
input
,
void
*
output
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
void
userbuffers_send_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
);
void
userbuffers_recv_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
transformer_engine/common/common.cu
View file @
87e3e56e
...
@@ -98,6 +98,9 @@ void checkCuDriverContext(CUstream stream) {
...
@@ -98,6 +98,9 @@ void checkCuDriverContext(CUstream stream) {
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
return
;
return
;
#else
#else
// Ensure the thread's "current" CUDA context is set.
cuda_driver
::
ensure_context_exists
();
CUcontext
ctx
;
CUcontext
ctx
;
const
CUresult
driver_status
=
cuda_driver
::
call
(
"cuStreamGetCtx"
,
stream
,
&
ctx
);
const
CUresult
driver_status
=
cuda_driver
::
call
(
"cuStreamGetCtx"
,
stream
,
&
ctx
);
switch
(
driver_status
)
{
switch
(
driver_status
)
{
...
@@ -167,10 +170,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
...
@@ -167,10 +170,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
void
*
dataPtr
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uint8_t
*>
(
tensor
.
dptr
)
+
void
*
dataPtr
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uint8_t
*>
(
tensor
.
dptr
)
+
(
offset_elems
*
type_num_bits
)
/
8
);
(
offset_elems
*
type_num_bits
)
/
8
);
NVTE_CHECK
(
is_aligned_ptr
(
dataPtr
,
TMA_
gmem_alignment
),
NVTE_CHECK
(
is_aligned_ptr
(
dataPtr
,
TMA_
GMEM_ALIGNMENT
),
"Tensor data pointer must be 16B aligned"
);
"Tensor data pointer must be 16B aligned"
);
const
int
TMA_needed_size
=
(
TMA_
gmem_alignment
*
8
)
/
type_num_bits
;
const
int
TMA_needed_size
=
(
TMA_
GMEM_ALIGNMENT
*
8
)
/
type_num_bits
;
NVTE_CHECK
(
globalX
%
TMA_needed_size
==
0
,
"Shape not supported. For "
,
type_num_bits
,
NVTE_CHECK
(
globalX
%
TMA_needed_size
==
0
,
"Shape not supported. For "
,
type_num_bits
,
"-bit data type, expected multiple of "
,
TMA_needed_size
,
", got "
,
globalX
);
"-bit data type, expected multiple of "
,
TMA_needed_size
,
", got "
,
globalX
);
...
...
transformer_engine/common/common.h
View file @
87e3e56e
...
@@ -94,7 +94,7 @@ struct SimpleTensor {
...
@@ -94,7 +94,7 @@ struct SimpleTensor {
nvte_make_shape
(
this
->
shape
.
data
(),
this
->
shape
.
size
())};
nvte_make_shape
(
this
->
shape
.
data
(),
this
->
shape
.
size
())};
}
}
in
t
numel
()
const
{
size_
t
numel
()
const
{
size_t
acc
=
1
;
size_t
acc
=
1
;
for
(
const
auto
&
dim
:
shape
)
{
for
(
const
auto
&
dim
:
shape
)
{
acc
*=
dim
;
acc
*=
dim
;
...
@@ -737,7 +737,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128;
...
@@ -737,7 +737,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr
size_t
scale_tensor_alignment_Y_colwise
=
4
;
constexpr
size_t
scale_tensor_alignment_Y_colwise
=
4
;
// Alignment requirements for the Tensor Memory Accelerator (TMA)
// Alignment requirements for the Tensor Memory Accelerator (TMA)
constexpr
int
TMA_gmem_alignment
=
16
;
// global memory address alignment
constexpr
size_t
TMA_GMEM_ALIGNMENT
=
16
;
// global memory address alignment
constexpr
size_t
TMA_SHMEM_ALIGNMENT
=
128
;
// shared memory address alignment
inline
bool
is_aligned_ptr
(
const
void
*
ptr
,
size_t
alignment
)
{
inline
bool
is_aligned_ptr
(
const
void
*
ptr
,
size_t
alignment
)
{
return
reinterpret_cast
<
uintptr_t
>
(
ptr
)
%
alignment
==
0
;
return
reinterpret_cast
<
uintptr_t
>
(
ptr
)
%
alignment
==
0
;
...
...
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
87e3e56e
...
@@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_PADDING_MASK
||
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_PADDING_MASK
||
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_PADDING_CAUSAL_MASK
)))
&&
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_PADDING_CAUSAL_MASK
)))
&&
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
&&
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
&&
!
requires_64bit_ragged_offset
)
{
!
requires_64bit_ragged_offset
&&
// 9.10.0: known bugs with SDPA FP8
(
cudnn_runtime_version
!=
91000
))
{
if
(
cudnn_runtime_version
>=
8900
)
{
if
(
cudnn_runtime_version
>=
8900
)
{
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_FP8
;
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_FP8
;
}
else
{
}
else
{
...
@@ -239,20 +241,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -239,20 +241,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(
!
is_training
&&
sm_arch_
>=
100
&&
cudnn_runtime_version
>=
90900
&&
max_seqlen_q
>
1
&&
(
!
is_training
&&
sm_arch_
>=
100
&&
cudnn_runtime_version
>=
90900
&&
max_seqlen_q
>
1
&&
layout_group
!=
NVTE_QKV_Layout_Group
::
NVTE_Paged_KV_HD_HD_HD
)
||
layout_group
!=
NVTE_QKV_Layout_Group
::
NVTE_Paged_KV_HD_HD_HD
)
||
// 9.10: any head_dim + any arch + fprop + paged
// 9.10
.2
: any head_dim + any arch + fprop + paged
// 9.10: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10
.2
: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
// 9.10
.2
: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(
!
is_training
&&
cudnn_runtime_version
>=
9100
0
&&
(
!
is_training
&&
cudnn_runtime_version
>=
9100
2
&&
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_Paged_KV_HD_HD_HD
||
max_seqlen_q
>
1
||
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_Paged_KV_HD_HD_HD
||
max_seqlen_q
>
1
||
(
max_seqlen_q
==
1
&&
attn_mask_type
!=
NVTE_Mask_Type
::
NVTE_CAUSAL_MASK
&&
(
max_seqlen_q
==
1
&&
attn_mask_type
!=
NVTE_Mask_Type
::
NVTE_CAUSAL_MASK
&&
attn_mask_type
!=
NVTE_Mask_Type
::
NVTE_PADDING_CAUSAL_MASK
)))
||
attn_mask_type
!=
NVTE_Mask_Type
::
NVTE_PADDING_CAUSAL_MASK
)))
||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(
head_dim_qk
==
192
&&
head_dim_v
==
128
&&
is_training
&&
sm_arch_
>=
100
&&
(
head_dim_qk
==
192
&&
head_dim_v
==
128
&&
is_training
&&
sm_arch_
>=
100
&&
cudnn_runtime_version
>=
91100
))
&&
cudnn_runtime_version
>=
91100
))
&&
// 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
// 9.11
/9.12
bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(
!
(
cudnn_runtime_version
==
91100
&&
is_training
&&
sm_arch_
==
90
&&
head_dim_qk
>=
128
&&
(
!
(
(
cudnn_runtime_version
==
91100
||
cudnn_runtime_version
==
91200
)
&&
is_training
&&
head_dim_v
>=
128
&&
!
(
head_dim_qk
=
=
1
9
2
&&
head_dim_v
=
=
128
)
&&
sm_arch_
==
90
&&
head_dim_qk
>
=
12
8
&&
head_dim_v
>
=
128
&&
head_dim_qk
!=
head_dim_v
)))
&&
!
(
head_dim_qk
==
192
&&
head_dim_v
==
128
)
&&
head_dim_qk
!=
head_dim_v
)))
&&
// bias type
// bias type
((
cudnn_runtime_version
<
8906
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
)
||
((
cudnn_runtime_version
<
8906
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
)
||
(
cudnn_runtime_version
>=
8906
&&
(
cudnn_runtime_version
>=
8906
&&
...
@@ -358,7 +360,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -358,7 +360,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
max_seqlen_q
<=
max_seqlen_kv
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
&&
max_seqlen_q
<=
max_seqlen_kv
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
&&
dropout
==
0.0
))))
&&
dropout
==
0.0
))))
&&
// check 64-bit ragged offset support
// check 64-bit ragged offset support
(
supported_ragged_offset_size
))
{
(
supported_ragged_offset_size
)
&&
// 9.10.0/9.10.1: known bugs with SDPA F16
(
cudnn_runtime_version
!=
91000
)
&&
(
cudnn_runtime_version
!=
91001
))
{
flag_arb
=
true
;
flag_arb
=
true
;
}
}
if
(((
max_seqlen_q
>
512
)
||
(
max_seqlen_kv
>
512
))
&&
(
flag_arb
==
true
))
{
if
(((
max_seqlen_q
>
512
)
||
(
max_seqlen_kv
>
512
))
&&
(
flag_arb
==
true
))
{
...
...
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
View file @
87e3e56e
...
@@ -90,7 +90,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
...
@@ -90,7 +90,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
*/
CompType
intermediate_result
=
CompType
intermediate_result
=
warp_reduce_on_shmem
(
aggregated_probs_per_expert
,
num_cols
,
sum
,
lane_id
);
warp_reduce_on_shmem
(
aggregated_probs_per_expert
,
num_cols
,
ReduceFuncType
::
SUM
,
lane_id
);
__syncwarp
();
__syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
{
...
@@ -146,7 +146,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
...
@@ -146,7 +146,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
*/
CompType
intermediate_result
=
CompType
intermediate_result
=
warp_reduce_on_shmem
(
aggregated_probs_per_expert
,
num_cols
,
sum
,
lane_id
);
warp_reduce_on_shmem
(
aggregated_probs_per_expert
,
num_cols
,
ReduceFuncType
::
SUM
,
lane_id
);
__syncwarp
();
__syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
{
...
...
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
View file @
87e3e56e
...
@@ -107,7 +107,8 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
...
@@ -107,7 +107,8 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
if
(
score_function
==
0
)
{
if
(
score_function
==
0
)
{
if
(
topk
>
1
)
{
if
(
topk
>
1
)
{
auto
sum_logits
=
warp_reduce_on_shmem
(
local_logits
,
num_experts
,
sum
,
lane_id
);
auto
sum_logits
=
warp_reduce_on_shmem
(
local_logits
,
num_experts
,
ReduceFuncType
::
SUM
,
lane_id
);
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_logits
[
i
]
=
static_cast
<
DataType
>
(
static_cast
<
double
>
(
local_logits
[
i
])
/
local_logits
[
i
]
=
static_cast
<
DataType
>
(
static_cast
<
double
>
(
local_logits
[
i
])
/
(
static_cast
<
double
>
(
sum_logits
)
+
epsilon
));
(
static_cast
<
double
>
(
sum_logits
)
+
epsilon
));
...
@@ -231,13 +232,15 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int
...
@@ -231,13 +232,15 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int
*/
*/
// Sigmoid Post-processing bwd when topk > 1
// Sigmoid Post-processing bwd when topk > 1
if
(
topk
>
1
&&
score_function
==
0
)
{
if
(
topk
>
1
&&
score_function
==
0
)
{
auto
sum_fwd_input
=
warp_reduce_on_shmem
(
local_act_from_fwd
,
num_experts
,
sum
,
lane_id
);
auto
sum_fwd_input
=
warp_reduce_on_shmem
(
local_act_from_fwd
,
num_experts
,
ReduceFuncType
::
SUM
,
lane_id
);
// Put the result of output * grad to the comp_buf
// Put the result of output * grad to the comp_buf
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_comp_buf
[
i
]
=
local_grad
[
i
]
*
local_act_from_fwd
[
i
];
local_comp_buf
[
i
]
=
local_grad
[
i
]
*
local_act_from_fwd
[
i
];
}
}
__syncwarp
();
__syncwarp
();
auto
sum_Output_x_Grad
=
warp_reduce_on_shmem
(
local_comp_buf
,
num_experts
,
sum
,
lane_id
);
auto
sum_Output_x_Grad
=
warp_reduce_on_shmem
(
local_comp_buf
,
num_experts
,
ReduceFuncType
::
SUM
,
lane_id
);
// In-place update
// In-place update
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_grad
[
i
]
=
local_grad
[
i
]
=
...
...
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
View file @
87e3e56e
...
@@ -220,7 +220,7 @@ __global__ void fused_topk_with_score_function_forward_kernel(
...
@@ -220,7 +220,7 @@ __global__ void fused_topk_with_score_function_forward_kernel(
// score_function == 0 means sigmoid
// score_function == 0 means sigmoid
if
(
score_function
==
0
)
{
if
(
score_function
==
0
)
{
if
(
topk
>
1
)
{
if
(
topk
>
1
)
{
double
sum_scores
=
warp_reduce_on_shmem
(
topk_scores
,
topk
,
sum
,
lane_id
);
double
sum_scores
=
warp_reduce_on_shmem
(
topk_scores
,
topk
,
ReduceFuncType
::
SUM
,
lane_id
);
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
kThreadsPerWarp
)
{
topk_scores
[
i
]
=
static_cast
<
double
>
(
topk_scores
[
i
])
/
(
sum_scores
+
epsilon
);
topk_scores
[
i
]
=
static_cast
<
double
>
(
topk_scores
[
i
])
/
(
sum_scores
+
epsilon
);
}
}
...
@@ -362,7 +362,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
...
@@ -362,7 +362,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
/*data ptr = */
local_act_from_fwd
,
/*data ptr = */
local_act_from_fwd
,
/*mask ptr = */
local_routing_map
,
/*mask ptr = */
local_routing_map
,
/*data size = */
num_experts
,
/*data size = */
num_experts
,
/*reduce func = */
sum
,
lane_id
);
/*reduce func = */
ReduceFuncType
::
SUM
,
lane_id
);
// Put the result of output * grad to the comp_buf
// Put the result of output * grad to the comp_buf
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_comp_buf
[
i
]
=
(
local_routing_map
[
i
]
?
static_cast
<
double
>
(
local_grad
[
i
])
*
local_comp_buf
[
i
]
=
(
local_routing_map
[
i
]
?
static_cast
<
double
>
(
local_grad
[
i
])
*
...
@@ -374,7 +374,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
...
@@ -374,7 +374,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
/*data ptr = */
local_comp_buf
,
/*data ptr = */
local_comp_buf
,
/*mask ptr = */
local_routing_map
,
/*mask ptr = */
local_routing_map
,
/*data size = */
num_experts
,
/*data size = */
num_experts
,
/*reduce func = */
sum
,
lane_id
);
/*reduce func = */
ReduceFuncType
::
SUM
,
lane_id
);
// In-place update
// In-place update
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
if
(
local_routing_map
[
i
])
{
if
(
local_routing_map
[
i
])
{
...
...
transformer_engine/common/fused_router/utils.h
View file @
87e3e56e
...
@@ -30,14 +30,28 @@ __device__ inline T sum(T a, T b) {
...
@@ -30,14 +30,28 @@ __device__ inline T sum(T a, T b) {
return
a
+
b
;
return
a
+
b
;
}
}
enum
ReduceFuncType
{
SUM
,
MAX
,
};
template
<
typename
T
>
template
<
typename
T
>
__device__
inline
T
warp_reduce_on_shmem
(
T
*
data_ptr
,
int
data_size
,
T
(
*
r
educe
_f
unc
)(
T
,
T
)
,
__device__
inline
T
warp_reduce_on_shmem
(
T
*
data_ptr
,
int
data_size
,
R
educe
F
unc
Type
type
,
int
lane_id
)
{
int
lane_id
)
{
T
(
*
reduce_func
)(
T
,
T
);
double
default_val
=
0
;
if
(
type
==
ReduceFuncType
::
SUM
)
{
reduce_func
=
sum
;
default_val
=
0
;
}
else
if
(
type
==
ReduceFuncType
::
MAX
)
{
reduce_func
=
max
;
default_val
=
-
std
::
numeric_limits
<
double
>::
infinity
();
}
// Some value is hanlded in local thread
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
// Reduce the value in local thread
volatile
double
val
=
volatile
double
val
=
lane_id
<
data_size
?
static_cast
<
double
>
(
data_ptr
[
lane_id
])
:
default_val
;
lane_id
<
data_size
?
static_cast
<
double
>
(
data_ptr
[
lane_id
])
:
static_cast
<
double
>
(
0
);
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
val
=
reduce_func
(
val
,
data_ptr
[
i
]);
val
=
reduce_func
(
val
,
data_ptr
[
i
]);
}
}
...
@@ -69,13 +83,22 @@ __device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, i
...
@@ -69,13 +83,22 @@ __device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, i
template
<
typename
T
>
template
<
typename
T
>
__device__
inline
T
masked_warp_reduce_on_shmem
(
T
*
data_ptr
,
bool
*
mask
,
int
data_size
,
__device__
inline
T
masked_warp_reduce_on_shmem
(
T
*
data_ptr
,
bool
*
mask
,
int
data_size
,
T
(
*
reduce_func
)(
T
,
T
),
int
lane_id
)
{
ReduceFuncType
type
,
int
lane_id
)
{
T
(
*
reduce_func
)(
T
,
T
);
double
default_val
=
0
;
if
(
type
==
ReduceFuncType
::
SUM
)
{
reduce_func
=
sum
;
default_val
=
0
;
}
else
if
(
type
==
ReduceFuncType
::
MAX
)
{
reduce_func
=
max
;
default_val
=
-
std
::
numeric_limits
<
double
>::
infinity
();
}
// Some value is hanlded in local thread
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
// Reduce the value in local thread
volatile
double
val
=
lane_id
<
data_size
&&
mask
[
lane_id
]
volatile
double
val
=
?
static_cast
<
double
>
(
data_ptr
[
lane_id
])
lane_id
<
data_size
&&
mask
[
lane_id
]
?
static_cast
<
double
>
(
data_ptr
[
lane_id
])
:
default_val
;
:
static_cast
<
double
>
(
0
);
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
if
(
mask
[
i
])
{
if
(
mask
[
i
])
{
val
=
reduce_func
(
val
,
data_ptr
[
i
]);
val
=
reduce_func
(
val
,
data_ptr
[
i
]);
...
@@ -128,7 +151,7 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
...
@@ -128,7 +151,7 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
float
sum_Output_x_Grad
=
warp_reduce_on_shmem
(
float
sum_Output_x_Grad
=
warp_reduce_on_shmem
(
/*data ptr = */
comp_buf
,
/*data ptr = */
comp_buf
,
/*data size = */
data_size
,
/*data size = */
data_size
,
/*reduce func = */
sum
,
lane_id
);
/*reduce func = */
ReduceFuncType
::
SUM
,
lane_id
);
// In-place update
// In-place update
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
if
(
mask
)
{
if
(
mask
)
{
...
@@ -147,14 +170,16 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
...
@@ -147,14 +170,16 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
template
<
typename
DataType
>
template
<
typename
DataType
>
__device__
inline
void
apply_softmax_on_float
(
DataType
*
scores
,
int
data_size
,
int
lane_id
)
{
__device__
inline
void
apply_softmax_on_float
(
DataType
*
scores
,
int
data_size
,
int
lane_id
)
{
// 1. compute the max of value
// 1. compute the max of value
float
max_val
=
static_cast
<
float
>
(
warp_reduce_on_shmem
(
scores
,
data_size
,
max
,
lane_id
));
float
max_val
=
static_cast
<
float
>
(
warp_reduce_on_shmem
(
scores
,
data_size
,
ReduceFuncType
::
MAX
,
lane_id
));
// 2. value -> exp_value
// 2. value -> exp_value
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
scores
[
i
]
=
static_cast
<
float
>
(
exp
(
static_cast
<
float
>
(
scores
[
i
])
-
max_val
));
scores
[
i
]
=
static_cast
<
float
>
(
exp
(
static_cast
<
float
>
(
scores
[
i
])
-
max_val
));
}
}
__syncwarp
();
__syncwarp
();
// 3. compute the sum of exp_value
// 3. compute the sum of exp_value
float
sum_val
=
static_cast
<
float
>
(
warp_reduce_on_shmem
(
scores
,
data_size
,
sum
,
lane_id
));
float
sum_val
=
static_cast
<
float
>
(
warp_reduce_on_shmem
(
scores
,
data_size
,
ReduceFuncType
::
SUM
,
lane_id
));
// 4. update the softmax value
// 4. update the softmax value
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
scores
[
i
]
=
static_cast
<
float
>
(
scores
[
i
])
/
sum_val
;
scores
[
i
]
=
static_cast
<
float
>
(
scores
[
i
])
/
sum_val
;
...
@@ -165,19 +190,29 @@ __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, i
...
@@ -165,19 +190,29 @@ __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, i
template
<
typename
T
>
template
<
typename
T
>
__device__
inline
void
naive_topk_and_mask
(
T
*
scores
,
int
data_size
,
int
topk
,
int
*
topk_indices
,
__device__
inline
void
naive_topk_and_mask
(
T
*
scores
,
int
data_size
,
int
topk
,
int
*
topk_indices
,
T
*
topk_scores
,
int
lane_id
)
{
T
*
topk_scores
,
int
lane_id
)
{
// Check if the index is masked by the later iteration
auto
is_masked
=
[
&
topk_indices
](
int
k
,
int
index
)
{
if
(
k
==
0
)
return
false
;
for
(
int
i
=
0
;
i
<
k
;
i
++
)
{
if
(
topk_indices
[
i
]
==
index
)
return
true
;
}
return
false
;
};
// Topk Times: Find the max value and its index
// Topk Times: Find the max value and its index
// Then mask it, and record the index in the topk_indices
// Then mask it, and record the index in the topk_indices
// After looping topk times, the topk_indices will be the topk indices
// After looping topk times, the topk_indices will be the topk indices
for
(
int
k
=
0
;
k
<
topk
;
k
++
)
{
for
(
int
k
=
0
;
k
<
topk
;
k
++
)
{
// Find the max value and its index
// Find the max value and its index
volatile
double
val
=
volatile
double
val
=
(
lane_id
<
data_size
&&
!
is_masked
(
k
,
lane_id
))
(
lane_id
<
data_size
)
?
static_cast
<
double
>
(
scores
[
lane_id
])
:
static_cast
<
double
>
(
0
);
?
static_cast
<
double
>
(
scores
[
lane_id
])
:
-
std
::
numeric_limits
<
double
>::
infinity
();
volatile
int
index
=
(
lane_id
<
data_size
)
?
lane_id
:
0
;
volatile
int
index
=
(
lane_id
<
data_size
)
?
lane_id
:
0
;
// Some value is hanlded in local thread
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
// Reduce the value in local thread
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
volatile
double
cur_val
=
scores
[
i
];
volatile
double
cur_val
=
(
is_masked
(
k
,
i
))
?
-
std
::
numeric_limits
<
double
>::
infinity
()
:
static_cast
<
double
>
(
scores
[
i
]);
if
(
cur_val
>
val
)
{
if
(
cur_val
>
val
)
{
val
=
cur_val
;
val
=
cur_val
;
index
=
i
;
index
=
i
;
...
@@ -200,17 +235,9 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
...
@@ -200,17 +235,9 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
{
topk_indices
[
k
]
=
index
;
topk_indices
[
k
]
=
index
;
topk_scores
[
k
]
=
val
;
topk_scores
[
k
]
=
val
;
scores
[
index
]
=
static_cast
<
double
>
(
-
1.0
)
-
val
;
// make the selected experts using val = - 1 - val
}
}
__syncwarp
();
__syncwarp
();
}
}
// Reset the scores to the original value
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
kThreadsPerWarp
)
{
scores
[
topk_indices
[
i
]]
=
static_cast
<
double
>
(
-
1.0
)
-
static_cast
<
double
>
(
scores
[
topk_indices
[
i
]]);
}
}
}
// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
87e3e56e
...
@@ -253,8 +253,9 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas
...
@@ -253,8 +253,9 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
cublasOperation_t
transa
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
cublasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
float
alpha
,
float
beta
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
cudaStream_t
stream
)
{
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
cudaStream_t
stream
)
{
// Tensor dims in row-major order
// Tensor dims in row-major order
const
int
A0
=
inputA
->
flat_first_dim
();
const
int
A0
=
inputA
->
flat_first_dim
();
const
int
A1
=
inputA
->
flat_last_dim
();
const
int
A1
=
inputA
->
flat_last_dim
();
...
@@ -310,13 +311,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -310,13 +311,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"fp8 Aux output for gemm + gelu fusion not supported!"
);
"fp8 Aux output for gemm + gelu fusion not supported!"
);
}
}
if
(
is_fp8_dtype
(
outputD
->
data
.
dtype
))
{
if
(
is_fp8_dtype
(
outputD
->
data
.
dtype
))
{
NVTE_CHECK
(
!
accumulate
,
"Accumulation mode not supported with FP8 GEMM output!"
);
NVTE_CHECK
(
beta
==
0.0
f
,
"Accumulation mode not supported with FP8 GEMM output!"
);
}
}
float
one
=
1.0
;
float
zero
=
0.0
;
float
beta
=
(
accumulate
)
?
one
:
zero
;
cublasLtHandle_t
handle
=
cublasHandleManager
::
Instance
().
GetHandle
();
cublasLtHandle_t
handle
=
cublasHandleManager
::
Instance
().
GetHandle
();
cublasLtMatmulDesc_t
operationDesc
=
nullptr
;
cublasLtMatmulDesc_t
operationDesc
=
nullptr
;
...
@@ -601,7 +598,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -601,7 +598,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// D = alpha * (A * B) + beta * C
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS
(
cublasLtMatmul
(
handle
,
operationDesc
,
NVTE_CHECK_CUBLAS
(
cublasLtMatmul
(
handle
,
operationDesc
,
static_cast
<
const
void
*>
(
&
one
),
/* alpha */
static_cast
<
const
void
*>
(
&
alpha
),
/* alpha */
param
.
A
,
/* A */
param
.
A
,
/* A */
Adesc
,
param
.
B
,
/* B */
Adesc
,
param
.
B
,
/* B */
Bdesc
,
static_cast
<
const
void
*>
(
&
beta
),
/* beta */
Bdesc
,
static_cast
<
const
void
*>
(
&
beta
),
/* beta */
...
@@ -752,8 +749,27 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -752,8 +749,27 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#else
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
1.0
f
,
(
accumulate
)
?
1.0
f
:
0.0
f
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
#endif //__HIP_PLATFORM_AMD__
nullptr
,
stream
);
#endif //__HIP_PLATFORM_AMD__
}
void
nvte_cublas_gemm_scaled
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
float
alpha
,
float
beta
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublas_gemm_scaled
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
Tensor
*
outputD
=
convertNVTETensor
(
D
);
const
Tensor
*
biasTensor
=
convertNVTETensor
(
bias
);
Tensor
*
outputGelu
=
convertNVTETensor
(
pre_gelu_out
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
alpha
,
beta
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
}
}
void
nvte_cublas_atomic_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
void
nvte_cublas_atomic_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
...
@@ -846,8 +862,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -846,8 +862,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#else
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
1.0
f
,
(
accumulate
)
?
1.0
f
:
0.0
f
,
use_split_accumulator
,
math_sm_count
,
m_split
,
inputCounter
,
stream
);
n_split
,
gemm_producer
,
inputCounter
,
stream
);
#endif //__HIP_PLATFORM_AMD__
#endif //__HIP_PLATFORM_AMD__
}
}
...
...
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment