Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
44740c6c
Commit
44740c6c
authored
Jul 22, 2025
by
yuguo
Browse files
Merge commit '
7a9a0825
' of...
Merge commit '
7a9a0825
' of
https://github.com/NVIDIA/TransformerEngine
parents
8113d9e0
7a9a0825
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3592 additions
and
225 deletions
+3592
-225
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+260
-8
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+1154
-0
tests/pytorch/test_parallel_cross_entropy.py
tests/pytorch/test_parallel_cross_entropy.py
+11
-7
tests/pytorch/test_permutation.py
tests/pytorch/test_permutation.py
+251
-86
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+38
-1
tests/pytorch/utils.py
tests/pytorch/utils.py
+1
-0
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+6
-0
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
...common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
+10
-0
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+5
-1
transformer_engine/common/fused_attn/kv_cache.cu
transformer_engine/common/fused_attn/kv_cache.cu
+76
-40
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
+290
-0
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
...ngine/common/fused_router/fused_score_for_moe_aux_loss.cu
+324
-0
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
...ine/common/fused_router/fused_topk_with_score_function.cu
+497
-0
transformer_engine/common/fused_router/utils.h
transformer_engine/common/fused_router/utils.h
+226
-0
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+79
-30
transformer_engine/common/include/transformer_engine/fused_router.h
...r_engine/common/include/transformer_engine/fused_router.h
+132
-0
transformer_engine/common/include/transformer_engine/multi_stream.h
...r_engine/common/include/transformer_engine/multi_stream.h
+22
-0
transformer_engine/common/include/transformer_engine/padding.h
...former_engine/common/include/transformer_engine/padding.h
+27
-0
transformer_engine/common/multi_tensor/adam.cu
transformer_engine/common/multi_tensor/adam.cu
+182
-51
transformer_engine/common/multi_tensor/multi_tensor_apply.cuh
...sformer_engine/common/multi_tensor/multi_tensor_apply.cuh
+1
-1
No files found.
tests/pytorch/test_numerics.py
View file @
44740c6c
...
...
@@ -106,7 +106,7 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types
=
[
"causal"
,
"no_mask"
]
NVTE_TEST_NVINSPECT_ENABLED
=
os
.
environ
.
get
(
"NVTE_TEST_NVINSPECT_ENABLED"
,
False
)
NVTE_TEST_NVINSPECT_ENABLED
=
int
(
os
.
environ
.
get
(
"NVTE_TEST_NVINSPECT_ENABLED"
,
"0"
)
)
if
NVTE_TEST_NVINSPECT_ENABLED
:
# The numerics of all the layers should work the same,
...
...
@@ -1059,8 +1059,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
assert_allclose
(
te_output
,
torch_output
,
atol
[
dtype
],
rtol
[
dtype
])
def
_test_granular_accuracy
(
block
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
False
):
def
_test_granular_accuracy
(
block
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
False
,
recipe
=
None
):
reset_rng_states
()
fp8
=
recipe
is
not
None
if
fp8
:
FP8GlobalStateManager
.
reset
()
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq_len
,
bs
,
config
.
hidden_size
),
...
...
@@ -1070,9 +1073,10 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False)
)
inp_hidden_states
.
retain_grad
()
out
=
block
(
inp_hidden_states
)
if
isinstance
(
out
,
(
List
,
Tuple
)):
out
=
out
[
0
]
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
out
=
block
(
inp_hidden_states
)
if
isinstance
(
out
,
(
List
,
Tuple
)):
out
=
out
[
0
]
loss
=
out
.
sum
()
loss
.
backward
()
if
delay_wgrad_compute
:
...
...
@@ -1268,6 +1272,64 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
+
[
None
])
def
test_linear_accuracy_save_original_input
(
dtype
,
model
,
recipe
):
bs
=
1
fuse_wgrad_accumulation
=
True
fp8_model_params
=
False
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
config
=
model_configs
[
model
]
if
config
.
seq_len
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
te_linear_ref
=
Linear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
False
,
params_dtype
=
dtype
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
save_original_input
=
False
,
).
eval
()
te_linear
=
Linear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
False
,
params_dtype
=
dtype
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
save_original_input
=
True
,
).
eval
()
# Share params
with
torch
.
no_grad
():
te_linear_ref
.
weight
=
Parameter
(
te_linear
.
weight
.
clone
())
if
fuse_wgrad_accumulation
:
weight
=
getattr
(
te_linear
,
f
"weight"
)
weight
.
main_grad
=
torch
.
rand_like
(
weight
,
dtype
=
torch
.
float32
)
te_linear_ref
.
weight
.
main_grad
=
weight
.
main_grad
.
clone
()
te_outputs
=
_test_granular_accuracy
(
te_linear
,
bs
,
dtype
,
config
,
recipe
=
recipe
)
te_outputs_ref
=
_test_granular_accuracy
(
te_linear_ref
,
bs
,
dtype
,
config
,
recipe
=
recipe
)
# Shoule be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
te_outputs
,
te_outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
...
...
@@ -1768,6 +1830,111 @@ def test_grouped_linear_accuracy(
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
delay_wgrad_compute
=
delay_wgrad_compute
,
save_original_input
=
False
,
).
eval
()
sequential_linear
=
torch
.
nn
.
ModuleList
(
[
Linear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
bias
,
params_dtype
=
dtype
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
).
eval
()
for
_
in
range
(
num_gemms
)
]
)
# Share params
with
torch
.
no_grad
():
for
i
in
range
(
num_gemms
):
sequential_linear
[
i
].
weight
=
Parameter
(
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
).
clone
())
if
bias
:
sequential_linear
[
i
].
bias
=
Parameter
(
getattr
(
grouped_linear
,
f
"bias
{
i
}
"
).
clone
())
if
fuse_wgrad_accumulation
:
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
outputs_ref
=
_test_grouped_linear_accuracy
(
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
)
outputs
=
_test_grouped_linear_accuracy
(
grouped_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
)
# Shoule be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
outputs
,
outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"num_gemms"
,
[
3
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
+
[
None
])
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"delay_wgrad_compute"
,
[
True
])
def
test_grouped_linear_accuracy_save_original_input
(
dtype
,
num_gemms
,
bs
,
model
,
recipe
,
fp8_model_params
,
fuse_wgrad_accumulation
,
bias
,
delay_wgrad_compute
,
parallel_mode
=
None
,
):
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
config
=
model_configs
[
model
]
if
config
.
seq_len
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
bias
,
params_dtype
=
dtype
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
delay_wgrad_compute
=
delay_wgrad_compute
,
save_original_input
=
True
,
).
eval
()
sequential_linear
=
torch
.
nn
.
ModuleList
(
[
...
...
@@ -1948,7 +2115,89 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
def
test_padding_grouped_linear_accuracy
(
dtype
,
num_gemms
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
parallel_mode
=
None
dtype
,
num_gemms
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
parallel_mode
=
None
,
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
if
config
.
seq_len
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
TorchGroupedLinearWithPadding
(
num_gemms
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
False
,
params_dtype
=
dtype
,
parallel_mode
=
parallel_mode
,
fp8
=
fp8
,
).
eval
()
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
ref_grouped_linear
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
False
,
params_dtype
=
dtype
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
save_original_input
=
False
,
).
eval
()
# Share params
with
torch
.
no_grad
():
inner_grouped_linear
=
grouped_linear
.
linear_fn
for
i
in
range
(
num_gemms
):
setattr
(
ref_grouped_linear
,
f
"weight
{
i
}
"
,
Parameter
(
getattr
(
inner_grouped_linear
,
f
"weight
{
i
}
"
).
clone
()),
)
outputs
=
_test_padding_grouped_linear_accuracy
(
grouped_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
)
outputs_ref
=
_test_padding_grouped_linear_accuracy
(
ref_grouped_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
)
# Shoule be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
outputs
,
outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"num_gemms"
,
[
3
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"fp8"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
[
False
])
def
test_padding_grouped_linear_accuracy_save_original_input
(
dtype
,
num_gemms
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
parallel_mode
=
None
,
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
...
...
@@ -1958,6 +2207,8 @@ def test_padding_grouped_linear_accuracy(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
config
=
model_configs
[
model
]
if
config
.
seq_len
%
16
!=
0
and
fp8
:
...
...
@@ -1983,6 +2234,7 @@ def test_padding_grouped_linear_accuracy(
params_dtype
=
dtype
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
save_original_input
=
True
,
).
eval
()
# Share params
...
...
@@ -2334,9 +2586,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
if
(
backend
==
"FusedAttention"
and
get_device_compute_capability
()
==
(
8
,
9
)
and
get_cudnn_version
()
<
(
9
,
1
1
,
0
)
and
get_cudnn_version
()
<
(
9
,
1
2
,
0
)
):
pytest
.
skip
(
"Skip KV cache for sm89 and cuDNN < 9.1
1
"
)
pytest
.
skip
(
"Skip KV cache for sm89 and cuDNN < 9.1
2
"
)
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
...
...
tests/pytorch/test_onnx_export.py
0 → 100644
View file @
44740c6c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This file contains tests for exporting TransformerEngine models to ONNX.
The purpose of these tests is validation that TE models are converted to their correct ONNX
representation. Toward this end, each test captures the output of a TE module forward pass,
converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and
validate the output against TE's output.
Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented
using custom ORT operations.
To run many repetitive tests use pytest-loop:
$ python3 -m pip install pytest-loop
$ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm
For reproducibility use: torch.manual_seed(0)
"""
import
os
import
tempfile
import
pytest
import
warnings
import
numpy
as
np
import
onnxruntime
as
ort
import
torch
import
random
from
torch
import
nn
as
nn
from
typing
import
Optional
,
Union
,
Tuple
,
List
from
onnxruntime_extensions
import
PyCustomOpDef
,
get_library_path
,
onnx_op
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
,
te_translation_table
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
get_default_init_method
# Global test configuration knobs.
# Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance).
SAVE_TEST_IO
=
bool
(
int
(
os
.
getenv
(
"NVTE_ONNX_EXPORT_SAVE_TEST_IO"
,
"0"
)))
if
SAVE_TEST_IO
:
from
polygraphy.json
import
save_json
from
polygraphy.comparator
import
RunResults
# The directory where generated ONNX test models are stored.
NVTE_TEST_ARTIFACTS_DIR
=
os
.
environ
.
get
(
"NVTE_TEST_ARTIFACTS_DIR"
)
NVTE_TEST_ARTIFACTS_DIR
=
NVTE_TEST_ARTIFACTS_DIR
or
os
.
path
.
join
(
tempfile
.
gettempdir
(),
"./gen_onnx_models"
)
# The directory where this file is stored.
TESTS_DIR
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
skip_FP8
=
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
skip_MXFP8
=
pytest
.
mark
.
skipif
(
not
mxfp8_available
,
reason
=
reason_for_no_mxfp8
)
fp8_recipes
=
[
None
,
recipe
.
DelayedScaling
(),
recipe
.
MXFP8BlockScaling
(),
]
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
all_normalizations
=
[
"LayerNorm"
,
"RMSNorm"
]
@
onnx_op
(
op_type
=
"trt::TRT_FP8QuantizeLinear"
,
domain
=
"trt"
,
inputs
=
[
PyCustomOpDef
.
dt_float
,
PyCustomOpDef
.
dt_float
,
],
outputs
=
[
PyCustomOpDef
.
dt_uint8
],
)
def
trt_fp8_quantize
(
t
,
scale
):
"""FP8 quantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
return
q
(
x
).
_data
.
cpu
().
numpy
()
@
onnx_op
(
op_type
=
"trt::TRT_FP8DequantizeLinear"
,
domain
=
"trt"
,
inputs
=
[
PyCustomOpDef
.
dt_uint8
,
PyCustomOpDef
.
dt_float
,
],
outputs
=
[
PyCustomOpDef
.
dt_float
],
)
def
trt_fp8_dequantize
(
t
,
scale
):
"""FP8 dequantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
quantizer_tensor
=
q
.
create_tensor_from_data
(
x
,
fake_dtype
=
torch
.
float32
)
return
quantizer_tensor
.
dequantize
().
cpu
().
numpy
()
@
onnx_op
(
op_type
=
"trt::TRT_MXFP8QuantizeLinear"
,
domain
=
"trt"
,
inputs
=
[
PyCustomOpDef
.
dt_float
,
],
outputs
=
[
PyCustomOpDef
.
dt_uint8
,
PyCustomOpDef
.
dt_uint8
],
)
def
trt_mxfp8_quantize
(
t
):
"""MXFP8 quantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
mxfp8_tensor
.
MXFP8Quantizer
(
tex
.
DType
.
kFloat8E4M3
)
return
q
(
x
).
_rowwise_data
.
cpu
().
numpy
(),
q
(
x
).
_rowwise_scale_inv
.
cpu
().
numpy
()
@
onnx_op
(
op_type
=
"trt::TRT_MXFP8DequantizeLinear"
,
domain
=
"trt"
,
inputs
=
[
PyCustomOpDef
.
dt_uint8
,
PyCustomOpDef
.
dt_uint8
,
],
outputs
=
[
PyCustomOpDef
.
dt_float
],
)
def
trt_mxfp8_dequantize
(
t
,
scale_inv
):
"""MXFP8 dequantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
scale_inv_tensor
=
torch
.
from_numpy
(
scale_inv
).
cuda
()
q
=
te
.
tensor
.
mxfp8_tensor
.
MXFP8Quantizer
(
tex
.
DType
.
kFloat8E4M3
)
quantizer_tensor
=
q
.
create_tensor_from_data
(
x
,
scale_inv_tensor
,
fake_dtype
=
torch
.
float32
)
return
quantizer_tensor
.
dequantize
().
cpu
().
numpy
()
@
pytest
.
fixture
()
def
seed_default_rng
():
"""Reseed the PRNG for test reproducibility"""
torch
.
manual_seed
(
1234
)
@
pytest
.
fixture
()
def
set_max_seq_len
(
max_seq_len
=
128
):
"""Set the maximum sequence length that can be used for attention masking"""
os
.
environ
[
"NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"
]
=
f
"
{
max_seq_len
}
"
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_global_fp8_state
():
yield
FP8GlobalStateManager
.
reset
()
def
do_export
(
model
:
torch
.
nn
.
Module
,
inp
:
torch
.
Tensor
,
fname
:
str
,
fp8_recipe
:
recipe
.
Recipe
,
input_names
:
List
[
str
]
=
None
,
output_names
:
List
[
str
]
=
None
,
dynamic_shapes
:
List
[
str
]
=
None
,
):
"""Export to ONNX"""
input_names
=
input_names
or
[
"input"
]
output_names
=
output_names
or
[
"output"
]
with
torch
.
inference_mode
(),
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
),
warnings
.
catch_warnings
():
warnings
.
filterwarnings
(
action
=
"ignore"
,
category
=
torch
.
jit
.
TracerWarning
,
module
=
r
".*"
)
model
.
cuda
().
eval
()
os
.
makedirs
(
NVTE_TEST_ARTIFACTS_DIR
,
exist_ok
=
True
)
fname
=
os
.
path
.
join
(
NVTE_TEST_ARTIFACTS_DIR
,
fname
)
inps
=
inp
if
isinstance
(
inp
,
list
)
or
isinstance
(
inp
,
tuple
)
else
(
inp
,)
assert
len
(
inps
)
==
len
(
input_names
)
inds_to_del
=
[
i
for
i
in
range
(
len
(
inps
))
if
inps
[
i
]
is
None
]
input_names
=
[
input_names
[
i
]
for
i
in
range
(
len
(
inps
))
if
i
not
in
inds_to_del
]
model
(
*
inps
)
# warm-up run
with
te
.
export
.
onnx_export
(
True
):
model
(
*
inps
)
with
te
.
export
.
onnx_export
(
True
):
torch
.
onnx
.
export
(
model
,
inps
,
fname
,
dynamo
=
True
,
custom_translation_table
=
te_translation_table
,
verbose
=
True
,
dynamic_shapes
=
dynamic_shapes
,
input_names
=
input_names
,
output_names
=
output_names
,
optimize
=
inps
[
0
].
dtype
!=
torch
.
bfloat16
,
# optimizer does not work with bfloat16 yet - will need to change that after onnxscript supports bfloat16
)
def
to_numpy
(
tensor
):
if
isinstance
(
tensor
,
torch
.
Tensor
):
if
tensor
.
dtype
==
torch
.
bfloat16
:
tensor
=
tensor
.
type
(
torch
.
float32
)
tensor
=
tensor
.
detach
().
cpu
().
numpy
()
return
tensor
def
set_layer_scale
(
module
:
torch
.
nn
.
Module
,
scale
:
float
,
num_gemms
:
int
):
"""Initialize the FP8 quantization scales in module"""
module
.
init_fp8_metadata
(
num_gemms
)
for
quantizer
in
module
.
quantizers
[
"scaling_fwd"
]:
quantizer
.
scale
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
scale
def
te_infer
(
model
:
torch
.
nn
.
Module
,
inps
:
Union
[
Tuple
[
torch
.
Tensor
],
torch
.
Tensor
],
is_fp8
:
bool
,
fp8_recipe
:
recipe
.
Recipe
,
):
"""Transformer Engine forward propagation."""
with
torch
.
inference_mode
(),
te
.
fp8_autocast
(
enabled
=
is_fp8
,
fp8_recipe
=
fp8_recipe
),
warnings
.
catch_warnings
():
te_outputs
=
model
(
*
inps
if
isinstance
(
inps
,
tuple
)
else
(
inps
,))
if
not
isinstance
(
te_outputs
,
tuple
):
te_outputs
=
(
te_outputs
,)
return
te_outputs
def
compare_outputs
(
onnx_outputs
,
te_outputs
,
atol
,
rtol
,
max_errors_printed
,
allow_cnt_errors
,
fname
):
"""Compare ORT and TE outputs."""
assert
len
(
onnx_outputs
)
==
len
(
te_outputs
)
# Compare ORT and PyTorch outputs.
for
onnx_output
,
te_output
in
zip
(
onnx_outputs
,
te_outputs
):
# np.isclose: abs(a - b) <= (atol + rtol * abs(b))
te_output
=
to_numpy
(
te_output
)
onnx_output
=
to_numpy
(
onnx_output
)
ac
=
~
np
.
isclose
(
onnx_output
,
te_output
,
atol
=
atol
,
rtol
=
rtol
)
mismatches
=
ac
.
nonzero
()
mismatched_ids
=
[
loc
for
loc
in
zip
(
*
mismatches
)]
if
mismatched_ids
:
# Log some information in case of error.
print
(
"*"
*
100
)
nb_errors
=
len
(
mismatched_ids
)
nb_vals
=
min
(
nb_errors
,
max_errors_printed
)
print
(
f
"Detected
{
nb_errors
}
diverging values (output shape=
{
onnx_output
.
shape
}
)"
)
print
(
f
"Showing first
{
nb_vals
}
errors (ONNX -- TE):"
)
abs_err
=
np
.
abs
(
onnx_output
-
te_output
)
errors
=
abs_err
[
mismatches
]
for
loc
in
mismatched_ids
[:
nb_vals
]:
ref
=
te_output
[
loc
]
print
(
f
"
{
onnx_output
[
loc
]
}
--
{
te_output
[
loc
]
}
err=
{
abs_err
[
loc
]
}
>"
f
"
{
atol
+
rtol
*
abs
(
ref
)
}
"
)
print
(
f
"Max error:
{
np
.
max
(
errors
)
}
"
)
if
nb_errors
>
allow_cnt_errors
:
raise
ValueError
(
f
"Output validation of
{
fname
}
failed with
{
nb_errors
}
errors"
)
def
serialize_inputs_outputs
(
fname
:
str
,
inputs
:
Union
[
Tuple
[
torch
.
Tensor
],
torch
.
Tensor
],
te_outputs
:
List
[
torch
.
Tensor
],
input_names
:
Optional
[
List
[
str
]]
=
None
,
output_names
:
Optional
[
List
[
str
]]
=
None
,
):
if
not
SAVE_TEST_IO
:
return
fname
=
os
.
path
.
join
(
NVTE_TEST_ARTIFACTS_DIR
,
fname
)
input_names
=
input_names
or
[
"input"
]
output_names
=
output_names
or
[
"output"
]
inputs
=
inputs
if
isinstance
(
inputs
,
list
)
or
isinstance
(
inputs
,
tuple
)
else
(
inputs
,)
named_inputs
=
zip
(
input_names
,
inputs
)
input_data
=
[{
k
:
v
.
cpu
()
for
k
,
v
in
named_inputs
if
v
is
not
None
}]
json_fname
=
fname
[:
-
len
(
".onnx"
)]
+
"_inputs.json"
save_json
(
input_data
,
json_fname
,
description
=
"custom input data"
)
json_fname
=
fname
[:
-
len
(
".onnx"
)]
+
"_output.json"
named_outputs
=
zip
(
output_names
,
te_outputs
)
output_data
=
{
k
:
v
.
detach
().
cpu
()
for
k
,
v
in
named_outputs
if
v
is
not
None
}
custom_outputs
=
RunResults
()
custom_outputs
.
add
([
output_data
],
runner_name
=
"custom_runner"
)
custom_outputs
.
save
(
json_fname
)
def
validate_result
(
fname
:
str
,
inps
:
Union
[
Tuple
[
torch
.
Tensor
],
torch
.
Tensor
],
model
:
torch
.
nn
.
Module
,
atol
:
float
=
1.0e-8
,
# np.isclose default atol
rtol
:
float
=
1.0e-5
,
# np.isclose default rtol
max_errors_printed
:
int
=
10
,
is_fp8
:
bool
=
False
,
allow_cnt_errors
:
int
=
0
,
input_names
:
List
[
str
]
=
None
,
output_names
:
List
[
str
]
=
None
,
te_outputs
:
List
[
torch
.
Tensor
]
=
None
,
):
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close.
The purpose of the output comparison is to validate that TE models are converted to
their correct ONNX representation by testing that TE and ORT outputs match within some
small threshold (allowing for finite precision errors).
Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring,
a very small number (0-3) of outliers. This is fine to do because these outliers are due to
small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX
representation (the tests assume both ORT or TE kernels are correct).
Argument `te_outputs` can be used to provide pre-computed TE outputs.
"""
def
create_ort_session
(
fname
:
str
,
is_fp8
:
bool
):
def
load_custom_ops
(
session_opts
:
ort
.
SessionOptions
):
"""For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension."""
session_opts
.
register_custom_ops_library
(
get_library_path
())
print
(
"registered custom FP8 Q/DQ ops!"
)
"""Create an ONNX Runtime session for validation."""
kwargs
=
{
"providers"
:
[
"CUDAExecutionProvider"
,
"CPUExecutionProvider"
]}
if
is_fp8
:
sess_options
=
ort
.
SessionOptions
()
load_custom_ops
(
sess_options
)
kwargs
[
"sess_options"
]
=
sess_options
s
=
ort
.
InferenceSession
(
fname
,
**
kwargs
)
return
s
def
create_ort_input_dict
(
session
,
inputs
):
inputs
=
inputs
if
isinstance
(
inputs
,
list
)
or
isinstance
(
inputs
,
tuple
)
else
(
inputs
,)
input_names
=
[
x
.
name
for
x
in
session
.
get_inputs
()]
inps
=
[
to_numpy
(
x
)
for
x
in
inputs
if
x
is
not
None
]
inp_dict
=
dict
(
zip
(
input_names
,
inps
))
return
inp_dict
input_names
=
input_names
or
[
"input"
]
output_names
=
output_names
or
[
"output"
]
# Run ORT session and TE model.
fname
=
os
.
path
.
join
(
NVTE_TEST_ARTIFACTS_DIR
,
fname
)
if
not
te_outputs
:
te_outputs
=
te_infer
(
model
,
inps
,
is_fp8
)
ort_s
=
create_ort_session
(
fname
,
is_fp8
)
input_feed
=
create_ort_input_dict
(
ort_s
,
inps
)
onnx_outputs
=
ort_s
.
run
(
None
,
input_feed
=
input_feed
)
compare_outputs
(
onnx_outputs
,
te_outputs
,
atol
,
rtol
,
max_errors_printed
,
allow_cnt_errors
,
fname
)
def
create_meta
(
scale_factor
:
float
,
size
:
int
=
1
):
meta
=
tex
.
FP8TensorMeta
()
meta
.
amax_history
=
torch
.
zeros
(
1
,
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
meta
.
scale_inv
=
torch
.
ones
(
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
/
scale_factor
meta
.
scale
=
torch
.
ones
(
size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
scale_factor
return
meta
def
dtype2str
(
dtype
:
torch
.
dtype
,
fake_bf16_io
=
False
):
if
fake_bf16_io
:
assert
dtype
==
torch
.
bfloat16
return
"_fake_bf16"
return
{
torch
.
float32
:
"_fp32"
,
torch
.
float16
:
"_fp16"
,
torch
.
bfloat16
:
"_bf16"
,
}[
dtype
]
def
as_te_type
(
dtype
:
torch
.
dtype
):
return
{
torch
.
float32
:
tex
.
DType
.
kFloat32
,
torch
.
float16
:
tex
.
DType
.
kFloat16
,
torch
.
bfloat16
:
tex
.
DType
.
kBFloat16
,
}[
dtype
]
def
get_attn_mask_str
(
use_mask
,
attn_mask_type
):
# See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
if
attn_mask_type
is
None
:
return
"_mask"
if
use_mask
else
"_no-mask"
attn_mask_str
=
"_arbitrary-no-mask"
attn_mask_str
=
"_causal-mask"
if
attn_mask_type
==
"causal"
else
attn_mask_str
attn_mask_str
=
(
"_arbitrary-mask"
if
use_mask
and
attn_mask_type
==
"arbitrary"
else
attn_mask_str
)
return
attn_mask_str
"""
Test cases begin here.
"""
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
# Returning the bias is a TE fusion optimization we don't care about.
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"precision, use_bias"
,
[
(
torch
.
float32
,
False
),
(
torch
.
float32
,
True
),
(
torch
.
float16
,
False
),
(
torch
.
float16
,
True
),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(
torch
.
bfloat16
,
False
),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
(
torch
.
bfloat16
,
True
),
],
)
def
test_export_linear
(
seed_default_rng
,
scale_factor
:
float
,
fp8_recipe
:
recipe
.
Recipe
,
use_bias
:
bool
,
return_bias
:
bool
,
precision
:
torch
.
dtype
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
return_bias
and
not
use_bias
:
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
# Set dimensions (these are arbitrary).
batch_size
=
4
in_features
=
64
out_features
=
64
hidden_size
=
64
class
Test_Linear
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
use_bias
,
return_bias
,
precision
):
super
().
__init__
()
self
.
linear
=
te
.
Linear
(
in_features
,
out_features
,
bias
=
use_bias
,
return_bias
=
return_bias
,
params_dtype
=
precision
,
)
def
forward
(
self
,
inp
):
ret
=
self
.
linear
(
inp
)
return
ret
inp
=
torch
.
randn
(
batch_size
,
hidden_size
,
in_features
,
device
=
"cuda"
,
dtype
=
precision
)
fp8_str
=
"_fp8"
if
fp8_recipe
is
not
None
else
""
bias_str
=
"_bias"
if
use_bias
else
""
high_prec_str
=
dtype2str
(
precision
)
fname
=
f
"te.linear
{
fp8_str
}{
bias_str
}{
high_prec_str
}
.onnx"
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
model
=
Test_Linear
(
in_features
,
out_features
,
use_bias
,
return_bias
,
precision
).
to
(
device
=
"cuda"
)
# dynamic shape
bs
=
torch
.
export
.
Dim
(
"bs"
,
min
=
2
,
max
=
1256
)
do_export
(
model
,
inp
,
fname
,
fp8_recipe
,
dynamic_shapes
=
{
"inp"
:
{
0
:
bs
}},
)
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
)
serialize_inputs_outputs
(
fname
,
inp
,
te_outputs
)
if
precision
in
(
torch
.
bfloat16
,):
return
if
fp8_recipe
is
None
:
validate_result
(
fname
,
inp
,
model
,
atol
=
1e-3
,
te_outputs
=
te_outputs
)
else
:
validate_result
(
fname
,
inp
,
model
,
atol
=
1e-2
,
is_fp8
=
fp8_recipe
is
not
None
,
te_outputs
=
te_outputs
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
,
],
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm
(
seed_default_rng
,
scale_factor
:
float
,
fp8_recipe
:
recipe
.
Recipe
,
precision
:
torch
.
dtype
,
zero_centered_gamma
:
bool
,
normalization
:
str
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
# Set dimensions (these are arbitrary).
batch_size
=
4
in_features
=
64
out_features
=
256
hidden_size
=
256
inp
=
torch
.
ones
(
batch_size
,
in_features
,
out_features
,
device
=
"cuda"
,
dtype
=
precision
)
fp8_str
=
"_fp8"
if
fp8_recipe
is
not
None
else
""
high_prec_str
=
dtype2str
(
precision
)
fname
=
f
"te.layernorm_linear
{
fp8_str
}{
high_prec_str
}
.onnx"
with
torch
.
no_grad
():
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
layernorm_cls
=
te
.
LayerNorm
if
normalization
==
"LayerNorm"
else
te
.
RMSNorm
model
=
layernorm_cls
(
hidden_size
,
params_dtype
=
precision
,
zero_centered_gamma
=
zero_centered_gamma
,
).
to
(
device
=
"cuda"
)
# dynamic shape
bs
=
torch
.
export
.
Dim
(
"bs"
,
min
=
2
,
max
=
1256
)
do_export
(
model
,
inp
,
fname
,
fp8_recipe
,
dynamic_shapes
=
{
"input"
:
{
0
:
bs
}})
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
)
serialize_inputs_outputs
(
fname
,
inp
,
te_outputs
)
if
precision
in
(
torch
.
bfloat16
,):
return
if
fp8_recipe
is
None
:
validate_result
(
fname
,
inp
,
model
,
atol
=
1e-3
,
te_outputs
=
te_outputs
)
elif
precision
!=
torch
.
bfloat16
:
validate_result
(
fname
,
inp
,
model
,
atol
=
1e-3
,
is_fp8
=
fp8_recipe
is
not
None
,
te_outputs
=
te_outputs
,
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"return_layernorm_output"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"precision, use_bias"
,
[
(
torch
.
float32
,
False
),
(
torch
.
float32
,
True
),
(
torch
.
float16
,
True
),
(
torch
.
float16
,
False
),
(
torch
.
bfloat16
,
True
),
(
torch
.
bfloat16
,
False
),
],
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm_linear
(
seed_default_rng
,
scale_factor
:
float
,
fp8_recipe
:
recipe
.
Recipe
,
use_bias
:
bool
,
return_bias
:
bool
,
return_layernorm_output
:
bool
,
precision
:
torch
.
dtype
,
zero_centered_gamma
:
bool
,
normalization
:
str
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
return_bias
and
not
use_bias
:
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
# Set dimensions (these are arbitrary).
in_features
=
64
out_features
=
256
hidden_size
=
256
inp
=
torch
.
randn
(
in_features
,
out_features
,
device
=
"cuda"
,
dtype
=
precision
)
fp8_str
=
"_fp8"
if
fp8_recipe
is
not
None
else
""
bias_str
=
"_bias"
if
use_bias
else
""
high_prec_str
=
dtype2str
(
precision
)
fname
=
f
"te.layernorm_linear
{
fp8_str
}{
bias_str
}{
high_prec_str
}
.onnx"
with
torch
.
no_grad
():
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
model
=
te
.
LayerNormLinear
(
hidden_size
,
3
*
hidden_size
,
bias
=
use_bias
,
return_bias
=
return_bias
,
return_layernorm_output
=
return_layernorm_output
,
params_dtype
=
precision
,
zero_centered_gamma
=
zero_centered_gamma
,
normalization
=
normalization
,
).
to
(
device
=
"cuda"
)
if
fp8_recipe
is
not
None
:
set_layer_scale
(
model
,
scale_factor
,
num_gemms
=
2
)
do_export
(
model
,
inp
,
fname
,
fp8_recipe
)
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
)
serialize_inputs_outputs
(
fname
,
inp
,
te_outputs
)
if
precision
in
(
torch
.
bfloat16
,):
return
if
fp8_recipe
is
None
:
validate_result
(
fname
,
inp
,
model
,
atol
=
1e-3
,
te_outputs
=
te_outputs
)
elif
precision
!=
torch
.
bfloat16
:
validate_result
(
fname
,
inp
,
model
,
atol
=
1e-3
,
is_fp8
=
fp8_recipe
is
not
None
,
te_outputs
=
te_outputs
,
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
112
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"return_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"return_layernorm_output"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"precision, use_bias"
,
[
(
torch
.
float32
,
False
),
(
torch
.
float32
,
True
),
(
torch
.
float16
,
True
),
(
torch
.
float16
,
False
),
(
torch
.
bfloat16
,
True
),
(
torch
.
bfloat16
,
False
),
],
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
def
test_export_layernorm_mlp
(
seed_default_rng
,
scale_factor
:
float
,
fp8_recipe
:
recipe
.
Recipe
,
use_bias
:
bool
,
return_bias
:
bool
,
return_layernorm_output
:
bool
,
precision
:
torch
.
dtype
,
zero_centered_gamma
:
bool
,
activation
:
str
,
normalization
:
str
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
return_bias
and
not
use_bias
:
pytest
.
skip
(
"Cannot return bias when bias is disabled"
)
# Set dimensions (these are arbitrary).
in_features
=
64
out_features
=
256
hidden_size
=
256
ffn_hidden_size
=
256
inp
=
torch
.
randn
(
in_features
,
out_features
,
device
=
"cuda"
,
dtype
=
precision
)
fp8_str
=
"_fp8"
if
fp8_recipe
is
not
None
else
""
bias_str
=
"_bias"
if
use_bias
else
""
high_prec_str
=
dtype2str
(
precision
)
fname
=
f
"te.layernorm_mlp
{
fp8_str
}{
bias_str
}{
high_prec_str
}
_
{
activation
}
.onnx"
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
model
=
te
.
LayerNormMLP
(
hidden_size
,
ffn_hidden_size
,
bias
=
use_bias
,
return_bias
=
return_bias
,
return_layernorm_output
=
return_layernorm_output
,
params_dtype
=
precision
,
zero_centered_gamma
=
zero_centered_gamma
,
activation
=
activation
,
normalization
=
normalization
,
).
to
(
device
=
"cuda"
)
if
fp8_recipe
is
not
None
:
set_layer_scale
(
model
,
scale_factor
,
num_gemms
=
2
)
do_export
(
model
,
inp
,
fname
,
fp8_recipe
)
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
)
serialize_inputs_outputs
(
fname
,
inp
,
te_outputs
)
if
precision
in
(
torch
.
bfloat16
,):
return
atol
=
(
2e-2
if
fp8_recipe
is
not
None
else
(
5e-1
if
activation
==
"swiglu"
else
1e-3
)
)
# TODO(pgadzinski) - check 2e-2
validate_result
(
fname
,
inp
,
model
,
atol
=
atol
,
is_fp8
=
fp8_recipe
is
not
None
,
te_outputs
=
te_outputs
)
@
pytest
.
mark
.
parametrize
(
"precision, use_mask, attn_mask_type"
,
[
(
torch
.
float32
,
True
,
"arbitrary"
),
# calls forward_torch_softmax (apply user mask)
(
torch
.
float32
,
False
,
"no_mask"
),
# calls forward_torch_softmax (apply no mask)
(
torch
.
float16
,
False
,
"causal"
),
# calls forward_torch_softmax (apply dynamic onnx mask)
(
torch
.
float16
,
True
,
"arbitrary"
),
# calls forward_torch_softmax (apply user mask)
(
torch
.
float16
,
False
,
"no_mask"
),
# calls forward_torch_softmax (apply no mask)
(
torch
.
bfloat16
,
False
,
"causal"
),
# calls forward_torch_softmax (apply dynamic onnx mask)
(
torch
.
bfloat16
,
True
,
"arbitrary"
),
# calls forward_torch_softmax (apply user mask)
(
torch
.
bfloat16
,
False
,
"no_mask"
),
# calls forward_torch_softmax (apply no mask)
],
)
def
test_export_core_attention
(
seed_default_rng
,
set_max_seq_len
,
precision
:
torch
.
dtype
,
use_mask
:
bool
,
attn_mask_type
:
str
,
):
# Set dimensions (these are arbitrary).
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
=
(
64
,
4
,
1
,
64
)
qkv_size
=
(
seq_len
,
batch_size
,
num_attention_heads
,
kv_channels
)
qkv_format
=
"sbhd"
query_layer
=
torch
.
randn
(
qkv_size
,
dtype
=
precision
,
device
=
"cuda"
)
key_layer
=
torch
.
randn
(
qkv_size
,
dtype
=
precision
,
device
=
"cuda"
)
value_layer
=
torch
.
randn
(
qkv_size
,
dtype
=
precision
,
device
=
"cuda"
)
input_names
=
[
"query"
,
"key"
,
"value"
,
"attention_mask"
]
attention_mask
=
None
if
use_mask
:
# Generate a random mask with 50% probability for 0 or 1.
probs
=
0.5
*
torch
.
ones
(
batch_size
,
1
,
1
,
seq_len
,
device
=
"cuda"
,
dtype
=
precision
)
attention_mask
=
torch
.
bernoulli
(
probs
).
to
(
"cuda"
,
dtype
=
torch
.
bool
)
inp
=
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
mask_str
=
get_attn_mask_str
(
use_mask
,
attn_mask_type
)
high_prec_str
=
dtype2str
(
precision
)
fname
=
f
"te.core_attention
{
mask_str
}{
high_prec_str
}
.onnx"
model
=
te
.
attention
.
DotProductAttention
(
num_attention_heads
=
num_attention_heads
,
kv_channels
=
kv_channels
,
attention_dropout
=
0.5
,
qkv_format
=
qkv_format
,
attn_mask_type
=
attn_mask_type
,
).
to
(
device
=
"cuda"
)
do_export
(
model
,
inp
,
fname
,
input_names
=
input_names
,
fp8_recipe
=
None
)
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
False
,
fp8_recipe
=
None
)
serialize_inputs_outputs
(
fname
,
inp
,
te_outputs
,
input_names
=
input_names
)
if
precision
in
(
torch
.
bfloat16
,):
return
validate_result
(
fname
,
inp
,
model
,
is_fp8
=
True
,
atol
=
1e-2
,
input_names
=
input_names
,
te_outputs
=
te_outputs
)
test_configs_multihead_attention
=
[
# "use_mask, attn_mask_type"
(
False
,
"no_mask"
),
# calls ScaledSoftmax
(
True
,
"arbitrary"
),
# calls ScaledMaskedSoftmax
]
test_configs_attention_type
=
[
# "input_layernorm, attention_type, fuse_qkv_params"
(
True
,
"self"
,
True
),
(
False
,
"self"
,
True
),
(
True
,
"self"
,
False
),
(
False
,
"self"
,
False
),
(
True
,
"cross"
,
True
),
(
False
,
"cross"
,
True
),
(
True
,
"cross"
,
False
),
(
False
,
"cross"
,
False
),
]
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"use_mask, attn_mask_type"
,
test_configs_multihead_attention
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"return_layernorm_output"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"input_layernorm, attention_type, fuse_qkv_params"
,
test_configs_attention_type
)
def
test_export_multihead_attention
(
seed_default_rng
,
set_max_seq_len
,
fp8_recipe
:
recipe
.
Recipe
,
use_mask
:
bool
,
attn_mask_type
:
str
,
precision
:
torch
.
dtype
,
return_layernorm_output
:
bool
,
input_layernorm
:
bool
,
attention_type
:
str
,
fuse_qkv_params
:
bool
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
hidden_size
=
256
sequence_length
=
128
batch_size
=
4
num_attention_heads
=
32
kv_channels
=
8
attention_dropout
=
0.1
layernorm_epsilon
=
1e-5
init_method
=
output_layer_init_method
=
get_default_init_method
()
attention_args
=
(
hidden_size
,
num_attention_heads
,
kv_channels
,
attention_dropout
,
layernorm_epsilon
,
init_method
,
output_layer_init_method
,
)
hidden_states_context
=
torch
.
randn
(
sequence_length
,
batch_size
,
hidden_size
,
dtype
=
precision
,
device
=
"cuda"
)
attention_mask
=
None
if
use_mask
and
attn_mask_type
!=
"causal"
:
# Generate a random mask with 50% probability for 0 or 1.
probs
=
0.5
*
torch
.
ones
(
batch_size
,
1
,
sequence_length
,
sequence_length
,
device
=
"cuda"
,
dtype
=
precision
)
attention_mask
=
torch
.
bernoulli
(
probs
).
to
(
"cuda"
,
dtype
=
torch
.
bool
)
encoder_output
=
None
if
attention_type
==
"cross"
:
encoder_output
=
torch
.
randn
(
sequence_length
,
batch_size
,
hidden_size
,
dtype
=
precision
,
device
=
"cuda"
)
fp8_str
=
"_fp8"
if
fp8_recipe
is
not
None
else
""
dtype_str
=
dtype2str
(
precision
)
attn_type_str
=
"_self-attention"
if
attention_type
==
"self"
else
"_cross-attention"
fuse_qkv_str
=
"_fused-qkv"
if
fuse_qkv_params
else
""
attn_mask_str
=
get_attn_mask_str
(
use_mask
,
attn_mask_type
)
input_ln_str
=
"_input-ln"
if
input_layernorm
else
""
fname
=
f
"te.multihead_attention
{
fp8_str
}{
attn_mask_str
}{
attn_type_str
}{
input_ln_str
}{
fuse_qkv_str
}{
dtype_str
}
.onnx"
model
=
te
.
MultiheadAttention
(
*
attention_args
,
attn_mask_type
=
attn_mask_type
,
params_dtype
=
precision
,
return_layernorm_output
=
return_layernorm_output
,
input_layernorm
=
input_layernorm
,
attention_type
=
attention_type
,
fuse_qkv_params
=
fuse_qkv_params
,
return_bias
=
True
,
).
to
(
device
=
"cuda"
)
inp_context
=
(
hidden_states_context
,
attention_mask
,
encoder_output
)
input_names
=
[
"hidden_states"
,
"attention_mask"
,
"encoder_output"
]
output_names
=
[
"attention_output"
,
"attention_bias"
]
seq
=
torch
.
export
.
Dim
(
"seq"
,
min
=
2
,
max
=
1256
)
bs
=
torch
.
export
.
Dim
(
"bs"
,
min
=
2
,
max
=
1256
)
do_export
(
model
,
inp_context
,
fname
,
fp8_recipe
,
input_names
=
input_names
,
output_names
=
output_names
,
dynamic_shapes
=
{
"hidden_states"
:
{
0
:
seq
,
1
:
bs
},
"attention_mask"
:
{
2
:
seq
,
0
:
bs
}
if
use_mask
else
None
,
"encoder_output"
:
{
0
:
seq
,
1
:
bs
}
if
attention_type
==
"cross"
else
None
,
},
)
te_outputs
=
te_infer
(
model
,
inp_context
,
is_fp8
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
)
serialize_inputs_outputs
(
fname
,
inp_context
,
te_outputs
,
input_names
=
input_names
,
output_names
=
output_names
)
if
precision
in
(
torch
.
bfloat16
,):
return
if
fp8_recipe
is
None
:
validate_result
(
fname
,
inp_context
,
model
,
atol
=
1e-3
,
input_names
=
input_names
,
output_names
=
output_names
,
te_outputs
=
te_outputs
,
)
else
:
validate_result
(
fname
,
inp_context
,
model
,
atol
=
1e-2
,
is_fp8
=
fp8_recipe
is
not
None
,
input_names
=
input_names
,
output_names
=
output_names
,
allow_cnt_errors
=
3
,
te_outputs
=
te_outputs
,
)
# In GPT generative phase (inference) the input sequence is smaller than the maximum
# allowed sequence length and we want to test this condition.
# Pretend that we're in generative phase when it makes sense (causal mask and self-attention).
is_generative_phase
=
attn_mask_type
==
"causal"
and
attention_type
==
"self"
if
is_generative_phase
:
seq_len_offset
=
8
hidden_states_generative
=
torch
.
randn
(
sequence_length
-
seq_len_offset
,
batch_size
,
hidden_size
,
dtype
=
precision
,
device
=
"cuda"
,
)
inp_generative
=
(
hidden_states_generative
,
attention_mask
,
encoder_output
)
if
fp8_recipe
is
None
:
validate_result
(
fname
,
inp_generative
,
model
,
atol
=
1e-3
,
input_names
=
input_names
,
output_names
=
output_names
,
)
else
:
validate_result
(
fname
,
inp_generative
,
model
,
atol
=
1e-2
,
is_fp8
=
fp8_recipe
is
not
None
,
input_names
=
input_names
,
output_names
=
output_names
,
allow_cnt_errors
=
3
,
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"use_mask, attn_mask_type"
,
test_configs_multihead_attention
)
@
pytest
.
mark
.
parametrize
(
"output_layernorm"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"fuse_qkv_params"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
supported_activations
)
def
test_export_transformer_layer
(
seed_default_rng
,
set_max_seq_len
,
fp8_recipe
:
recipe
.
Recipe
,
use_mask
:
bool
,
attn_mask_type
:
str
,
output_layernorm
:
bool
,
precision
:
torch
.
dtype
,
fuse_qkv_params
:
bool
,
zero_centered_gamma
:
bool
,
activation
:
str
,
):
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
# Layer configuration
hidden_size
=
64
sequence_length
=
128
batch_size
=
1
ffn_hidden_size
=
256
num_attention_heads
=
4
input_tensor
=
torch
.
rand
(
sequence_length
,
batch_size
,
hidden_size
,
dtype
=
precision
,
device
=
"cuda"
)
input_names
=
[
"input"
,
"attention_mask"
]
attention_mask
=
None
if
use_mask
and
attn_mask_type
!=
"causal"
:
# Generate a random mask with 50% probability for 0 or 1.
probs
=
0.5
*
torch
.
ones
(
batch_size
,
1
,
sequence_length
,
sequence_length
,
device
=
"cuda"
,
dtype
=
precision
)
attention_mask
=
torch
.
bernoulli
(
probs
).
to
(
"cuda"
,
dtype
=
torch
.
bool
)
inp
=
(
input_tensor
,
attention_mask
)
fp8_str
=
"_fp8"
if
fp8_recipe
is
not
None
else
""
fuse_qkv_params_str
=
"_fused-qkv"
if
fuse_qkv_params
else
""
high_prec_str
=
dtype2str
(
precision
)
attn_mask_str
=
get_attn_mask_str
(
use_mask
,
attn_mask_type
)
fname
=
f
"te.transformer_layer
{
fp8_str
}{
attn_mask_str
}{
fuse_qkv_params_str
}{
high_prec_str
}
_
{
activation
}
.onnx"
model
=
te
.
TransformerLayer
(
hidden_size
,
ffn_hidden_size
,
num_attention_heads
,
self_attn_mask_type
=
attn_mask_type
,
output_layernorm
=
output_layernorm
,
params_dtype
=
precision
,
fuse_qkv_params
=
fuse_qkv_params
,
zero_centered_gamma
=
zero_centered_gamma
,
activation
=
activation
,
).
to
(
device
=
"cuda"
)
do_export
(
model
,
inp
,
fname
,
fp8_recipe
,
input_names
=
input_names
)
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
)
serialize_inputs_outputs
(
fname
,
inp
,
te_outputs
,
input_names
=
input_names
,
)
if
precision
in
(
torch
.
bfloat16
,):
return
atol
=
5e-1
if
fp8_recipe
is
not
None
else
(
5e-1
if
activation
==
"swiglu"
else
5e-3
)
validate_result
(
fname
,
inp
,
model
,
atol
=
atol
,
is_fp8
=
fp8_recipe
is
not
None
,
input_names
=
input_names
,
te_outputs
=
te_outputs
,
)
@
skip_FP8
@
skip_MXFP8
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"precision"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
True
])
def
test_export_gpt_generation
(
seed_default_rng
,
set_max_seq_len
,
fp8_recipe
:
recipe
.
Recipe
,
precision
:
torch
.
dtype
,
zero_centered_gamma
:
bool
,
):
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
# Skip FP8 tests on non-hopper devices
if
fp8_recipe
is
not
None
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
and
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
# Layer configuration
hidden_size
=
64
sequence_length
=
128
batch_size
=
4
ffn_hidden_size
=
256
num_attention_heads
=
4
attention_mask
=
None
use_mask
=
True
attn_mask_type
=
"causal"
fuse_qkv_params
=
True
output_layernorm
=
False
fp8_str
=
"_fp8"
if
fp8_recipe
is
not
None
else
""
fuse_qkv_params_str
=
"_fused-qkv"
if
fuse_qkv_params
else
""
high_prec_str
=
dtype2str
(
precision
)
attn_mask_str
=
get_attn_mask_str
(
use_mask
,
attn_mask_type
)
fname
=
f
"te.transformer_layer_generative
{
fp8_str
}{
attn_mask_str
}{
fuse_qkv_params_str
}{
high_prec_str
}
.onnx"
model
=
te
.
TransformerLayer
(
hidden_size
,
ffn_hidden_size
,
num_attention_heads
,
self_attn_mask_type
=
attn_mask_type
,
output_layernorm
=
output_layernorm
,
params_dtype
=
precision
,
fuse_qkv_params
=
fuse_qkv_params
,
zero_centered_gamma
=
zero_centered_gamma
,
).
to
(
device
=
"cuda"
)
# "Context phase": use full input sequence length
input_names
=
[
"input"
]
output_names
=
[
"output"
]
input_tensor
=
torch
.
rand
(
sequence_length
,
batch_size
,
hidden_size
,
dtype
=
precision
,
device
=
"cuda"
)
inp
=
(
input_tensor
,)
# dynamic shape
seq
=
torch
.
export
.
Dim
(
"seq"
,
min
=
2
,
max
=
1256
)
bs
=
torch
.
export
.
Dim
(
"bs"
,
min
=
2
,
max
=
1256
)
do_export
(
model
,
inp
,
fname
,
fp8_recipe
,
dynamic_shapes
=
{
"hidden_states"
:
{
0
:
seq
,
1
:
bs
}},
)
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
)
serialize_inputs_outputs
(
fname
,
inp
,
te_outputs
,
input_names
=
input_names
,
output_names
=
output_names
)
if
precision
not
in
(
torch
.
bfloat16
,):
validate_result
(
fname
,
inp
,
model
,
atol
=
1e-2
,
is_fp8
=
fp8_recipe
is
not
None
,
input_names
=
input_names
,
te_outputs
=
te_outputs
,
)
# "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8 and for MXFP8 we need to pad to mult of 32.
sequence_length
=
1
if
fp8_recipe
is
None
else
32
input_tensor
=
torch
.
rand
(
sequence_length
,
batch_size
,
hidden_size
,
dtype
=
precision
,
device
=
"cuda"
)
inp
=
(
input_tensor
,
attention_mask
)
te_outputs
=
te_infer
(
model
,
inp
,
is_fp8
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
)
serialize_inputs_outputs
(
fname
,
inp
,
te_outputs
,
input_names
=
input_names
)
if
precision
not
in
(
torch
.
bfloat16
,):
validate_result
(
fname
,
inp
,
model
,
atol
=
1e-2
,
is_fp8
=
fp8_recipe
is
not
None
,
input_names
=
input_names
,
te_outputs
=
te_outputs
,
)
@
pytest
.
mark
.
parametrize
(
"enabled"
,
[
True
,
False
])
def
test_export_ctx_manager
(
enabled
):
assert
is_in_onnx_export_mode
()
==
False
with
te
.
onnx_export
(
enabled
):
assert
is_in_onnx_export_mode
()
==
enabled
assert
is_in_onnx_export_mode
()
==
False
tests/pytorch/test_parallel_cross_entropy.py
View file @
44740c6c
...
...
@@ -61,22 +61,26 @@ class TestParallelCrossEntropy:
test_loss
=
self
.
test_loss_func
(
self
.
input_test
,
self
.
tar_test
,
label_smoothing
,
reduce_loss
,
None
)
if
reduce_loss
:
test_loss
.
backward
()
ref_loss
=
self
.
ref_loss_func
(
self
.
input_ref
,
self
.
tar_ref
)
# Handle backward pass based on the test scenario
if
reduce_loss
:
test_loss
.
backward
()
ref_loss
.
backward
()
else
:
test_loss
.
sum
().
backward
()
ref_loss
.
sum
().
backward
()
test_loss
=
torch
.
flatten
(
test_loss
)
if
not
reduce_loss
else
test_loss
torch
.
testing
.
assert_close
(
test_loss
,
ref_loss
,
check_dtype
=
False
)
if
ignore_idx
:
print
(
test_loss
,
ref_loss
)
if
reduce_loss
:
torch
.
testing
.
assert_close
(
torch
.
flatten
(
self
.
input_test
.
grad
,
start_dim
=
0
,
end_dim
=
1
),
self
.
input_ref
.
grad
)
# Compare gradients when backward pass was called
torch
.
testing
.
assert_close
(
torch
.
flatten
(
self
.
input_test
.
grad
,
start_dim
=
0
,
end_dim
=
1
),
self
.
input_ref
.
grad
)
self
.
input_test
=
None
self
.
input_ref
=
None
...
...
tests/pytorch/test_permutation.py
View file @
44740c6c
...
...
@@ -326,33 +326,37 @@ def _test_permutation_index_map(
te_unpermute_output_
=
te_unpermute_output
.
float
()
te_unpermute_fwd_input_grad
=
te_unpermute_fwd_input
.
grad
.
float
()
torch
.
testing
.
assert_close
(
pytorch_permute_output
.
float
(),
te_permute_output_
,
msg
=
f
"Mismatch in te_permute fwd"
,
)
torch
.
testing
.
assert_close
(
pytorch_permute_fwd_input
.
grad
.
float
(),
te_permute_fwd_input_grad
,
msg
=
f
"Mismatch in te_permute bwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_unpermute_output
.
float
(),
te_unpermute_output_
,
msg
=
f
"Mismatch in te_unpermute fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_unpermute_fwd_input
.
grad
.
float
(),
te_unpermute_fwd_input_grad
,
msg
=
f
"Mismatch in te_unpermute bwd"
,
**
tols
,
)
if
with_probs
:
if
not
BENCHMARK
:
torch
.
testing
.
assert_close
(
probs
.
grad
.
float
(),
te_probs
.
grad
.
float
(),
msg
=
f
"Mismatch in te_unpermute bwd"
,
**
tols
pytorch_permute_output
.
float
(),
te_permute_output_
,
msg
=
f
"Mismatch in te_permute fwd"
,
)
torch
.
testing
.
assert_close
(
pytorch_permute_fwd_input
.
grad
.
float
(),
te_permute_fwd_input_grad
,
msg
=
f
"Mismatch in te_permute bwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_unpermute_output
.
float
(),
te_unpermute_output_
,
msg
=
f
"Mismatch in te_unpermute fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_unpermute_fwd_input
.
grad
.
float
(),
te_unpermute_fwd_input_grad
,
msg
=
f
"Mismatch in te_unpermute bwd"
,
**
tols
,
)
if
with_probs
:
torch
.
testing
.
assert_close
(
probs
.
grad
.
float
(),
te_probs
.
grad
.
float
(),
msg
=
f
"Mismatch in te_unpermute bwd"
,
**
tols
,
)
if
not
pytorch_permute_fwd_input
.
numel
():
print
(
"Empty pytorch_permute_fwd_input activation test passed."
)
...
...
@@ -538,34 +542,38 @@ def _test_permutation_mask_map(
te_unpermute_output_
=
te_unpermute_output
.
float
()
te_unpermute_fwd_input_grad
=
te_unpermute_fwd_input
.
grad
.
float
()
torch
.
testing
.
assert_close
(
pytorch_permute_output
.
float
(),
te_permute_output_
,
msg
=
f
"Mismatch in te_permute fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_permute_fwd_input
.
grad
.
float
(),
te_permute_fwd_input_grad
,
msg
=
f
"Mismatch in te_permute bwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_unpermute_output
.
float
(),
te_unpermute_output_
,
msg
=
f
"Mismatch in te_unpermute fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_unpermute_fwd_input
.
grad
.
float
(),
te_unpermute_fwd_input_grad
,
msg
=
f
"Mismatch in te_unpermute bwd"
,
**
tols
,
)
if
with_probs
:
if
not
BENCHMARK
:
torch
.
testing
.
assert_close
(
pytorch_permute_output
.
float
(),
te_permute_output_
,
msg
=
f
"Mismatch in te_permute fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
probs
.
grad
.
float
(),
te_probs
.
grad
.
float
(),
msg
=
f
"Mismatch in te_unpermute bwd"
,
**
tols
pytorch_permute_fwd_input
.
grad
.
float
(),
te_permute_fwd_input_grad
,
msg
=
f
"Mismatch in te_permute bwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_unpermute_output
.
float
(),
te_unpermute_output_
,
msg
=
f
"Mismatch in te_unpermute fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_unpermute_fwd_input
.
grad
.
float
(),
te_unpermute_fwd_input_grad
,
msg
=
f
"Mismatch in te_unpermute bwd"
,
**
tols
,
)
if
with_probs
:
torch
.
testing
.
assert_close
(
probs
.
grad
.
float
(),
te_probs
.
grad
.
float
(),
msg
=
f
"Mismatch in te_unpermute bwd"
,
**
tols
,
)
if
not
pytorch_permute_fwd_input
.
numel
():
print
(
"Empty pytorch_permute_fwd_input activation test passed."
)
...
...
@@ -827,18 +835,19 @@ def _test_moe_chunk_sort(
te_output_
=
te_output
.
float
()
te_fwd_input_grad
=
te_fwd_input
.
grad
.
float
()
torch
.
testing
.
assert_close
(
pytorch_output
.
float
(),
te_output_
,
msg
=
f
"Mismatch in te_permute fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_fwd_input
.
grad
.
float
(),
te_fwd_input_grad
,
msg
=
f
"Mismatch in te_permute bwd"
,
**
tols
,
)
if
not
BENCHMARK
:
torch
.
testing
.
assert_close
(
pytorch_output
.
float
(),
te_output_
,
msg
=
f
"Mismatch in te_permute fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_fwd_input
.
grad
.
float
(),
te_fwd_input_grad
,
msg
=
f
"Mismatch in te_permute bwd"
,
**
tols
,
)
if
not
pytorch_fwd_input
.
numel
():
print
(
"Empty pytorch_fwd_input activation test passed."
)
...
...
@@ -887,6 +896,7 @@ def _test_permutation_mask_map_alongside_probs(
topK
,
num_out_tokens
,
tp_size
,
BENCHMARK
=
False
,
):
if
topK
>
num_expert
:
pytest
.
skip
(
"topK should be smaller than the number of experts."
)
...
...
@@ -1016,21 +1026,73 @@ def _test_permutation_mask_map_alongside_probs(
te_permute_fwd_input_grad
=
te_permute_fwd_input
.
grad
.
float
()
te_unpermute_output_
=
te_unpermute_output
.
float
()
torch
.
testing
.
assert_close
(
pytorch_unpermute_output
.
float
(),
te_unpermute_output_
,
msg
=
f
"Mismatch in fused_unpermute fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_permute_fwd_input
.
grad
.
float
(),
te_permute_fwd_input_grad
,
msg
=
f
"Mismatch in fused_permute bwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
probs
.
grad
.
float
(),
te_probs
.
grad
.
float
(),
msg
=
f
"Mismatch in prob grad"
,
**
tols
)
if
not
BENCHMARK
:
torch
.
testing
.
assert_close
(
pytorch_unpermute_output
.
float
(),
te_unpermute_output_
,
msg
=
f
"Mismatch in fused_unpermute fwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
pytorch_permute_fwd_input
.
grad
.
float
(),
te_permute_fwd_input_grad
,
msg
=
f
"Mismatch in fused_permute bwd"
,
**
tols
,
)
torch
.
testing
.
assert_close
(
probs
.
grad
.
float
(),
te_probs
.
grad
.
float
(),
msg
=
f
"Mismatch in prob grad"
,
**
tols
)
if
BENCHMARK
:
t1
=
perf_test_cuda_kernel
(
lambda
:
te_permute_with_probs
(
te_permute_fwd_input
,
te_probs
,
routing_map
,
num_out_tokens
=
num_out_tokens
)
)
print
(
f
"permute
\t\t
fwd: TE:
{
t1
:.
3
f
}
ms"
)
te_permute_output
,
te_permuted_probs
,
row_id_map
=
te_permute_with_probs
(
te_permute_fwd_input
,
te_probs
,
routing_map
,
num_out_tokens
=
num_out_tokens
,
)
te_permute_bwd_input
=
torch
.
rand
((
num_out_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
t2
=
perf_test_cuda_kernel
(
lambda
:
backward_wrapper
(
te_permute_output
,
te_permute_bwd_input
,
forward_input
=
[
te_permute_fwd_input
],
retain_graph
=
True
,
accumulate_grad
=
False
,
)
)
print
(
f
"permute
\t\t
bwd: TE:
{
t2
:.
3
f
}
ms"
)
chunk_sort_fwd_input
=
te_permute_output
.
detach
()
chunk_sort_fwd_input
.
requires_grad_
(
True
)
chunk_sort_fwd_probs
=
te_permuted_probs
.
detach
()
chunk_sort_fwd_probs
.
requires_grad_
(
True
)
t1
=
perf_test_cuda_kernel
(
lambda
:
te_sort_chunks_by_index_with_probs
(
chunk_sort_fwd_input
,
chunk_sort_fwd_probs
,
split_sizes_cuda
,
sorted_idxs_cuda
)
)
print
(
f
"chunk sort
\t\t
fwd: TE:
{
t1
:.
3
f
}
ms"
)
chunk_sort_output
,
_
=
te_sort_chunks_by_index_with_probs
(
chunk_sort_fwd_input
,
chunk_sort_fwd_probs
,
split_sizes_cuda
,
sorted_idxs_cuda
)
t2
=
perf_test_cuda_kernel
(
lambda
:
backward_wrapper
(
chunk_sort_output
,
te_permute_bwd_input
,
forward_input
=
[
chunk_sort_fwd_input
],
retain_graph
=
True
,
accumulate_grad
=
False
,
)
)
print
(
f
"chunk sort
\t\t
bwd: TE:
{
t2
:.
3
f
}
ms"
)
def
perf_test_cuda_kernel
(
cuda_kernel_fn
):
...
...
@@ -1063,7 +1125,7 @@ if is_bf16_compatible():
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
_te_dtypes
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
...
...
@@ -1092,7 +1154,7 @@ def test_permutation_index_map(
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
_te_dtypes
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
...
...
@@ -1138,7 +1200,7 @@ def test_permutation_mask_map_empty_input(te_dtype):
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
_te_dtypes
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
...
...
@@ -1193,7 +1255,7 @@ fp8_recipes = [
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
...
...
@@ -1225,7 +1287,7 @@ def test_permutation_mask_map_fp8(
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
_te_dtypes
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
def
test_permutation_index_map_topk1_no_probs
(
te_dtype
,
...
...
@@ -1252,7 +1314,7 @@ def test_permutation_index_map_topk1_no_probs(
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
_te_dtypes
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
def
test_permutation_mask_map_topk1_no_probs
(
te_dtype
,
...
...
@@ -1279,7 +1341,7 @@ def test_permutation_mask_map_topk1_no_probs(
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
_te_dtypes
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
7
,
16
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
,
8
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
def
test_chunk_permutation
(
...
...
@@ -1372,5 +1434,108 @@ def test_permutation_single_case():
)
def
benchmark_single_case
(
te_dtype
,
num_tokens
,
num_expert
,
hidden_size
,
topK
,
num_out_tokens
,
ep_size
,
tp_size
):
torch
.
cuda
.
nvtx
.
range_push
(
f
"
{
num_tokens
}
-
{
num_expert
}
-
{
hidden_size
}
-
{
topK
}
-
{
ep_size
}
-
{
tp_size
}
"
)
torch
.
cuda
.
nvtx
.
range_push
(
"permutation_index_map_with_probs"
)
_test_permutation_index_map
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
with_probs
=
True
,
BENCHMARK
=
True
,
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_push
(
"permutation_mask_map_with_probs"
)
_test_permutation_mask_map
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
with_probs
=
True
,
BENCHMARK
=
True
,
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_push
(
"permutation_mask_map_without_probs"
)
_test_permutation_mask_map
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
with_probs
=
False
,
BENCHMARK
=
True
,
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_push
(
"permutation_mask_map_alongside_probs"
)
_test_permutation_mask_map_alongside_probs
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
tp_size
=
tp_size
,
BENCHMARK
=
True
,
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
def
benchmark_multiple_cases
():
print
(
"GPU:"
,
torch
.
cuda
.
get_device_name
(
0
))
# te_dtype = tex.DType.kFloat32
# te_dtype = tex.DType.kFloat16
te_dtype
=
tex
.
DType
.
kBFloat16
ep_size
=
64
tp_size
=
2
num_tokens
=
4096
num_expert
=
256
hidden_size
=
7168
topK
=
8
num_out_tokens
=
num_tokens
*
topK
benchmark_single_case
(
te_dtype
,
num_tokens
,
num_expert
,
hidden_size
,
topK
,
num_out_tokens
,
ep_size
,
tp_size
)
ep_size
=
8
tp_size
=
1
num_tokens
=
8192
*
2
num_expert
=
128
hidden_size
=
4096
topK
=
6
num_out_tokens
=
num_tokens
*
topK
benchmark_single_case
(
te_dtype
,
num_tokens
,
num_expert
,
hidden_size
,
topK
,
num_out_tokens
,
ep_size
,
tp_size
)
ep_size
=
64
tp_size
=
2
num_tokens
=
16384
num_expert
=
4
hidden_size
=
7168
topK
=
1
num_out_tokens
=
num_tokens
*
topK
benchmark_single_case
(
te_dtype
,
num_tokens
,
num_expert
,
hidden_size
,
topK
,
num_out_tokens
,
ep_size
,
tp_size
)
if
__name__
==
"__main__"
:
test_permutation_sing
le_case
()
benchmark_multip
le_case
s
()
tests/pytorch/test_sanity.py
View file @
44740c6c
...
...
@@ -47,7 +47,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.distributed
import
checkpoint
from
test_numerics
import
reset_rng_states
,
dtype_tols
from
utils
import
dtype_tols
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
...
...
@@ -56,6 +56,28 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
# Record initial RNG state from script run.
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
NVTE_TEST_NVINSPECT_ENABLED
=
int
(
os
.
environ
.
get
(
"NVTE_TEST_NVINSPECT_ENABLED"
,
"0"
))
if
NVTE_TEST_NVINSPECT_ENABLED
:
# The sanity tests should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import
nvdlfw_inspect.api
as
debug_api
debug_api
.
initialize
(
os
.
environ
[
"NVTE_TEST_NVINSPECT_CONFIG_FILE"
],
feature_dirs
=
os
.
environ
[
"NVTE_TEST_NVINSPECT_FEATURE_DIRS"
],
)
def
create_meta
(
scale_factor
:
float
,
size
:
int
=
1
):
meta
=
tex
.
FP8TensorMeta
()
...
...
@@ -90,6 +112,13 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
return
torch
.
min
(
amax_history
,
dim
=
0
).
values
def
reset_rng_states
()
->
None
:
"""revert back to initial RNG state."""
global
_cpu_rng_state
,
_cuda_rng_state
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
@
dataclass
class
ModelConfig
:
"""Transformer model configuration"""
...
...
@@ -529,6 +558,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
all_boolean
)
def
test_sanity_linear_with_zero_tokens
(
dtype
,
bs
,
model
,
fp8_recipe
,
fp8_model_params
,
use_bias
):
if
NVTE_TEST_NVINSPECT_ENABLED
and
fp8_model_params
:
pytest
.
skip
(
"Quantized model parameters are not supported in debug mode."
)
config
=
model_configs
[
model
]
ffn_hidden_size
=
4
*
config
.
hidden_size
num_tokens
=
bs
*
config
.
seq_len
...
...
@@ -570,6 +601,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
def
test_sanity_grouped_linear
(
dtype
,
bs
,
model
,
fp8_recipe
,
fp8_model_params
,
use_bias
,
num_gemms
,
empty_split
):
if
NVTE_TEST_NVINSPECT_ENABLED
and
fp8_model_params
:
pytest
.
skip
(
"FP8 model parameters are not supported in debug mode."
)
config
=
model_configs
[
model
]
ffn_hidden_size
=
4
*
config
.
hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
...
...
@@ -682,6 +715,8 @@ def test_sanity_gpt(
parallel_attention_mlp
,
cpu_offload
,
):
if
cpu_offload
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"CPU offload is not supported in debug mode."
)
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
...
...
@@ -1367,6 +1402,8 @@ def test_inference_mode(
quantization
:
Optional
[
str
],
)
->
None
:
"""Test heuristics for initializing quantized weights"""
if
NVTE_TEST_NVINSPECT_ENABLED
and
quantization
is
not
None
:
pytest
.
skip
(
"Quantized model parameters are not supported in debug mode."
)
# Tensor dimensions
sequence_length
=
32
...
...
tests/pytorch/utils.py
View file @
44740c6c
...
...
@@ -93,6 +93,7 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if
name
in
(
"fp8"
,
"fp8_delayed_scaling"
):
return
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
,
amax_history_len
=
8
,
)
if
name
==
"fp8_current_scaling"
:
return
transformer_engine
.
common
.
recipe
.
Float8CurrentScaling
(
...
...
transformer_engine/common/CMakeLists.txt
View file @
44740c6c
...
...
@@ -158,6 +158,9 @@ if(USE_CUDA)
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
...
...
@@ -211,6 +214,9 @@ else()
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
View file @
44740c6c
...
...
@@ -100,6 +100,16 @@ bool has_mnnvl_fabric(int device_id) {
}
return
false
;
#else
// Check run-time CUDA version
if
(
transformer_engine
::
cuda
::
cudart_version
()
<
12040
)
{
if
(
getenv
(
"NVTE_UBDEBUG"
))
{
printf
(
"TransformerEngine does not support multi-node NVLINK "
"since it is not being run with CUDA version >= 12.4.
\n
"
);
}
return
false
;
}
bool
mnnvl_fabric_support
=
false
;
CUdevice
dev
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuDeviceGet
,
&
dev
,
device_id
);
...
...
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
44740c6c
...
...
@@ -248,7 +248,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type
!=
NVTE_Mask_Type
::
NVTE_PADDING_CAUSAL_MASK
)))
||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(
head_dim_qk
==
192
&&
head_dim_v
==
128
&&
is_training
&&
sm_arch_
>=
100
&&
cudnn_runtime_version
>=
91100
)))
&&
cudnn_runtime_version
>=
91100
))
&&
// 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(
!
(
cudnn_runtime_version
==
91100
&&
is_training
&&
sm_arch_
==
90
&&
head_dim_qk
>=
128
&&
head_dim_v
>=
128
&&
!
(
head_dim_qk
==
192
&&
head_dim_v
==
128
)
&&
head_dim_qk
!=
head_dim_v
)))
&&
// bias type
((
cudnn_runtime_version
<
8906
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
)
||
(
cudnn_runtime_version
>=
8906
&&
...
...
transformer_engine/common/fused_attn/kv_cache.cu
View file @
44740c6c
...
...
@@ -10,6 +10,8 @@
namespace
transformer_engine
{
namespace
kv_cache
{
constexpr
int
block_size
=
1024
;
template
<
typename
dtype
>
__global__
void
reindex_kv_cache_kernel
(
dtype
*
k_cache
,
dtype
*
v_cache
,
int
*
batch_indices
,
int
*
cu_new_lens
,
int
*
cu_cached_lens
,
int
h_kv
,
int
d_k
,
...
...
@@ -22,21 +24,29 @@ __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *bat
actual_b
=
i
+
1
;
}
}
bool
flag
=
(
batch_indices
[
0
]
!=
0
);
for
(
int
batch_idx
=
0
;
batch_idx
<
actual_b
;
batch_idx
++
)
{
i
nt
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
n
ew_l
en
=
cu_
new
_lens
[
batch_idx
+
1
]
-
cu_
new
_lens
[
batch_idx
]
;
for
(
int
token_idx
=
blockIdx
.
x
;
token_idx
<
cached_len
-
new_len
;
token_idx
+=
gridDim
.
x
)
{
i
f
(
flag
||
((
batch_indices
[
batch_idx
]
-
batch_indices
[
0
])
!=
batch_idx
))
{
int
n
um_tok
en
s
=
(
cu_
cached
_lens
[
batch_idx
+
1
]
-
cu_
cached
_lens
[
batch_idx
]
)
-
(
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
]);
int
num_elts_k
=
h_kv
*
d_k
;
int
num_elts_v
=
h_kv
*
d_v
;
int
k_cache_src_offset
=
(
batch_indices
[
batch_idx
]
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_k
;
int
k_cache_des_offset
=
(
batch_idx
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_k
;
int
v_cache_src_offset
=
(
batch_indices
[
batch_idx
]
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_v
;
int
v_cache_des_offset
=
(
batch_idx
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_v
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts_k
;
i
+=
blockDim
.
x
)
{
*
(
k_cache
+
k_cache_des_offset
+
i
)
=
*
(
k_cache
+
k_cache_src_offset
+
i
);
}
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts_v
;
i
+=
blockDim
.
x
)
{
*
(
v_cache
+
v_cache_des_offset
+
i
)
=
*
(
v_cache
+
v_cache_src_offset
+
i
);
int
num_elts
=
max
(
num_elts_k
,
num_elts_v
);
for
(
int
token_idx
=
blockIdx
.
x
;
token_idx
<
num_tokens
;
token_idx
+=
gridDim
.
x
)
{
int
src_offset
=
batch_indices
[
batch_idx
]
*
max_seq_len
+
token_idx
;
int
des_offset
=
batch_idx
*
max_seq_len
+
token_idx
;
dtype
*
k_cache_src_offset
=
k_cache
+
src_offset
*
num_elts_k
;
dtype
*
k_cache_des_offset
=
k_cache
+
des_offset
*
num_elts_k
;
dtype
*
v_cache_src_offset
=
v_cache
+
src_offset
*
num_elts_v
;
dtype
*
v_cache_des_offset
=
v_cache
+
des_offset
*
num_elts_v
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
if
(
i
<
num_elts_k
)
{
*
(
k_cache_des_offset
+
i
)
=
*
(
k_cache_src_offset
+
i
);
}
if
(
i
<
num_elts_v
)
{
*
(
v_cache_des_offset
+
i
)
=
*
(
v_cache_src_offset
+
i
);
}
}
}
}
}
...
...
@@ -55,19 +65,26 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
)
{
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
*
page_list
=
is_non_paged
?
nullptr
:
page_table
+
batch_idx
*
max_pages_per_seq
;
int
new_token_offset
=
batch_idx
*
max_ctx_len
;
int
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
new_len
=
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
];
for
(
int
i
=
threadIdx
.
x
;
i
<
new_len
;
i
+=
blockDim
.
x
)
{
int
num_elts_k
=
h_kv
*
d_k
;
int
num_elts_v
=
h_kv
*
d_v
;
int
hd
=
h_kv
*
max
(
d_k
,
d_v
);
for
(
int
i
=
blockIdx
.
y
;
i
<
new_len
;
i
+=
gridDim
.
y
)
{
int
page_idx
=
is_non_paged
?
batch_idx
:
page_list
[(
cached_len
-
new_len
+
i
)
/
page_size
];
int
token_idx
=
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
;
for
(
int
j
=
0
;
j
<
h_kv
*
d_k
;
j
++
)
{
*
(
k_cache
+
token_idx
*
h_kv
*
d_k
+
j
)
=
*
(
new_k
+
(
new_token_offset
+
i
)
*
h_kv
*
d_k
+
j
);
}
for
(
int
j
=
0
;
j
<
h_kv
*
d_v
;
j
++
)
{
*
(
v_cache
+
token_idx
*
h_kv
*
d_v
+
j
)
=
*
(
new_v
+
(
new_token_offset
+
i
)
*
h_kv
*
d_v
+
j
);
dtype
*
new_token_id_k
=
new_k
+
(
batch_idx
*
max_ctx_len
+
i
)
*
num_elts_k
;
dtype
*
new_token_id_v
=
new_v
+
(
batch_idx
*
max_ctx_len
+
i
)
*
num_elts_v
;
dtype
*
token_id_k
=
k_cache
+
(
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
)
*
num_elts_k
;
dtype
*
token_id_v
=
v_cache
+
(
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
)
*
num_elts_v
;
for
(
int
j
=
threadIdx
.
x
;
j
<
hd
;
j
+=
blockDim
.
x
)
{
if
(
j
<
num_elts_k
)
{
*
(
token_id_k
+
j
)
=
*
(
new_token_id_k
+
j
);
}
if
(
j
<
num_elts_v
)
{
*
(
token_id_v
+
j
)
=
*
(
new_token_id_v
+
j
);
}
}
}
}
...
...
@@ -76,14 +93,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
int
*
page_list
=
is_non_paged
?
nullptr
:
page_table
+
batch_idx
*
max_pages_per_seq
;
int
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
new_len
=
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
];
for
(
int
i
=
threadIdx
.
x
;
i
<
new_len
;
i
+=
blockDim
.
x
)
{
int
num_elts_k
=
h_kv
*
d_k
;
int
num_elts_v
=
h_kv
*
d_v
;
int
hd
=
h_kv
*
max
(
d_k
,
d_v
);
for
(
int
i
=
blockIdx
.
y
;
i
<
new_len
;
i
+=
gridDim
.
y
)
{
int
page_idx
=
is_non_paged
?
batch_idx
:
page_list
[(
cached_len
-
new_len
+
i
)
/
page_size
];
int
token_idx
=
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
;
for
(
int
j
=
0
;
j
<
h_kv
*
d_k
;
j
++
)
{
*
(
k_cache
+
token_idx
*
h_kv
*
d_k
+
j
)
=
*
(
new_k
+
(
i
*
b
+
batch_idx
)
*
h_kv
*
d_k
+
j
);
}
for
(
int
j
=
0
;
j
<
h_kv
*
d_v
;
j
++
)
{
*
(
v_cache
+
token_idx
*
h_kv
*
d_v
+
j
)
=
*
(
new_v
+
(
i
*
b
+
batch_idx
)
*
h_kv
*
d_v
+
j
);
dtype
*
new_token_id_k
=
new_k
+
(
i
*
b
+
batch_idx
)
*
num_elts_k
;
dtype
*
new_token_id_v
=
new_v
+
(
i
*
b
+
batch_idx
)
*
num_elts_v
;
dtype
*
token_id_k
=
k_cache
+
(
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
)
*
num_elts_k
;
dtype
*
token_id_v
=
v_cache
+
(
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
)
*
num_elts_v
;
for
(
int
j
=
threadIdx
.
x
;
j
<
hd
;
j
+=
blockDim
.
x
)
{
if
(
j
<
num_elts_k
)
{
*
(
token_id_k
+
j
)
=
*
(
new_token_id_k
+
j
);
}
if
(
j
<
num_elts_v
)
{
*
(
token_id_v
+
j
)
=
*
(
new_token_id_v
+
j
);
}
}
}
}
...
...
@@ -92,16 +119,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
int
*
page_list
=
is_non_paged
?
nullptr
:
page_table
+
batch_idx
*
max_pages_per_seq
;
int
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
new_len
=
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
];
for
(
int
i
=
threadIdx
.
x
;
i
<
new_len
;
i
+=
blockDim
.
x
)
{
int
num_elts_k
=
h_kv
*
d_k
;
int
num_elts_v
=
h_kv
*
d_v
;
int
hd
=
h_kv
*
max
(
d_k
,
d_v
);
for
(
int
i
=
blockIdx
.
y
;
i
<
new_len
;
i
+=
gridDim
.
y
)
{
int
page_idx
=
is_non_paged
?
batch_idx
:
page_list
[(
cached_len
-
new_len
+
i
)
/
page_size
];
int
token_idx
=
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
;
for
(
int
j
=
0
;
j
<
h_kv
*
d_k
;
j
++
)
{
*
(
k_cache
+
token_idx
*
h_kv
*
d_k
+
j
)
=
*
(
new_k
+
(
cu_new_lens
[
batch_idx
]
+
i
)
*
h_kv
*
d_k
+
j
);
}
for
(
int
j
=
0
;
j
<
h_kv
*
d_v
;
j
++
)
{
*
(
v_cache
+
token_idx
*
h_kv
*
d_v
+
j
)
=
*
(
new_v
+
(
cu_new_lens
[
batch_idx
]
+
i
)
*
h_kv
*
d_v
+
j
);
dtype
*
new_token_id_k
=
new_k
+
(
cu_new_lens
[
batch_idx
]
+
i
)
*
num_elts_k
;
dtype
*
new_token_id_v
=
new_v
+
(
cu_new_lens
[
batch_idx
]
+
i
)
*
num_elts_v
;
dtype
*
token_id_k
=
k_cache
+
(
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
)
*
num_elts_k
;
dtype
*
token_id_v
=
v_cache
+
(
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
)
*
num_elts_v
;
for
(
int
j
=
threadIdx
.
x
;
j
<
hd
;
j
+=
blockDim
.
x
)
{
if
(
j
<
num_elts_k
)
{
*
(
token_id_k
+
j
)
=
*
(
new_token_id_k
+
j
);
}
if
(
j
<
num_elts_v
)
{
*
(
token_id_v
+
j
)
=
*
(
new_token_id_v
+
j
);
}
}
}
}
...
...
@@ -116,14 +151,15 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
bool
is_non_paged
,
cudaStream_t
stream
)
{
if
(
new_k
.
has_data
()
&&
new_v
.
has_data
()
&&
k_cache
.
has_data
()
&&
v_cache
.
has_data
())
{
if
(
is_non_paged
)
{
reindex_kv_cache_kernel
<<<
16
,
256
,
0
,
stream
>>>
(
reindex_kv_cache_kernel
<<<
max_seq_len
,
block_size
,
0
,
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
k_cache
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
v_cache
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
page_table
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_new_lens
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_cached_lens
.
data
.
dptr
),
h_kv
,
d_k
,
d_v
,
b
,
max_seq_len
);
}
copy_to_kv_cache_kernel
<<<
16
,
256
,
0
,
stream
>>>
(
dim3
grid_size
(
b
,
max_ctx_len
);
copy_to_kv_cache_kernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
new_k
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
new_v
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
k_cache
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
v_cache
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
page_table
.
data
.
dptr
),
...
...
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
0 → 100644
View file @
44740c6c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cooperative_groups.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_router.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
#include "common/util/cuda_runtime.h"
#include "utils.h"
namespace
transformer_engine
{
// Using Double to hanld all the calculations
using
CompType
=
double
;
template
<
typename
DataType
,
typename
IndexType
>
__global__
void
fused_moe_aux_loss_forward_kernel
(
const
DataType
*
probs
,
const
IndexType
*
tokens_per_expert
,
int
total_num_tokens
,
int
num_experts
,
int
num_rows
,
int
num_cols
,
int
topk
,
float
coeff
,
DataType
*
aux_loss
,
float
*
Const_buf
)
{
#if __CUDA_ARCH__ >= 900
// Using cooperative_groups to manage the cluster
namespace
cg
=
cooperative_groups
;
cg
::
cluster_group
cluster
=
cg
::
this_cluster
();
int
thread_id
=
cg
::
this_grid
().
thread_rank
();
int
lane_id
=
thread_id
%
kThreadsPerWarp
;
int
warp_id
=
thread_id
/
kThreadsPerWarp
;
int
warp_num
=
blockDim
.
x
*
gridDim
.
x
/
kThreadsPerWarp
;
// Only 1 block in the cluster
int
block_id
=
cluster
.
block_rank
();
int
block_num
=
cluster
.
dim_blocks
().
x
;
int
cluster_id
=
blockIdx
.
x
/
block_num
;
if
(
cluster_id
>
0
)
return
;
// Only use the cluster 0
extern
__shared__
float
shmem_aux_loss
[];
CompType
*
aggregated_probs_per_expert
=
reinterpret_cast
<
CompType
*>
(
shmem_aux_loss
);
// Clear the shmem
for
(
int
i
=
threadIdx
.
x
;
i
<
num_cols
;
i
+=
blockDim
.
x
)
{
aggregated_probs_per_expert
[
i
]
=
CompType
(
0
);
}
__syncthreads
();
/**
* Section: Reduce the probs to the aggregated_probs_per_expert
* 1. reduce on the block
* 2. reduce on the cluster
*/
// Loop: for all positions in each row
for
(
int
i
=
lane_id
;
i
<
num_cols
;
i
+=
kThreadsPerWarp
)
{
CompType
tmp
=
CompType
(
0
);
// Loop: for all rows that this warp is responsible for
for
(
int
j
=
warp_id
;
j
<
num_rows
;
j
+=
warp_num
)
{
tmp
+=
CompType
(
probs
[
j
*
num_cols
+
i
]);
}
atomicAdd
(
&
aggregated_probs_per_expert
[
i
],
tmp
);
}
cluster
.
sync
();
// The block 0 will reduce the results of all blocks
if
(
block_id
==
0
)
{
for
(
int
i
=
1
;
i
<
block_num
;
i
++
)
{
// Map the shared memory of the block i to the current block
CompType
*
dst_smem
=
reinterpret_cast
<
CompType
*>
(
cluster
.
map_shared_rank
(
shmem_aux_loss
,
i
));
for
(
int
j
=
threadIdx
.
x
;
j
<
num_cols
;
j
+=
blockDim
.
x
)
{
atomicAdd
(
&
aggregated_probs_per_expert
[
j
],
dst_smem
[
j
]);
}
}
}
cluster
.
sync
();
/**
* Section: aggregated_probs_per_expert * tokens_per_expert
* In-place update on shmem
*/
if
(
block_id
==
0
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
num_cols
;
i
+=
blockDim
.
x
)
{
aggregated_probs_per_expert
[
i
]
*=
CompType
(
tokens_per_expert
[
i
]);
}
__syncthreads
();
if
(
warp_id
==
0
)
{
/**
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType
intermediate_result
=
warp_reduce_on_shmem
(
aggregated_probs_per_expert
,
num_cols
,
sum
,
lane_id
);
__syncwarp
();
if
(
lane_id
==
0
)
{
/**
* Section: Compute the aux_loss
*/
float
C_coeff
=
(
num_experts
*
coeff
)
/
topk
/
total_num_tokens
/
total_num_tokens
;
aux_loss
[
0
]
=
static_cast
<
DataType
>
(
static_cast
<
double
>
(
intermediate_result
)
*
C_coeff
);
Const_buf
[
0
]
=
C_coeff
;
}
}
}
#else
// Use Only 1 block/1024 threads to avoid the grid sync
if
(
blockIdx
.
x
>
0
)
return
;
int
warp_num
=
blockDim
.
x
/
kThreadsPerWarp
;
int
warp_id
=
threadIdx
.
x
/
kThreadsPerWarp
;
int
lane_id
=
threadIdx
.
x
%
kThreadsPerWarp
;
extern
__shared__
float
shmem_aux_loss
[];
CompType
*
aggregated_probs_per_expert
=
reinterpret_cast
<
CompType
*>
(
shmem_aux_loss
);
// Clear the shmem
for
(
int
i
=
threadIdx
.
x
;
i
<
num_cols
;
i
+=
blockDim
.
x
)
{
aggregated_probs_per_expert
[
i
]
=
CompType
(
0
);
}
__syncthreads
();
/**
* Section: Reduce the probs to the aggregated_probs_per_expert
*/
// Loop: for all positions in each row
for
(
int
i
=
lane_id
;
i
<
num_cols
;
i
+=
kThreadsPerWarp
)
{
CompType
tmp
=
CompType
(
0
);
// Loop: for all rows that this warp is responsible for
for
(
int
j
=
warp_id
;
j
<
num_rows
;
j
+=
warp_num
)
{
tmp
+=
CompType
(
probs
[
j
*
num_cols
+
i
]);
}
atomicAdd
(
&
aggregated_probs_per_expert
[
i
],
tmp
);
}
__syncthreads
();
/**
* Section: aggregated_probs_per_expert * tokens_per_expert
* In-place update on shmem
*/
for
(
int
i
=
threadIdx
.
x
;
i
<
num_cols
;
i
+=
blockDim
.
x
)
{
aggregated_probs_per_expert
[
i
]
*=
CompType
(
tokens_per_expert
[
i
]);
}
__syncthreads
();
if
(
warp_id
==
0
)
{
/**
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType
intermediate_result
=
warp_reduce_on_shmem
(
aggregated_probs_per_expert
,
num_cols
,
sum
,
lane_id
);
__syncwarp
();
if
(
lane_id
==
0
)
{
/**
* Section: Compute the aux_loss
*/
float
C_coeff
=
(
num_experts
*
coeff
)
/
topk
/
total_num_tokens
/
total_num_tokens
;
aux_loss
[
0
]
=
static_cast
<
DataType
>
(
static_cast
<
double
>
(
intermediate_result
)
*
C_coeff
);
Const_buf
[
0
]
=
C_coeff
;
}
}
#endif
}
template
<
typename
DataType
,
typename
IndexType
>
void
fused_moe_aux_loss_forward_kernel_launcher
(
const
DataType
*
probs
,
const
IndexType
*
tokens_per_expert
,
int
total_num_tokens
,
int
num_experts
,
int
num_rows
,
int
num_cols
,
int
topk
,
float
coeff
,
DataType
*
aux_loss
,
float
*
Const_buf
,
cudaStream_t
stream
)
{
if
(
cuda
::
sm_arch
(
cuda
::
current_device
())
>=
90
)
{
cudaLaunchConfig_t
config
=
{
0
};
int
cluster_size
=
8
;
config
.
gridDim
=
cluster_size
;
config
.
blockDim
=
1024
;
config
.
dynamicSmemBytes
=
sizeof
(
CompType
)
*
num_cols
;
config
.
stream
=
stream
;
// Update the max cluster size based on the device
cudaOccupancyMaxPotentialClusterSize
(
&
cluster_size
,
reinterpret_cast
<
void
*>
(
fused_moe_aux_loss_forward_kernel
<
DataType
,
IndexType
>
),
&
config
);
cudaLaunchAttribute
attribute
[
1
];
attribute
[
0
].
id
=
cudaLaunchAttributeClusterDimension
;
attribute
[
0
].
val
.
clusterDim
.
x
=
cluster_size
;
attribute
[
0
].
val
.
clusterDim
.
y
=
1
;
attribute
[
0
].
val
.
clusterDim
.
z
=
1
;
config
.
numAttrs
=
1
;
config
.
attrs
=
attribute
;
cudaLaunchKernelEx
(
&
config
,
fused_moe_aux_loss_forward_kernel
<
DataType
,
IndexType
>
,
probs
,
tokens_per_expert
,
total_num_tokens
,
num_experts
,
num_rows
,
num_cols
,
topk
,
coeff
,
aux_loss
,
Const_buf
);
}
else
{
size_t
smem_size
=
sizeof
(
CompType
)
*
num_cols
;
fused_moe_aux_loss_forward_kernel
<
DataType
,
IndexType
>
<<<
1
,
1024
,
smem_size
,
stream
>>>
(
probs
,
tokens_per_expert
,
total_num_tokens
,
num_experts
,
num_rows
,
num_cols
,
topk
,
coeff
,
aux_loss
,
Const_buf
);
}
}
void
fused_moe_aux_loss_forward
(
const
Tensor
&
probs
,
const
Tensor
&
tokens_per_expert
,
int
total_num_tokens
,
int
num_experts
,
int
num_rows
,
int
num_cols
,
int
topk
,
float
coeff
,
Tensor
&
aux_loss
,
Tensor
&
Const_buf
,
cudaStream_t
stream
)
{
TE_ROUTER_PROBS_TYPE_SWITCH_ALL
(
probs
.
data
.
dtype
,
DataType
,
TE_ROUTER_INDEX_TYPE_SWITCH_ALL
(
tokens_per_expert
.
data
.
dtype
,
IndexType
,
fused_moe_aux_loss_forward_kernel_launcher
<
DataType
,
IndexType
>
(
reinterpret_cast
<
DataType
*>
(
probs
.
data
.
dptr
),
reinterpret_cast
<
IndexType
*>
(
tokens_per_expert
.
data
.
dptr
),
total_num_tokens
,
num_experts
,
num_rows
,
num_cols
,
topk
,
coeff
,
reinterpret_cast
<
DataType
*>
(
aux_loss
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
Const_buf
.
data
.
dptr
),
stream
);););
}
template
<
typename
DataType
,
typename
IndexType
>
__global__
void
fused_moe_aux_loss_backward_kernel
(
const
float
*
Const_buf
,
const
IndexType
*
tokens_per_expert
,
int
num_rows
,
int
num_cols
,
DataType
*
grad_aux_loss
,
DataType
*
grad_probs
)
{
int
global_warp_num
=
gridDim
.
x
*
blockDim
.
x
/
kThreadsPerWarp
;
int
global_warp_id
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
/
kThreadsPerWarp
;
int
lane_id
=
threadIdx
.
x
%
kThreadsPerWarp
;
// Loop: for all positions in each row
for
(
int
i
=
lane_id
;
i
<
num_cols
;
i
+=
kThreadsPerWarp
)
{
float
C_coeff
=
Const_buf
[
0
];
IndexType
tokens_per_expert_i
=
tokens_per_expert
[
i
];
double
grad_aux_loss_value
=
static_cast
<
double
>
(
grad_aux_loss
[
0
]);
// Loop: for all rows
for
(
int
j
=
global_warp_id
;
j
<
num_rows
;
j
+=
global_warp_num
)
{
grad_probs
[
j
*
num_cols
+
i
]
=
C_coeff
*
tokens_per_expert_i
*
grad_aux_loss_value
;
}
}
}
template
<
typename
DataType
,
typename
IndexType
>
void
fused_moe_aux_loss_backward_kernel_launcher
(
const
float
*
Const_buf
,
const
IndexType
*
tokens_per_expert
,
int
num_rows
,
int
num_cols
,
DataType
*
grad_aux_loss
,
DataType
*
grad_probs
,
cudaStream_t
stream
)
{
// Meta data for the kernel
int
block_size
=
256
;
int
grid_size
=
(
num_rows
+
block_size
-
1
)
/
block_size
;
fused_moe_aux_loss_backward_kernel
<
DataType
,
IndexType
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
Const_buf
,
tokens_per_expert
,
num_rows
,
num_cols
,
grad_aux_loss
,
grad_probs
);
}
void
fused_moe_aux_loss_backward
(
const
Tensor
&
Const_buf
,
const
Tensor
&
tokens_per_expert
,
int
num_rows
,
int
num_cols
,
Tensor
&
grad_aux_loss
,
Tensor
&
grad_probs
,
cudaStream_t
stream
)
{
TE_ROUTER_PROBS_TYPE_SWITCH_ALL
(
grad_aux_loss
.
data
.
dtype
,
DataType
,
TE_ROUTER_INDEX_TYPE_SWITCH_ALL
(
tokens_per_expert
.
data
.
dtype
,
IndexType
,
fused_moe_aux_loss_backward_kernel_launcher
<
DataType
,
IndexType
>
(
reinterpret_cast
<
float
*>
(
Const_buf
.
data
.
dptr
),
reinterpret_cast
<
IndexType
*>
(
tokens_per_expert
.
data
.
dptr
),
num_rows
,
num_cols
,
reinterpret_cast
<
DataType
*>
(
grad_aux_loss
.
data
.
dptr
),
reinterpret_cast
<
DataType
*>
(
grad_probs
.
data
.
dptr
),
stream
);););
}
}
// namespace transformer_engine
void
nvte_fused_moe_aux_loss_forward
(
const
NVTETensor
probs
,
const
NVTETensor
tokens_per_expert
,
int
total_num_tokens
,
int
num_experts
,
int
num_rows
,
int
num_cols
,
int
topk
,
float
coeff
,
NVTETensor
aux_loss
,
NVTETensor
Const_buf
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_moe_aux_loss_forward
);
using
namespace
transformer_engine
;
fused_moe_aux_loss_forward
(
*
convertNVTETensorCheck
(
probs
),
*
convertNVTETensorCheck
(
tokens_per_expert
),
total_num_tokens
,
num_experts
,
num_rows
,
num_cols
,
topk
,
coeff
,
*
convertNVTETensorCheck
(
aux_loss
),
*
convertNVTETensorCheck
(
Const_buf
),
stream
);
}
void
nvte_fused_moe_aux_loss_backward
(
const
NVTETensor
Const_buf
,
const
NVTETensor
tokens_per_expert
,
int
num_rows
,
int
num_cols
,
NVTETensor
grad_aux_loss
,
NVTETensor
grad_probs
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_moe_aux_loss_backward
);
using
namespace
transformer_engine
;
fused_moe_aux_loss_backward
(
*
convertNVTETensorCheck
(
Const_buf
),
*
convertNVTETensorCheck
(
tokens_per_expert
),
num_rows
,
num_cols
,
*
convertNVTETensorCheck
(
grad_aux_loss
),
*
convertNVTETensorCheck
(
grad_probs
),
stream
);
}
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
0 → 100644
View file @
44740c6c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_router.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
#include "utils.h"
namespace
transformer_engine
{
template
<
typename
DataType
>
__global__
void
fused_score_for_moe_aux_loss_forward_kernel
(
const
DataType
*
logits
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
score_function
,
DataType
*
scores
,
bool
*
routing_map
,
DataType
*
intermediate_output
)
{
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* So DataType address is assigned firstly to avoid the alignment issue
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int
num_token_per_block
=
blockDim
.
x
/
kThreadsPerWarp
;
int
warp_id
=
threadIdx
.
x
/
kThreadsPerWarp
;
int
lane_id
=
threadIdx
.
x
%
kThreadsPerWarp
;
extern
__shared__
float
shmem_scores_for_aux_loss
[];
DataType
*
logits_buf
=
reinterpret_cast
<
DataType
*>
(
shmem_scores_for_aux_loss
);
DataType
*
topk_logits_buf
=
reinterpret_cast
<
DataType
*>
(
logits_buf
+
num_experts
*
num_token_per_block
);
int
*
topk_indices_buf
=
reinterpret_cast
<
int
*>
(
topk_logits_buf
+
topk
*
num_token_per_block
);
// The address of buffers on the current warp
DataType
*
local_logits
=
logits_buf
+
warp_id
*
num_experts
;
DataType
*
topk_logits
=
topk_logits_buf
+
warp_id
*
topk
;
int
*
topk_indices
=
topk_indices_buf
+
warp_id
*
topk
;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int
total_round
=
(
num_tokens
+
num_token_per_block
-
1
)
/
num_token_per_block
;
for
(
int
round
=
blockIdx
.
x
;
round
<
total_round
;
round
+=
gridDim
.
x
)
{
int
token_offset_cur_warp
=
round
*
num_token_per_block
+
warp_id
;
// Each warp is responsible for one token
if
(
token_offset_cur_warp
>=
num_tokens
)
break
;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the logits to shmem
*/
int
pos_offset
=
token_offset_cur_warp
*
num_experts
;
// Clear the routing_map (num_experts)
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
routing_map
[
pos_offset
+
i
]
=
false
;
if
(
score_function
==
1
)
{
intermediate_output
[
pos_offset
+
i
]
=
-
std
::
numeric_limits
<
DataType
>::
infinity
();
}
}
// Load the logits to shmem
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_logits
[
i
]
=
logits
[
pos_offset
+
i
];
}
__threadfence_block
();
__syncwarp
();
/***
* Section: Preprocess
* Possible preprocess the scores before the topk operation
* - Pre-softmax
* - Sigmoid
* - Sigmoid post-processing when topk > 1
* This is in-place scores update
*/
// score_function == 1 means softmax
if
(
score_function
==
1
)
{
// Apply softmax to the logits before the topk
apply_softmax_on_float
(
local_logits
,
num_experts
,
lane_id
);
__syncwarp
();
// Save the softmax output for backward
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
intermediate_output
[
pos_offset
+
i
]
=
local_logits
[
i
];
}
}
// score_function == 0 means sigmoid
if
(
score_function
==
0
)
{
// Apply sigmoid to the logits
apply_sigmoid_on_float
(
local_logits
,
num_experts
,
lane_id
);
__syncwarp
();
// Save the sigmoid output for backward
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
intermediate_output
[
pos_offset
+
i
]
=
local_logits
[
i
];
}
}
__syncwarp
();
//Confirm the scores is written to the softmax/sigmoid output
if
(
score_function
==
0
)
{
if
(
topk
>
1
)
{
auto
sum_logits
=
warp_reduce_on_shmem
(
local_logits
,
num_experts
,
sum
,
lane_id
);
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_logits
[
i
]
=
static_cast
<
DataType
>
(
static_cast
<
double
>
(
local_logits
[
i
])
/
(
static_cast
<
double
>
(
sum_logits
)
+
epsilon
));
}
}
__syncwarp
();
}
/***
* Section: Topk
* Get the topk indices
*/
naive_topk_and_mask
(
local_logits
,
num_experts
,
topk
,
topk_indices
,
topk_logits
,
lane_id
);
__syncwarp
();
// Write the routing_map to the output tensor
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
kThreadsPerWarp
)
{
routing_map
[
pos_offset
+
topk_indices
[
i
]]
=
true
;
}
// Write the scores to the output tensor
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
scores
[
pos_offset
+
i
]
=
local_logits
[
i
];
}
__threadfence_block
();
__syncwarp
();
}
}
template
<
typename
DataType
>
void
fused_score_for_moe_aux_loss_forward_kernel_launcher
(
const
DataType
*
logits
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
score_function
,
DataType
*
scores
,
bool
*
routing_map
,
DataType
*
intermediate_output
,
cudaStream_t
stream
)
{
// Meta data for the kernel
size_t
num_token_per_block
=
kThreadsPerBlock
/
kThreadsPerWarp
;
size_t
grid_size
=
(
num_tokens
+
num_token_per_block
-
1
)
/
num_token_per_block
;
size_t
shared_memory_size
=
num_experts
*
num_token_per_block
*
sizeof
(
DataType
)
// logits
+
topk
*
num_token_per_block
*
sizeof
(
DataType
)
// topk_logits
+
topk
*
num_token_per_block
*
sizeof
(
int
);
// topk_indices
fused_score_for_moe_aux_loss_forward_kernel
<
DataType
>
<<<
grid_size
,
kThreadsPerBlock
,
shared_memory_size
,
stream
>>>
(
logits
,
num_tokens
,
num_experts
,
topk
,
score_function
,
scores
,
routing_map
,
intermediate_output
);
}
void
fused_score_for_moe_aux_loss_forward
(
const
Tensor
&
logits
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
score_function
,
Tensor
&
scores
,
Tensor
&
routing_map
,
Tensor
&
intermediate_output
,
cudaStream_t
stream
)
{
TE_ROUTER_PROBS_TYPE_SWITCH_ALL
(
logits
.
data
.
dtype
,
DataType
,
fused_score_for_moe_aux_loss_forward_kernel_launcher
<
DataType
>
(
reinterpret_cast
<
DataType
*>
(
logits
.
data
.
dptr
),
num_tokens
,
num_experts
,
topk
,
score_function
,
reinterpret_cast
<
DataType
*>
(
scores
.
data
.
dptr
),
reinterpret_cast
<
bool
*>
(
routing_map
.
data
.
dptr
),
reinterpret_cast
<
DataType
*>
(
intermediate_output
.
data
.
dptr
),
stream
););
}
template
<
typename
DataType
>
__global__
void
fused_score_for_moe_aux_loss_backward_kernel
(
const
DataType
*
intermediate_output
,
const
DataType
*
grad_scores
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
score_function
,
DataType
*
grad_logits
)
{
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int
num_token_per_block
=
blockDim
.
x
/
kThreadsPerWarp
;
int
warp_id
=
threadIdx
.
x
/
kThreadsPerWarp
;
int
lane_id
=
threadIdx
.
x
%
kThreadsPerWarp
;
extern
__shared__
float
shmem
[];
DataType
*
grad_scores_buf
=
reinterpret_cast
<
DataType
*>
(
shmem
);
// To store the output of softmax/sigmoid from the fwd
DataType
*
act_from_fwd_buf
=
reinterpret_cast
<
DataType
*>
(
grad_scores_buf
+
num_experts
*
num_token_per_block
);
DataType
*
comp_buf
=
reinterpret_cast
<
DataType
*>
(
act_from_fwd_buf
+
num_experts
*
num_token_per_block
);
// The address of buffers on the current warp
DataType
*
local_grad
=
grad_scores_buf
+
warp_id
*
num_experts
;
DataType
*
local_act_from_fwd
=
act_from_fwd_buf
+
warp_id
*
num_experts
;
DataType
*
local_comp_buf
=
comp_buf
+
warp_id
*
num_experts
;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int
total_round
=
(
num_tokens
+
num_token_per_block
-
1
)
/
num_token_per_block
;
for
(
int
round
=
blockIdx
.
x
;
round
<
total_round
;
round
+=
gridDim
.
x
)
{
int
token_offset_cur_warp
=
round
*
num_token_per_block
+
warp_id
;
// Each warp is responsible for one token
if
(
token_offset_cur_warp
>=
num_tokens
)
break
;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the dgrad/output_from_fwd to shmem
*/
int
pos_offset
=
token_offset_cur_warp
*
num_experts
;
// Clear the logits_grad in global mem
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
grad_logits
[
pos_offset
+
i
]
=
0.0
f
;
}
// Load the dgrad/output_from_fwd to shmem
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_grad
[
i
]
=
grad_scores
[
pos_offset
+
i
];
local_act_from_fwd
[
i
]
=
intermediate_output
[
pos_offset
+
i
];
}
__threadfence_block
();
__syncwarp
();
/***
* Section: Backward of ops before the topk
* - Pre-softmax bwd
* - Sigmoid Post-processing bwd when topk > 1
* - Sigmoid bwd
* - Write the grad_logits to the global mem
*/
// Sigmoid Post-processing bwd when topk > 1
if
(
topk
>
1
&&
score_function
==
0
)
{
auto
sum_fwd_input
=
warp_reduce_on_shmem
(
local_act_from_fwd
,
num_experts
,
sum
,
lane_id
);
// Put the result of output * grad to the comp_buf
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_comp_buf
[
i
]
=
local_grad
[
i
]
*
local_act_from_fwd
[
i
];
}
__syncwarp
();
auto
sum_Output_x_Grad
=
warp_reduce_on_shmem
(
local_comp_buf
,
num_experts
,
sum
,
lane_id
);
// In-place update
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_grad
[
i
]
=
static_cast
<
double
>
(
local_grad
[
i
])
/
(
static_cast
<
double
>
(
sum_fwd_input
)
+
epsilon
)
-
static_cast
<
double
>
(
sum_Output_x_Grad
)
/
((
static_cast
<
double
>
(
sum_fwd_input
)
+
epsilon
)
*
(
static_cast
<
double
>
(
sum_fwd_input
)
+
epsilon
));
}
}
__syncwarp
();
// Pre-softmax bwd
if
(
score_function
==
1
)
{
apply_softmax_bwd_on_float
(
local_grad
,
local_act_from_fwd
,
local_comp_buf
,
nullptr
,
num_experts
,
lane_id
);
__syncwarp
();
}
// Sigmoid bwd
if
(
score_function
==
0
)
{
apply_sigmoid_bwd_on_float
(
local_grad
,
local_act_from_fwd
,
num_experts
,
lane_id
);
__syncwarp
();
}
// Write the grad_logits to the global mem
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
grad_logits
[
pos_offset
+
i
]
=
local_grad
[
i
];
}
__syncwarp
();
}
}
template
<
typename
DataType
>
void
fused_score_for_moe_aux_loss_backward_kernel_launcher
(
const
DataType
*
intermediate_output
,
const
DataType
*
grad_scores
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
score_function
,
DataType
*
grad_logits
,
cudaStream_t
stream
)
{
// Meta data for the kernel
size_t
num_token_per_block
=
kThreadsPerBlock
/
kThreadsPerWarp
;
size_t
grid_size
=
(
num_tokens
+
num_token_per_block
-
1
)
/
num_token_per_block
;
size_t
shared_memory_size
=
num_experts
*
num_token_per_block
*
sizeof
(
DataType
)
// grad_scores
+
num_experts
*
num_token_per_block
*
sizeof
(
DataType
)
// act_from_fwd
+
num_experts
*
num_token_per_block
*
sizeof
(
DataType
);
// comp_buf
fused_score_for_moe_aux_loss_backward_kernel
<
DataType
>
<<<
grid_size
,
kThreadsPerBlock
,
shared_memory_size
,
stream
>>>
(
intermediate_output
,
grad_scores
,
num_tokens
,
num_experts
,
topk
,
score_function
,
grad_logits
);
}
void
fused_score_for_moe_aux_loss_backward
(
const
Tensor
&
intermediate_output
,
const
Tensor
&
grad_scores
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
score_function
,
Tensor
&
grad_logits
,
cudaStream_t
stream
)
{
TE_ROUTER_PROBS_TYPE_SWITCH_ALL
(
grad_scores
.
data
.
dtype
,
DataType
,
fused_score_for_moe_aux_loss_backward_kernel_launcher
<
DataType
>
(
reinterpret_cast
<
DataType
*>
(
intermediate_output
.
data
.
dptr
),
reinterpret_cast
<
DataType
*>
(
grad_scores
.
data
.
dptr
),
num_tokens
,
num_experts
,
topk
,
score_function
,
reinterpret_cast
<
DataType
*>
(
grad_logits
.
data
.
dptr
),
stream
););
}
}
// namespace transformer_engine
void
nvte_fused_score_for_moe_aux_loss_forward
(
const
NVTETensor
logits
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
score_function
,
NVTETensor
scores
,
const
NVTETensor
routing_map
,
const
NVTETensor
intermediate_output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_score_for_moe_aux_loss_forward
);
using
namespace
transformer_engine
;
fused_score_for_moe_aux_loss_forward
(
*
convertNVTETensorCheck
(
logits
),
num_tokens
,
num_experts
,
topk
,
score_function
,
*
convertNVTETensorCheck
(
scores
),
*
convertNVTETensorCheck
(
routing_map
),
*
convertNVTETensorCheck
(
intermediate_output
),
stream
);
}
void
nvte_fused_score_for_moe_aux_loss_backward
(
const
NVTETensor
intermediate_output
,
const
NVTETensor
grad_scores
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
score_function
,
NVTETensor
grad_logits
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_score_for_moe_aux_loss_backward
);
using
namespace
transformer_engine
;
fused_score_for_moe_aux_loss_backward
(
*
convertNVTETensorCheck
(
intermediate_output
),
*
convertNVTETensorCheck
(
grad_scores
),
num_tokens
,
num_experts
,
topk
,
score_function
,
*
convertNVTETensorCheck
(
grad_logits
),
stream
);
}
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
0 → 100644
View file @
44740c6c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_router.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
#include "utils.h"
namespace
transformer_engine
{
template
<
typename
DataType
,
typename
BiasType
>
__global__
void
fused_topk_with_score_function_forward_kernel
(
const
DataType
*
logits
,
int
num_tokens
,
int
num_experts
,
int
topk
,
bool
use_pre_softmax
,
int
num_groups
,
int
group_topk
,
float
scaling_factor
,
int
score_function
,
const
BiasType
*
expert_bias
,
DataType
*
probs
,
bool
*
routing_map
,
DataType
*
intermediate_output
)
{
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* So DataType address is assigned firstly to avoid the alignment issue
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int
num_token_per_block
=
blockDim
.
x
/
kThreadsPerWarp
;
int
warp_id
=
threadIdx
.
x
/
kThreadsPerWarp
;
int
lane_id
=
threadIdx
.
x
%
kThreadsPerWarp
;
extern
__shared__
float
shmem
[];
DataType
*
scores_buf
=
reinterpret_cast
<
DataType
*>
(
shmem
);
DataType
*
topk_scores_buf
=
reinterpret_cast
<
DataType
*>
(
scores_buf
+
num_experts
*
num_token_per_block
);
DataType
*
group_scores_buf
=
nullptr
,
*
masked_scores_buf
=
nullptr
;
int
*
topk_indices_buf
=
nullptr
;
if
(
group_topk
>
0
)
{
masked_scores_buf
=
reinterpret_cast
<
DataType
*>
(
topk_scores_buf
+
topk
*
num_token_per_block
);
group_scores_buf
=
reinterpret_cast
<
DataType
*>
(
masked_scores_buf
+
num_experts
*
num_token_per_block
);
topk_indices_buf
=
reinterpret_cast
<
int
*>
(
group_scores_buf
+
num_groups
*
num_token_per_block
);
}
else
{
topk_indices_buf
=
reinterpret_cast
<
int
*>
(
topk_scores_buf
+
topk
*
num_token_per_block
);
}
// The address of buffers on the current warp
DataType
*
scores
=
scores_buf
+
warp_id
*
num_experts
;
DataType
*
topk_scores
=
topk_scores_buf
+
warp_id
*
topk
;
DataType
*
masked_scores
=
masked_scores_buf
+
warp_id
*
num_experts
;
DataType
*
group_scores
=
group_scores_buf
+
warp_id
*
num_groups
;
int
*
topk_indices
=
topk_indices_buf
+
warp_id
*
topk
;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int
total_round
=
(
num_tokens
+
num_token_per_block
-
1
)
/
num_token_per_block
;
for
(
int
round
=
blockIdx
.
x
;
round
<
total_round
;
round
+=
gridDim
.
x
)
{
int
token_offset_cur_warp
=
round
*
num_token_per_block
+
warp_id
;
// Each warp is responsible for one token
if
(
token_offset_cur_warp
>=
num_tokens
)
break
;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the logits to shmem
*/
int
pos_offset
=
token_offset_cur_warp
*
num_experts
;
// Clear the probs/routing_map (num_experts)
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
probs
[
pos_offset
+
i
]
=
0.0
f
;
routing_map
[
pos_offset
+
i
]
=
false
;
if
(
score_function
==
1
)
{
intermediate_output
[
pos_offset
+
i
]
=
-
std
::
numeric_limits
<
DataType
>::
infinity
();
}
}
// Load the logits to shmem
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
scores
[
i
]
=
logits
[
pos_offset
+
i
];
}
// If group_topk > 0, init the masked_scores to -inf
if
(
group_topk
>
0
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
masked_scores
[
i
]
=
-
std
::
numeric_limits
<
DataType
>::
infinity
();
}
}
__threadfence_block
();
__syncwarp
();
/***
* Section: Preprocess
* Possible preprocess the scores before the topk operation
* - Pre-softmax
* - Sigmoid
* - Expert bias
* This is in-place scores update
*/
// score_function == 1 means softmax
if
(
use_pre_softmax
&&
score_function
==
1
)
{
// Apply softmax to the logits before the topk
apply_softmax_on_float
(
scores
,
num_experts
,
lane_id
);
__syncwarp
();
// Save the softmax output for backward
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
intermediate_output
[
pos_offset
+
i
]
=
scores
[
i
];
}
}
// score_function == 0 means sigmoid
if
(
score_function
==
0
)
{
// Apply sigmoid to the logits
apply_sigmoid_on_float
(
scores
,
num_experts
,
lane_id
);
__syncwarp
();
// Save the sigmoid output for backward
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
intermediate_output
[
pos_offset
+
i
]
=
scores
[
i
];
}
}
__syncwarp
();
//Confirm the scores is written to the softmax/sigmoid output
// Expert bias is only used at the sigmoid case
if
(
expert_bias
&&
score_function
==
0
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
scores
[
i
]
=
static_cast
<
DataType
>
(
static_cast
<
double
>
(
scores
[
i
])
+
static_cast
<
double
>
(
expert_bias
[
i
]));
}
}
__syncwarp
();
/***
* Section: Topk
* Get the topk indices
* - group_topk
* - naive topk
* - topk with expert bias
*/
// Topk on the scores
// The bias is not empty only happens at the sigmod case
if
(
group_topk
>
0
)
{
int
group_size
=
num_experts
/
num_groups
;
// Top2
for
(
int
i
=
0
;
i
<
num_groups
;
i
++
)
{
naive_topk_and_mask
(
/*scores ptr = */
scores
+
i
*
group_size
,
/*data size = */
group_size
,
/*topk = */
topk
/
group_topk
,
/*topk indices ptr = */
topk_indices
,
/*topk scores ptr = */
topk_scores
,
/*lane id = */
lane_id
);
__syncwarp
();
// Compute the group score
if
(
lane_id
==
0
)
{
DataType
tmp
=
0.0
f
;
for
(
int
j
=
0
;
j
<
topk
/
group_topk
;
j
++
)
{
tmp
=
tmp
+
topk_scores
[
j
];
}
group_scores
[
i
]
=
tmp
;
}
__syncwarp
();
}
// select the topk groups
naive_topk_and_mask
(
/*scores ptr = */
group_scores
,
/*data size = */
num_groups
,
/*topk = */
group_topk
,
/*topk indices ptr = */
topk_indices
,
/*topk scores ptr = */
topk_scores
,
/*lane id = */
lane_id
);
__syncwarp
();
// Copy the unmasked scores to the buffer
for
(
int
i
=
0
;
i
<
group_topk
;
i
++
)
{
int
st
=
topk_indices
[
i
]
*
group_size
;
int
ed
=
st
+
group_size
;
for
(
int
j
=
st
+
lane_id
;
j
<
ed
;
j
+=
kThreadsPerWarp
)
{
masked_scores
[
j
]
=
scores
[
j
];
}
}
__syncwarp
();
naive_topk_and_mask
(
masked_scores
,
num_experts
,
topk
,
topk_indices
,
topk_scores
,
lane_id
);
}
else
{
naive_topk_and_mask
(
scores
,
num_experts
,
topk
,
topk_indices
,
topk_scores
,
lane_id
);
}
__syncwarp
();
/***
* Section: Postprocess
* Possible postprocess the scores after the topk operation
* - Revert Expert bias
* - Softmax
* - Sigmoid post-processing when topk > 1
* - Write the result with scaling_factor
*/
// Revert Expert bias from the topk scores
if
(
expert_bias
&&
score_function
==
0
)
{
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
kThreadsPerWarp
)
{
topk_scores
[
i
]
=
static_cast
<
double
>
(
topk_scores
[
i
])
-
static_cast
<
double
>
(
expert_bias
[
topk_indices
[
i
]]);
}
}
__syncwarp
();
// score_function == 1 means softmax
if
(
!
use_pre_softmax
&&
score_function
==
1
)
{
// Apply softmax to the topk logits
apply_softmax_on_float
(
topk_scores
,
topk
,
lane_id
);
__syncwarp
();
// Save the softmax output for backward
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
kThreadsPerWarp
)
{
intermediate_output
[
pos_offset
+
topk_indices
[
i
]]
=
topk_scores
[
i
];
}
}
// score_function == 0 means sigmoid
if
(
score_function
==
0
)
{
if
(
topk
>
1
)
{
double
sum_scores
=
warp_reduce_on_shmem
(
topk_scores
,
topk
,
sum
,
lane_id
);
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
kThreadsPerWarp
)
{
topk_scores
[
i
]
=
static_cast
<
double
>
(
topk_scores
[
i
])
/
(
sum_scores
+
epsilon
);
}
}
__syncwarp
();
}
// Write the probs/routing_map to the output tensor
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
kThreadsPerWarp
)
{
routing_map
[
pos_offset
+
topk_indices
[
i
]]
=
true
;
probs
[
pos_offset
+
topk_indices
[
i
]]
=
scaling_factor
*
static_cast
<
double
>
(
topk_scores
[
i
]);
}
__threadfence_block
();
__syncwarp
();
}
}
template
<
typename
DataType
,
typename
BiasType
>
void
fused_topk_with_score_function_forward_kernel_launcher
(
const
DataType
*
logits
,
int
num_tokens
,
int
num_experts
,
int
topk
,
bool
use_pre_softmax
,
int
num_groups
,
int
group_topk
,
float
scaling_factor
,
int
score_function
,
const
BiasType
*
expert_bias
,
DataType
*
probs
,
bool
*
routing_map
,
DataType
*
intermediate_output
,
cudaStream_t
stream
)
{
size_t
num_token_per_block
=
kThreadsPerBlock
/
kThreadsPerWarp
;
size_t
grid_size
=
(
num_tokens
+
num_token_per_block
-
1
)
/
num_token_per_block
;
size_t
shared_memory_size
=
num_experts
*
num_token_per_block
*
sizeof
(
DataType
)
// scores
+
topk
*
num_token_per_block
*
sizeof
(
DataType
)
// topk_scores
+
topk
*
num_token_per_block
*
sizeof
(
int
);
// topk_indices
if
(
group_topk
>
0
)
{
shared_memory_size
+=
num_groups
*
num_token_per_block
*
sizeof
(
DataType
);
// group_scores
shared_memory_size
+=
num_experts
*
num_token_per_block
*
sizeof
(
DataType
);
// maksed_scores
}
fused_topk_with_score_function_forward_kernel
<
DataType
,
BiasType
>
<<<
grid_size
,
kThreadsPerBlock
,
shared_memory_size
,
stream
>>>
(
logits
,
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
num_groups
,
group_topk
,
scaling_factor
,
score_function
,
expert_bias
,
probs
,
routing_map
,
intermediate_output
);
}
void
fused_topk_with_score_function_forward
(
const
Tensor
logits
,
int
num_tokens
,
int
num_experts
,
int
topk
,
bool
use_pre_softmax
,
int
num_groups
,
int
group_topk
,
float
scaling_factor
,
int
score_function
,
const
Tensor
expert_bias
,
Tensor
probs
,
Tensor
routing_map
,
Tensor
intermediate_output
,
cudaStream_t
stream
)
{
TE_ROUTER_PROBS_TYPE_SWITCH_ALL
(
logits
.
data
.
dtype
,
DataType
,
TE_ROUTER_PROBS_TYPE_SWITCH_ALL
(
expert_bias
.
data
.
dtype
,
BiasType
,
fused_topk_with_score_function_forward_kernel_launcher
<
DataType
,
BiasType
>
(
reinterpret_cast
<
DataType
*>
(
logits
.
data
.
dptr
),
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
num_groups
,
group_topk
,
scaling_factor
,
score_function
,
reinterpret_cast
<
BiasType
*>
(
expert_bias
.
data
.
dptr
),
reinterpret_cast
<
DataType
*>
(
probs
.
data
.
dptr
),
reinterpret_cast
<
bool
*>
(
routing_map
.
data
.
dptr
),
reinterpret_cast
<
DataType
*>
(
intermediate_output
.
data
.
dptr
),
stream
);););
}
template
<
typename
DataType
>
__global__
void
fused_topk_with_score_function_backward_kernel
(
// Inputs tensor
const
bool
*
routing_map
,
const
DataType
*
intermediate_output
,
const
DataType
*
grad_probs
,
// Other parameters
int
num_tokens
,
int
num_experts
,
int
topk
,
bool
use_pre_softmax
,
float
scaling_factor
,
int
score_function
,
// Output tensor
DataType
*
grad_logits
)
{
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int
num_token_per_block
=
blockDim
.
x
/
kThreadsPerWarp
;
int
warp_id
=
threadIdx
.
x
/
kThreadsPerWarp
;
int
lane_id
=
threadIdx
.
x
%
kThreadsPerWarp
;
extern
__shared__
float
shmem
[];
DataType
*
grad_probs_buf
=
reinterpret_cast
<
DataType
*>
(
shmem
);
// To store the output of softmax/sigmoid from the fwd
DataType
*
act_from_fwd_buf
=
reinterpret_cast
<
DataType
*>
(
grad_probs_buf
+
num_experts
*
num_token_per_block
);
DataType
*
comp_buf
=
reinterpret_cast
<
DataType
*>
(
act_from_fwd_buf
+
num_experts
*
num_token_per_block
);
// To store the routing_map from the fwd
bool
*
routing_map_buf
=
reinterpret_cast
<
bool
*>
(
comp_buf
+
num_experts
*
num_token_per_block
);
// The address of buffers on the current warp
DataType
*
local_grad
=
grad_probs_buf
+
warp_id
*
num_experts
;
DataType
*
local_act_from_fwd
=
act_from_fwd_buf
+
warp_id
*
num_experts
;
DataType
*
local_comp_buf
=
comp_buf
+
warp_id
*
num_experts
;
bool
*
local_routing_map
=
routing_map_buf
+
warp_id
*
num_experts
;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int
total_round
=
(
num_tokens
+
num_token_per_block
-
1
)
/
num_token_per_block
;
for
(
int
round
=
blockIdx
.
x
;
round
<
total_round
;
round
+=
gridDim
.
x
)
{
int
token_offset_cur_warp
=
round
*
num_token_per_block
+
warp_id
;
// Each warp is responsible for one token
if
(
token_offset_cur_warp
>=
num_tokens
)
break
;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the dgrad/output_from_fwd to shmem
*/
int
pos_offset
=
token_offset_cur_warp
*
num_experts
;
// Clear the logits_grad in global mem
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
grad_logits
[
pos_offset
+
i
]
=
0.0
f
;
}
// Load the dgrad/output_from_fwd to shmem
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_grad
[
i
]
=
grad_probs
[
pos_offset
+
i
];
local_act_from_fwd
[
i
]
=
intermediate_output
[
pos_offset
+
i
];
local_routing_map
[
i
]
=
routing_map
[
pos_offset
+
i
];
}
__threadfence_block
();
__syncwarp
();
/***
* Section: Backward of ops after the topk
* - Backward of the used scaling_factor
* - Sigmoid Post-processing bwd when topk > 1
* - Softmax bwd if use_pre_softmax is false
*/
// Backward of the used scaling_factor
// In-place update
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
if
(
local_routing_map
[
i
])
{
local_grad
[
i
]
=
static_cast
<
double
>
(
local_grad
[
i
])
*
scaling_factor
;
}
}
__syncwarp
();
// Sigmoid Post-processing bwd when topk > 1
if
(
topk
>
1
&&
score_function
==
0
)
{
double
sum_fwd_input
=
masked_warp_reduce_on_shmem
(
/*data ptr = */
local_act_from_fwd
,
/*mask ptr = */
local_routing_map
,
/*data size = */
num_experts
,
/*reduce func = */
sum
,
lane_id
);
// Put the result of output * grad to the comp_buf
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
local_comp_buf
[
i
]
=
(
local_routing_map
[
i
]
?
static_cast
<
double
>
(
local_grad
[
i
])
*
static_cast
<
double
>
(
local_act_from_fwd
[
i
])
:
0.0
f
);
}
__syncwarp
();
double
sum_Output_x_Grad
=
masked_warp_reduce_on_shmem
(
/*data ptr = */
local_comp_buf
,
/*mask ptr = */
local_routing_map
,
/*data size = */
num_experts
,
/*reduce func = */
sum
,
lane_id
);
// In-place update
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
if
(
local_routing_map
[
i
])
{
local_grad
[
i
]
=
static_cast
<
double
>
(
local_grad
[
i
])
/
(
sum_fwd_input
+
epsilon
)
-
sum_Output_x_Grad
/
((
sum_fwd_input
+
epsilon
)
*
(
sum_fwd_input
+
epsilon
));
}
else
{
local_grad
[
i
]
=
0.0
f
;
}
}
}
__syncwarp
();
// Softmax bwd if use_pre_softmax is false
if
(
!
use_pre_softmax
&&
score_function
==
1
)
{
apply_softmax_bwd_on_float
(
local_grad
,
local_act_from_fwd
,
local_comp_buf
,
local_routing_map
,
num_experts
,
lane_id
);
__syncwarp
();
}
/***
* Section: Backward of topk
* mask the unselected position in the grad
*/
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
if
(
!
local_routing_map
[
i
])
{
local_grad
[
i
]
=
0.0
f
;
}
}
__syncwarp
();
/***
* Section: Backward of ops before the topk
* - Pre-softmax bwd
* - Sigmoid bwd
* - Write the grad_logits to the global mem
*/
// Pre-softmax bwd
if
(
score_function
==
1
&&
use_pre_softmax
)
{
apply_softmax_bwd_on_float
(
local_grad
,
local_act_from_fwd
,
local_comp_buf
,
nullptr
,
num_experts
,
lane_id
);
__syncwarp
();
}
// Sigmoid bwd
if
(
score_function
==
0
)
{
apply_sigmoid_bwd_on_float
(
local_grad
,
local_act_from_fwd
,
num_experts
,
lane_id
);
__syncwarp
();
}
// Write the grad_logits to the global mem
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kThreadsPerWarp
)
{
grad_logits
[
pos_offset
+
i
]
=
local_grad
[
i
];
}
__syncwarp
();
}
}
template
<
typename
DataType
>
void
fused_topk_with_score_function_backward_kernel_launcher
(
const
bool
*
routing_map
,
const
DataType
*
intermediate_output
,
const
DataType
*
grad_probs
,
int
num_tokens
,
int
num_experts
,
int
topk
,
bool
use_pre_softmax
,
float
scaling_factor
,
int
score_function
,
DataType
*
grad_logits
,
cudaStream_t
stream
)
{
// Meta data for the kernel
size_t
num_token_per_block
=
kThreadsPerBlock
/
kThreadsPerWarp
;
size_t
grid_size
=
(
num_tokens
+
num_token_per_block
-
1
)
/
num_token_per_block
;
size_t
shared_memory_size
=
num_experts
*
num_token_per_block
*
sizeof
(
DataType
)
// grad_probs
+
num_experts
*
num_token_per_block
*
sizeof
(
DataType
)
// act_from_fwd
+
num_experts
*
num_token_per_block
*
sizeof
(
DataType
)
// comp_buf
+
num_experts
*
num_token_per_block
*
sizeof
(
bool
);
// routing_map
fused_topk_with_score_function_backward_kernel
<
DataType
>
<<<
grid_size
,
kThreadsPerBlock
,
shared_memory_size
,
stream
>>>
(
routing_map
,
intermediate_output
,
grad_probs
,
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
scaling_factor
,
score_function
,
grad_logits
);
}
void
fused_topk_with_score_function_backward
(
const
Tensor
&
routing_map
,
const
Tensor
&
intermediate_output
,
const
Tensor
&
grad_probs
,
int
num_tokens
,
int
num_experts
,
int
topk
,
bool
use_pre_softmax
,
float
scaling_factor
,
int
score_function
,
Tensor
&
grad_logits
,
cudaStream_t
stream
)
{
TE_ROUTER_PROBS_TYPE_SWITCH_ALL
(
grad_logits
.
data
.
dtype
,
DataType
,
fused_topk_with_score_function_backward_kernel_launcher
<
DataType
>
(
reinterpret_cast
<
bool
*>
(
routing_map
.
data
.
dptr
),
reinterpret_cast
<
DataType
*>
(
intermediate_output
.
data
.
dptr
),
reinterpret_cast
<
DataType
*>
(
grad_probs
.
data
.
dptr
),
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
scaling_factor
,
score_function
,
reinterpret_cast
<
DataType
*>
(
grad_logits
.
data
.
dptr
),
stream
););
}
}
// namespace transformer_engine
void
nvte_fused_topk_with_score_function_forward
(
const
NVTETensor
logits
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
use_pre_softmax
,
int
num_groups
,
int
group_topk
,
float
scaling_factor
,
int
score_function
,
const
NVTETensor
expert_bias
,
NVTETensor
probs
,
NVTETensor
routing_map
,
NVTETensor
intermediate_output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_topk_with_score_function_forward
);
using
namespace
transformer_engine
;
fused_topk_with_score_function_forward
(
*
convertNVTETensorCheck
(
logits
),
num_tokens
,
num_experts
,
topk
,
static_cast
<
bool
>
(
use_pre_softmax
),
num_groups
,
group_topk
,
scaling_factor
,
score_function
,
*
convertNVTETensorCheck
(
expert_bias
),
*
convertNVTETensorCheck
(
probs
),
*
convertNVTETensorCheck
(
routing_map
),
*
convertNVTETensorCheck
(
intermediate_output
),
stream
);
}
void
nvte_fused_topk_with_score_function_backward
(
const
NVTETensor
routing_map
,
const
NVTETensor
intermediate_output
,
const
NVTETensor
grad_probs
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
use_pre_softmax
,
float
scaling_factor
,
int
score_function
,
NVTETensor
grad_logits
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_topk_with_score_function_backward
);
using
namespace
transformer_engine
;
fused_topk_with_score_function_backward
(
*
convertNVTETensorCheck
(
routing_map
),
*
convertNVTETensorCheck
(
intermediate_output
),
*
convertNVTETensorCheck
(
grad_probs
),
num_tokens
,
num_experts
,
topk
,
static_cast
<
bool
>
(
use_pre_softmax
),
scaling_factor
,
score_function
,
*
convertNVTETensorCheck
(
grad_logits
),
stream
);
}
transformer_engine/common/fused_router/utils.h
0 → 100644
View file @
44740c6c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
{
constexpr
size_t
kThreadsPerWarp
=
32
;
constexpr
int
kThreadsPerBlock
=
128
;
// Using 4 warps in 1 CTA, Each warp is responsible for 1 token.
constexpr
float
epsilon
=
1e-20
;
template
<
typename
T
>
__device__
inline
T
max
(
T
a
,
T
b
)
{
return
a
>
b
?
a
:
b
;
}
template
<
typename
T
>
__device__
inline
T
sum
(
T
a
,
T
b
)
{
return
a
+
b
;
}
template
<
typename
T
>
__device__
inline
T
warp_reduce_on_shmem
(
T
*
data_ptr
,
int
data_size
,
T
(
*
reduce_func
)(
T
,
T
),
int
lane_id
)
{
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile
double
val
=
lane_id
<
data_size
?
static_cast
<
double
>
(
data_ptr
[
lane_id
])
:
static_cast
<
double
>
(
0
);
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
val
=
reduce_func
(
val
,
data_ptr
[
i
]);
}
// Warp shuffle between threads
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
16
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
8
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
4
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
2
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
1
));
__syncwarp
();
return
T
(
val
);
}
template
<
typename
DataType
>
__device__
inline
void
apply_sigmoid_on_float
(
DataType
*
scores
,
int
data_size
,
int
lane_id
)
{
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
scores
[
i
]
=
static_cast
<
float
>
(
1.0
f
/
(
1.0
f
+
exp
(
-
static_cast
<
float
>
(
scores
[
i
]))));
}
}
template
<
typename
T
>
__device__
inline
T
masked_warp_reduce_on_shmem
(
T
*
data_ptr
,
bool
*
mask
,
int
data_size
,
T
(
*
reduce_func
)(
T
,
T
),
int
lane_id
)
{
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile
double
val
=
lane_id
<
data_size
&&
mask
[
lane_id
]
?
static_cast
<
double
>
(
data_ptr
[
lane_id
])
:
static_cast
<
double
>
(
0
);
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
if
(
mask
[
i
])
{
val
=
reduce_func
(
val
,
data_ptr
[
i
]);
}
}
// Warp shuffle between threads
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
16
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
8
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
4
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
2
));
val
=
reduce_func
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
1
));
__syncwarp
();
return
T
(
val
);
}
template
<
typename
DataType
>
__device__
inline
void
apply_sigmoid_bwd_on_float
(
DataType
*
grad
,
DataType
*
fwd_output
,
int
data_size
,
int
lane_id
)
{
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
grad
[
i
]
=
static_cast
<
double
>
(
grad
[
i
])
*
static_cast
<
double
>
(
fwd_output
[
i
])
*
(
1
-
static_cast
<
double
>
(
fwd_output
[
i
]));
}
}
template
<
typename
DataType
>
__device__
inline
void
apply_softmax_bwd_on_float
(
DataType
*
grad
,
DataType
*
fwd_output
,
DataType
*
comp_buf
,
bool
*
mask
,
int
data_size
,
int
lane_id
)
{
// Put the result of output * grad to the comp_buf
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
if
(
mask
)
{
if
(
mask
[
i
])
comp_buf
[
i
]
=
static_cast
<
float
>
(
grad
[
i
])
*
static_cast
<
float
>
(
fwd_output
[
i
]);
else
comp_buf
[
i
]
=
0.0
f
;
}
else
{
comp_buf
[
i
]
=
static_cast
<
float
>
(
grad
[
i
])
*
static_cast
<
float
>
(
fwd_output
[
i
]);
}
}
__syncwarp
();
float
sum_Output_x_Grad
=
warp_reduce_on_shmem
(
/*data ptr = */
comp_buf
,
/*data size = */
data_size
,
/*reduce func = */
sum
,
lane_id
);
// In-place update
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
if
(
mask
)
{
if
(
mask
[
i
])
grad
[
i
]
=
static_cast
<
float
>
(
fwd_output
[
i
])
*
(
static_cast
<
float
>
(
grad
[
i
])
-
sum_Output_x_Grad
);
else
grad
[
i
]
=
0.0
f
;
}
else
{
grad
[
i
]
=
static_cast
<
float
>
(
fwd_output
[
i
])
*
(
static_cast
<
float
>
(
grad
[
i
])
-
sum_Output_x_Grad
);
}
}
}
template
<
typename
DataType
>
__device__
inline
void
apply_softmax_on_float
(
DataType
*
scores
,
int
data_size
,
int
lane_id
)
{
// 1. compute the max of value
float
max_val
=
static_cast
<
float
>
(
warp_reduce_on_shmem
(
scores
,
data_size
,
max
,
lane_id
));
// 2. value -> exp_value
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
scores
[
i
]
=
static_cast
<
float
>
(
exp
(
static_cast
<
float
>
(
scores
[
i
])
-
max_val
));
}
__syncwarp
();
// 3. compute the sum of exp_value
float
sum_val
=
static_cast
<
float
>
(
warp_reduce_on_shmem
(
scores
,
data_size
,
sum
,
lane_id
));
// 4. update the softmax value
for
(
int
i
=
lane_id
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
scores
[
i
]
=
static_cast
<
float
>
(
scores
[
i
])
/
sum_val
;
}
__syncwarp
();
}
template
<
typename
T
>
__device__
inline
void
naive_topk_and_mask
(
T
*
scores
,
int
data_size
,
int
topk
,
int
*
topk_indices
,
T
*
topk_scores
,
int
lane_id
)
{
// Topk Times: Find the max value and its index
// Then mask it, and record the index in the topk_indices
// After looping topk times, the topk_indices will be the topk indices
for
(
int
k
=
0
;
k
<
topk
;
k
++
)
{
// Find the max value and its index
volatile
double
val
=
(
lane_id
<
data_size
)
?
static_cast
<
double
>
(
scores
[
lane_id
])
:
static_cast
<
double
>
(
0
);
volatile
int
index
=
(
lane_id
<
data_size
)
?
lane_id
:
0
;
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
for
(
int
i
=
lane_id
+
kThreadsPerWarp
;
i
<
data_size
;
i
+=
kThreadsPerWarp
)
{
volatile
double
cur_val
=
scores
[
i
];
if
(
cur_val
>
val
)
{
val
=
cur_val
;
index
=
i
;
}
}
// Warp shuffle between threads
for
(
int
s
=
16
;
s
>
0
;
s
/=
2
)
{
volatile
auto
shuffled_val
=
__shfl_xor_sync
(
0xffffffff
,
val
,
s
);
volatile
auto
shuffled_index
=
__shfl_xor_sync
(
0xffffffff
,
index
,
s
);
if
(
shuffled_val
>
val
)
{
val
=
shuffled_val
;
index
=
shuffled_index
;
}
}
if
(
lane_id
==
0
)
{
topk_indices
[
k
]
=
index
;
topk_scores
[
k
]
=
val
;
scores
[
index
]
=
static_cast
<
double
>
(
-
1.0
)
-
val
;
// make the selected experts using val = - 1 - val
}
__syncwarp
();
}
// Reset the scores to the original value
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
kThreadsPerWarp
)
{
scores
[
topk_indices
[
i
]]
=
static_cast
<
double
>
(
-
1.0
)
-
static_cast
<
double
>
(
scores
[
topk_indices
[
i
]]);
}
}
// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kInt32: { \
using type = int32_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt64: { \
using type = int64_t; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
}
// namespace transformer_engine
#endif
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
44740c6c
...
...
@@ -229,6 +229,13 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
return
ret
;
}
/* cuBLAS version number at run-time */
size_t
cublas_version
()
{
// Cache version to avoid cuBLAS logging overhead
static
size_t
version
=
cublasLtGetVersion
();
return
version
;
}
}
// namespace
#endif // __HIP_PLATFORM_AMD__
...
...
@@ -357,10 +364,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&
fastAccuMode
,
sizeof
(
fastAccuMode
)));
// Scaling factors.
#if CU
DA
_VERSION >= 12080
#if CU
BLAS
_VERSION >= 12080
0
cublasLtMatmulMatrixScale_t
scaling_mode_a
;
cublasLtMatmulMatrixScale_t
scaling_mode_b
;
#endif
#endif
// CUBLAS_VERSION >= 120800
if
((
is_tensor_scaling
(
inputA
->
scaling_mode
)
&&
is_tensor_scaling
(
inputB
->
scaling_mode
)))
{
void
*
A_scale_inverse
=
param
.
A_scale_inv
;
void
*
B_scale_inverse
=
param
.
B_scale_inv
;
...
...
@@ -370,10 +377,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
&
B_scale_inverse
,
sizeof
(
B_scale_inverse
)));
#if CU
DA
_VERSION >= 12080
#if CU
BLAS
_VERSION >= 12080
0
scaling_mode_a
=
CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F
;
scaling_mode_b
=
CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F
;
#endif // CUBLAS_VERSION >= 120800
}
else
if
((
is_mxfp_scaling
(
inputA
->
scaling_mode
)
&&
is_mxfp_scaling
(
inputB
->
scaling_mode
)))
{
#if CUBLAS_VERSION >= 120800
NVTE_CHECK
(
cublas_version
()
>=
120800
,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is "
,
cublas_version
());
fp8e8m0
*
A_scale_inverse
=
reinterpret_cast
<
fp8e8m0
*>
(
param
.
A_scale_inv
);
fp8e8m0
*
B_scale_inverse
=
reinterpret_cast
<
fp8e8m0
*>
(
param
.
B_scale_inv
);
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
...
...
@@ -386,17 +397,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
scaling_mode_b
=
CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0
;
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if
(
cublas
LtGetV
ersion
()
<=
120803
)
{
if
(
cublas
_v
ersion
()
<=
120803
)
{
const
int64_t
dummy_a_vec_stride
=
1
;
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE
,
&
dummy_a_vec_stride
,
sizeof
(
dummy_a_vec_stride
)));
}
#else
NVTE_ERROR
(
"MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is "
,
CUBLAS_VERSION
);
#endif // CUBLAS_VERSION >= 120800
}
else
if
((
inputA
->
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
inputA
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
&&
(
inputB
->
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
inputB
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
))
{
#if CUDA_VERSION >= 12090
#if CUBLAS_VERSION >= 120900
NVTE_CHECK
(
cublas_version
()
>=
120900
,
"FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is "
,
cublas_version
());
float
*
A_scale_inverse
=
reinterpret_cast
<
float
*>
(
param
.
A_scale_inv
);
float
*
B_scale_inverse
=
reinterpret_cast
<
float
*>
(
param
.
B_scale_inv
);
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
...
...
@@ -415,20 +433,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
?
CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
:
CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F
;
#else
NVTE_ERROR
(
"FP8 block scaling requires
CUDA 12.9+"
);
#endif // CUDA_VERSION >= 12090
#endif // CU
DA
_VERSION >= 120
8
0
NVTE_ERROR
(
"FP8 block scaling requires
cuBLAS 12.9+, but compile-time cuBLAS version is "
,
CUBLAS_VERSION
);
#endif // CU
BLAS
_VERSION >= 120
90
0
}
else
{
NVTE_ERROR
(
"Not implemented scaling modes: "
+
to_string
(
inputA
->
scaling_mode
)
+
" and "
+
to_string
(
inputB
->
scaling_mode
)
+
"."
);
}
#if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_A_SCALE_MODE
,
&
scaling_mode_a
,
sizeof
(
scaling_mode_a
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_B_SCALE_MODE
,
&
scaling_mode_b
,
sizeof
(
scaling_mode_b
)));
#endif
#if CUBLAS_VERSION >= 120800
if
(
cublas_version
()
>=
120800
)
{
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_A_SCALE_MODE
,
&
scaling_mode_a
,
sizeof
(
scaling_mode_a
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_B_SCALE_MODE
,
&
scaling_mode_b
,
sizeof
(
scaling_mode_b
)));
}
#endif // CUBLAS_VERSION >= 120800
if
(
is_fp8_dtype
(
outputD
->
data
.
dtype
))
{
// Accumulation mode not supported for FP8 output
C
=
nullptr
;
...
...
@@ -436,13 +458,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc
,
CUBLASLT_MATMUL_DESC_D_SCALE_POINTER
,
&
D_scale
,
sizeof
(
D_scale
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_AMAX_D_POINTER
,
&
D_amax
,
sizeof
(
D_amax
)));
#if CUDA_VERSION >= 12080
// NOTE: In all current cases where FP8 output is supported, the input is
// scaled identically to the output.
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_D_SCALE_MODE
,
&
scaling_mode_a
,
sizeof
(
scaling_mode_a
)));
#endif
#if CUBLAS_VERSION >= 120800
if
(
cublas_version
()
>=
120800
)
{
// NOTE: In all current cases where FP8 output is supported, the input is
// scaled identically to the output.
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_D_SCALE_MODE
,
&
scaling_mode_a
,
sizeof
(
scaling_mode_a
)));
}
#endif // CUBLAS_VERSION >= 120800
// For FP8 output, cuBLAS requires C_type to match bias_type and
// be FP16/BF16
const
cudaDataType_t
C_type
=
bias
?
bias_type
:
CUDA_R_16BF
;
...
...
@@ -510,9 +534,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
)));
if
(
counter
!=
nullptr
)
{
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is "
,
CUDA_VERSION
);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR
(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is "
,
CUBLAS_VERSION
);
#endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
if
(
counter
!=
nullptr
)
{
NVTE_CHECK
(
cuda
::
cudart_version
()
>=
12020
&&
cuda
::
cudart_version
()
<
13000
,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is "
,
cuda
::
cudart_version
());
NVTE_CHECK
(
cublas_version
()
>=
120205
&&
cublas_version
()
<
130000
,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is "
,
cublas_version
());
if
(
m_split
==
0
)
m_split
=
1
;
if
(
n_split
==
0
)
n_split
=
1
;
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
...
...
@@ -530,8 +569,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc
,
CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER
,
&
counter
,
sizeof
(
counter
)));
}
}
#endif
}
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceCreate
(
&
preference
));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
...
...
@@ -723,17 +762,27 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
)
{
NVTE_API_CALL
(
nvte_cublas_atomic_gemm
);
using
namespace
transformer_engine
;
#ifndef __HIP_PLATFORM_AMD__
int
cudart_version
;
NVTE_CHECK_CUDA
(
cudaRuntimeGetVersion
(
&
cudart_version
));
NVTE_CHECK
(
cudart_version
>=
12020
&&
cudart_version
<
13000
,
"Cuda version >=12.2 and <13.0 is required for atomic gemm."
);
NVTE_CHECK
(
cublasLtGetVersion
()
>=
120205
&&
cublasLtGetVersion
()
<
130000
,
"Cublas version >=12.2.5 and <13.0 is required for atomic gemm."
);
// Check CUDA and cuBLAS versions
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is "
,
CUDA_VERSION
);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR
(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is "
,
CUBLAS_VERSION
);
#endif
NVTE_CHECK
(
cuda
::
cudart_version
()
>=
12020
&&
cuda
::
cudart_version
()
<
13000
,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is "
,
cuda
::
cudart_version
());
NVTE_CHECK
(
cublas_version
()
>=
120205
&&
cublas_version
()
<
130000
,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is "
,
cublas_version
());
#endif
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
Tensor
*
outputD
=
convertNVTETensor
(
D
);
...
...
transformer_engine/common/include/transformer_engine/fused_router.h
0 → 100644
View file @
44740c6c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_H_
#define TRANSFORMER_ENGINE_FUSED_ROUTER_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern
"C"
{
#endif
/*! \brief Apply topk + softmax/sigmoid to the input tensor. Grouped topk is supported.
*
* \param[in] logits Logits from the gating GEMM.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] use_pre_softmax Whether to use softmax before topk.
* \param[in] num_groups Number of groups in grouped topk.
* \param[in] group_topk Grouped topk value.
* \param[in] scaling_factor Scaling factor.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[in] expert_bias Expert bias. (Only used at the sigmoid case)
* \param[out] probs Output tensor for probabilities.
* \param[out] routing_map Output tensor for routing map.
* \param[out] intermediate_output Output tensor for intermediate output. (Softmax/sigmoid output)
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fused_topk_with_score_function_forward
(
const
NVTETensor
logits
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
use_pre_softmax
,
int
num_groups
,
int
group_topk
,
float
scaling_factor
,
int
score_function
,
const
NVTETensor
expert_bias
,
NVTETensor
probs
,
NVTETensor
routing_map
,
NVTETensor
intermediate_output
,
cudaStream_t
stream
);
/*! \brief Backward pass for fused topk + softmax/sigmoid.
*
* \param[in] routing_map Routing map.
* \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output)
* \param[in] grad_probs Gradient of probs.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] use_pre_softmax Whether to use softmax before topk.
* \param[in] scaling_factor Scaling factor.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[out] grad_logits Gradient of logits.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fused_topk_with_score_function_backward
(
const
NVTETensor
routing_map
,
const
NVTETensor
intermediate_output
,
const
NVTETensor
grad_probs
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
use_pre_softmax
,
float
scaling_factor
,
int
score_function
,
NVTETensor
grad_logits
,
cudaStream_t
stream
);
/*! \brief Forward pass for computing scores/routing map for auxiliary loss.
*
* \param[in] logits Logits from the gating GEMM.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[out] scores Output tensor for scores.
* \param[in] routing_map Routing map.
* \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output)
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fused_score_for_moe_aux_loss_forward
(
const
NVTETensor
logits
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
score_function
,
NVTETensor
scores
,
const
NVTETensor
routing_map
,
const
NVTETensor
intermediate_output
,
cudaStream_t
stream
);
/*! \brief Backward pass for computing scores/routing map for auxiliary loss.
*
* \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output)
* \param[in] grad_scores Gradient of scores.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[out] grad_logits Gradient of logits.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fused_score_for_moe_aux_loss_backward
(
const
NVTETensor
intermediate_output
,
const
NVTETensor
grad_scores
,
int
num_tokens
,
int
num_experts
,
int
topk
,
int
score_function
,
NVTETensor
grad_logits
,
cudaStream_t
stream
);
/*! \brief Forward pass for auxiliary loss.
*
* \param[in] probs Probabilities from the forward pass.
* \param[in] tokens_per_expert Number of tokens per expert.
* \param[in] total_num_tokens Number of total tokens. Will be used in seq/global aux loss.
* \param[in] num_experts Number of experts.
* \param[in] num_rows Number of rows of probs.
* \param[in] num_cols Number of columns of probs.
* \param[in] topk Topk value.
* \param[in] coeff Coefficient.
* \param[out] aux_loss Output GPU scalar for auxiliary loss.
* \param[out] Const_buf Output GPU scalar for temporary constant buffer for backward pass.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fused_moe_aux_loss_forward
(
const
NVTETensor
probs
,
const
NVTETensor
tokens_per_expert
,
int
total_num_tokens
,
int
num_experts
,
int
num_rows
,
int
num_cols
,
int
topk
,
float
coeff
,
NVTETensor
aux_loss
,
NVTETensor
Const_buf
,
cudaStream_t
stream
);
/*! \brief Backward pass for auxiliary loss.
*
* \param[in] Const_buf Constant buffer from the forward pass.
* \param[in] tokens_per_expert Number of tokens per expert.
* \param[in] num_rows Number of rows of probs.
* \param[in] num_cols Number of columns of probs.
* \param[in] grad_aux_loss Gradient of auxiliary loss.
* \param[out] grad_probs Gradient of probs.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fused_moe_aux_loss_backward
(
const
NVTETensor
Const_buf
,
const
NVTETensor
tokens_per_expert
,
int
num_rows
,
int
num_cols
,
NVTETensor
grad_aux_loss
,
NVTETensor
grad_probs
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
#endif // TRANSFORMER_ENGINE_FUSED_ROPE_H_
transformer_engine/common/include/transformer_engine/multi_stream.h
View file @
44740c6c
...
...
@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_MULTI_STREAM_H
#define TRANSFORMER_ENGINE_MULTI_STREAM_H
#include "cuda_runtime.h"
#ifdef __cplusplus
extern
"C"
{
#endif
...
...
@@ -18,6 +20,26 @@ extern "C" {
/*! \brief Number of CUDA streams to use in multi-stream operations */
int
nvte_get_num_compute_streams
();
/*! \brief Get a CUDA stream for compute operations.
*
* \param[in] idx Index of the stream to retrieve.Add commentMore actions
* \return A cudaStream_t.
*
* This function returns a CUDA stream that can be used for compute operations.
* The index should be in the range [0, nvte_get_num_compute_streams() - 1].
*/
cudaStream_t
nvte_get_compute_stream
(
const
int
idx
);
/*! \brief Get a CUDA event for compute operations.
*
* \param[in] idx Index of the event to retrieve.
* \return A cudaEvent_t.
*
* This function returns a CUDA event that can be used to synchronize compute operations.
* The index should be in the range [0, nvte_get_num_compute_streams() - 1].
*/
cudaEvent_t
nvte_get_compute_stream_event
(
const
int
idx
);
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/include/transformer_engine/padding.h
View file @
44740c6c
...
...
@@ -44,6 +44,33 @@ extern "C" {
void
nvte_multi_padding
(
size_t
num_tensors
,
const
NVTETensor
*
input_list
,
NVTETensor
*
output_list
,
const
int
*
padded_num_rows_list
,
cudaStream_t
stream
);
/*! \brief Unpadding multiple tensors (reverse operation of padding).
*
* NOTE: Unpadding mode only removes bottom rows.
*
* For example, 4x3 matrix unpad to 3x3 matrix.
*
* source
* | 1 | 2 | 3 |
* | 4 | 5 | 6 |
* | 7 | 8 | 9 |
* | 0 | 0 | 0 |
*
* destination
* | 1 | 2 | 3 |
* | 4 | 5 | 6 |
* | 7 | 8 | 9 |
*
* \param[in] num_tensors Number of tensors.
* \param[in] input_list List of 2D padded input tensors.
* \param[in,out] output_list List of unpadded tensors. Dimensions
* match original unpadded tensors.
* \param[in] unpadded_num_rows_list List of unpadded num rows corresponding to input tensors.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_multi_unpadding
(
size_t
num_tensors
,
const
NVTETensor
*
input_list
,
NVTETensor
*
output_list
,
const
int
*
unpadded_num_rows_list
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/multi_tensor/adam.cu
View file @
44740c6c
...
...
@@ -226,8 +226,8 @@ struct AdamFunctorMasterParamRemainder {
r_m
[
ii
]
=
static_cast
<
MATH_T
>
(
m
[
i
]);
r_v
[
ii
]
=
static_cast
<
MATH_T
>
(
v
[
i
]);
local_p
[
ii
]
=
static_cast
<
int16_t
>
(
p
[
i
]
)
;
local_p_rem
[
ii
]
=
static_cast
<
int16_t
>
(
p_remainder
[
i
]
)
;
local_p
[
ii
]
=
p
[
i
];
local_p_rem
[
ii
]
=
p_remainder
[
i
];
}
else
{
r_g
[
ii
]
=
MATH_T
(
0
);
r_m
[
ii
]
=
MATH_T
(
0
);
...
...
@@ -281,8 +281,8 @@ struct AdamFunctorMasterParamRemainder {
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
p_remainder
[
i
]
=
static_cast
<
int16_t
>
(
local_p_rem
[
ii
]
)
;
p
[
i
]
=
static_cast
<
int16_t
>
(
local_p
[
ii
]
)
;
p_remainder
[
i
]
=
local_p_rem
[
ii
];
p
[
i
]
=
local_p
[
ii
];
m
[
i
]
=
static_cast
<
FULL_T
>
(
r_m
[
ii
]);
v
[
i
]
=
static_cast
<
FULL_T
>
(
r_v
[
ii
]);
...
...
@@ -467,8 +467,8 @@ struct AdamCapturableFunctor {
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
p
[
i
]
=
static_cast
<
T
>
(
r_p
[
ii
]);
m
[
i
]
=
static_cast
<
T
>
(
r_m
[
ii
]);
v
[
i
]
=
static_cast
<
T
>
(
r_v
[
ii
]);
m
[
i
]
=
static_cast
<
FULL_
T
>
(
r_m
[
ii
]);
v
[
i
]
=
static_cast
<
FULL_
T
>
(
r_v
[
ii
]);
}
}
}
...
...
@@ -578,9 +578,6 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
...
...
@@ -588,16 +585,48 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
}
size_t
max_size
=
0
;
// Check tensor list sizes
// 4 tensor lists: g, p, m, v
// 5 tensor lists: g, p, m, v, p_master
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
NVTE_CHECK
(
num_tensor_lists
==
4
||
num_tensor_lists
==
5
,
"Expected 4 or 5 tensor lists, but found "
,
num_tensor_lists
);
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
for
(
size_t
i
=
1
;
i
<
num_tensor_lists
;
i
++
)
{
NVTE_CHECK
(
tensor_lists
[
i
].
size
()
==
num_tensors_per_list
,
"Tensor list "
,
i
,
" has size="
,
tensor_lists
[
i
].
size
(),
", but expected size="
,
num_tensors_per_list
);
}
// Check tensor dtypes
const
auto
g_in_type_te
=
tensor_lists
[
0
][
0
]
->
dtype
();
const
auto
p_in_type_te
=
tensor_lists
[
1
][
0
]
->
dtype
();
for
(
size_t
j
=
0
;
j
<
num_tensors_per_list
;
j
++
)
{
NVTE_CHECK
(
tensor_lists
[
0
][
j
]
->
dtype
()
==
g_in_type_te
,
"Grad tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
0
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
g_in_type_te
));
NVTE_CHECK
(
tensor_lists
[
1
][
j
]
->
dtype
()
==
p_in_type_te
,
"Param tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
1
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
p_in_type_te
));
NVTE_CHECK
(
tensor_lists
[
2
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"First moment tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
2
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
NVTE_CHECK
(
tensor_lists
[
3
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"Second moment tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
3
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
if
(
num_tensor_lists
==
5
)
{
NVTE_CHECK
(
tensor_lists
[
4
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"Master param tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
4
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
}
}
// Check if 64-bit indices are required
bool
requires_64bit_indexing
=
false
;
for
(
size_t
i
=
0
;
i
<
num_tensor_lists
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
num_tensors_per_list
;
j
++
)
{
if
(
tensor_lists
[
i
][
j
]
->
numel
()
>
max_size
)
{
max_size
=
tensor_lists
[
i
][
j
]
->
numel
();
if
(
max_size
>=
INT_MAX
)
{
requires_64bit_indexing
=
true
;
break
;
}
if
(
tensor_lists
[
i
][
j
]
->
numel
()
>=
INT_MAX
)
{
requires_64bit_indexing
=
true
;
break
;
}
}
if
(
requires_64bit_indexing
)
{
...
...
@@ -605,16 +634,10 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
}
}
const
auto
g_in_type_te
=
tensor_lists
[
0
][
0
]
->
dtype
();
const
auto
p_in_type_te
=
tensor_lists
[
1
][
0
]
->
dtype
();
// case 4: g, p, m, v
// case 5: g, p, m, v, p_master
NVTE_CHECK
(
num_tensor_lists
==
4
||
num_tensor_lists
==
5
,
"tensor list must contain 4 or 5"
);
// Launch kernel
if
(
requires_64bit_indexing
)
{
if
(
num_tensor_lists
==
4
)
{
//
Assume single type across p,g,m1,m2 now
//
g, p, m, v
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
p_in_type_te
,
p_in_type
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
...
...
@@ -638,7 +661,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
}
}
else
{
if
(
num_tensor_lists
==
4
)
{
//
Assume single type across p,g,m1,m2 now
//
g, p, m, v
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
p_in_type_te
,
p_in_type
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
...
...
@@ -648,6 +671,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
}
else
{
// g, p, m, v, p_master
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
p_in_type_te
,
p_in_type
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
...
...
@@ -668,8 +692,6 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
...
...
@@ -677,23 +699,43 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
}
const
auto
g_in_type_te
=
tensor_lists
[
0
][
0
]
->
dtype
();
const
auto
p_in_type_te
=
tensor_lists
[
1
][
0
]
->
dtype
();
// case 5: g, p, m, v, p_master
NVTE_CHECK
(
num_tensor_lists
==
5
,
"tensor list must contain 5"
);
NVTE_CHECK
(
p_in_type_te
==
DType
::
kBFloat16
,
"Adam with BF16 param remainders requires BF16 params"
);
// Check tensor list sizes
// 5 tensor lists: g, p, m, v, p_remainder
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
NVTE_CHECK
(
num_tensor_lists
==
5
,
"Expected 5 tensor lists, but found "
,
num_tensor_lists
);
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
for
(
size_t
i
=
1
;
i
<
num_tensor_lists
;
i
++
)
{
NVTE_CHECK
(
tensor_lists
[
i
].
size
()
==
num_tensors_per_list
,
"Tensor list "
,
i
,
" has size="
,
tensor_lists
[
i
].
size
(),
", but expected size="
,
num_tensors_per_list
);
}
// g, p, m, v, p_master
// Check tensor dtypes
const
auto
g_in_type_te
=
tensor_lists
[
0
][
0
]
->
dtype
();
for
(
size_t
j
=
0
;
j
<
num_tensors_per_list
;
j
++
)
{
NVTE_CHECK
(
tensor_lists
[
0
][
j
]
->
dtype
()
==
g_in_type_te
,
"Grad tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
0
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
g_in_type_te
));
NVTE_CHECK
(
tensor_lists
[
1
][
j
]
->
dtype
()
==
DType
::
kBFloat16
,
"Param tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
1
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kBFloat16
));
NVTE_CHECK
(
tensor_lists
[
2
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"First moment tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
2
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
NVTE_CHECK
(
tensor_lists
[
3
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"Second moment tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
3
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
NVTE_CHECK
(
tensor_lists
[
4
][
j
]
->
dtype
()
==
DType
::
kInt16
,
"Param remainder tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
4
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kInt16
));
}
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
g_in_type_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
((
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMasterParamRemainder
<
g_in_type
,
float
,
int64_t
>
(),
device_id
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
););
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
...
...
@@ -703,9 +745,6 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
DType
fp8_dtype
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
...
...
@@ -713,16 +752,53 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
}
size_t
max_size
=
0
;
// Check tensor list sizes
// 8 tensor lists: g, p_fp8, m, v, p_master, scale, amax, scale_inv
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
NVTE_CHECK
(
num_tensor_lists
==
8
,
"Expected 8 tensor lists, but found "
,
num_tensor_lists
);
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
for
(
size_t
i
=
1
;
i
<
num_tensor_lists
;
i
++
)
{
NVTE_CHECK
(
tensor_lists
[
i
].
size
()
==
num_tensors_per_list
,
"Tensor list "
,
i
,
" has size="
,
tensor_lists
[
i
].
size
(),
", but expected size="
,
num_tensors_per_list
);
}
// Check tensor dtypes
const
auto
g_in_type_te
=
tensor_lists
[
0
][
0
]
->
dtype
();
for
(
size_t
j
=
0
;
j
<
num_tensors_per_list
;
j
++
)
{
NVTE_CHECK
(
tensor_lists
[
0
][
j
]
->
dtype
()
==
g_in_type_te
,
"Grad tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
0
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
g_in_type_te
));
NVTE_CHECK
(
tensor_lists
[
1
][
j
]
->
dtype
()
==
fp8_dtype
||
tensor_lists
[
1
][
j
]
->
dtype
()
==
DType
::
kByte
,
"Param tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
1
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
fp8_dtype
));
NVTE_CHECK
(
tensor_lists
[
2
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"First moment tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
2
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
NVTE_CHECK
(
tensor_lists
[
3
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"Second moment tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
3
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
NVTE_CHECK
(
tensor_lists
[
4
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"Master param tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
4
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
NVTE_CHECK
(
tensor_lists
[
5
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"Scale tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
5
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
NVTE_CHECK
(
tensor_lists
[
6
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"Absmax tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
6
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
NVTE_CHECK
(
tensor_lists
[
7
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"Scale-inverse tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
7
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
}
// Check if 64-bit indices are required
bool
requires_64bit_indexing
=
false
;
for
(
size_t
i
=
0
;
i
<
num_tensor_lists
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
num_tensors_per_list
;
j
++
)
{
if
(
tensor_lists
[
i
][
j
]
->
numel
()
>
max_size
)
{
max_size
=
tensor_lists
[
i
][
j
]
->
numel
();
if
(
max_size
>=
INT_MAX
)
{
requires_64bit_indexing
=
true
;
break
;
}
if
(
tensor_lists
[
i
][
j
]
->
numel
()
>=
INT_MAX
)
{
requires_64bit_indexing
=
true
;
break
;
}
}
if
(
requires_64bit_indexing
)
{
...
...
@@ -730,11 +806,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
}
}
const
auto
g_in_type_te
=
tensor_lists
[
0
][
0
]
->
dtype
();
// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
NVTE_CHECK
(
num_tensor_lists
==
8
,
"tensor list must contain 8 tensors"
);
// Launch kernel
if
(
requires_64bit_indexing
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
fp8_dtype
,
FP8_T
,
...
...
@@ -765,6 +837,34 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
Tensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
Tensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
// Check tensor list sizes
// 4 tensor lists: g, p, m, v
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
NVTE_CHECK
(
num_tensor_lists
==
4
,
"Expected 4 tensor lists, but found "
,
num_tensor_lists
);
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
for
(
size_t
i
=
1
;
i
<
num_tensor_lists
;
i
++
)
{
NVTE_CHECK
(
tensor_lists
[
i
].
size
()
==
num_tensors_per_list
,
"Tensor list "
,
i
,
" has size="
,
tensor_lists
[
i
].
size
(),
", but expected size="
,
num_tensors_per_list
);
}
// Check tensor dtypes
const
auto
g_in_type_te
=
tensor_lists
[
0
][
0
]
->
dtype
();
for
(
size_t
j
=
0
;
j
<
num_tensors_per_list
;
j
++
)
{
NVTE_CHECK
(
tensor_lists
[
0
][
j
]
->
dtype
()
==
g_in_type_te
,
"Grad tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
0
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
g_in_type_te
));
NVTE_CHECK
(
tensor_lists
[
1
][
j
]
->
dtype
()
==
g_in_type_te
,
"Param tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
1
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
g_in_type_te
));
NVTE_CHECK
(
tensor_lists
[
2
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"First moment tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
2
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
NVTE_CHECK
(
tensor_lists
[
3
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"Second moment tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
3
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
}
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
...
...
@@ -783,6 +883,37 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
const
int
bias_correction
,
const
float
weight_decay
,
Tensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
// Check tensor list sizes
// 4 tensor lists: g, p, m, v, p_master
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
NVTE_CHECK
(
num_tensor_lists
==
5
,
"Expected 4 tensor lists, but found "
,
num_tensor_lists
);
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
for
(
size_t
i
=
1
;
i
<
num_tensor_lists
;
i
++
)
{
NVTE_CHECK
(
tensor_lists
[
i
].
size
()
==
num_tensors_per_list
,
"Tensor list "
,
i
,
" has size="
,
tensor_lists
[
i
].
size
(),
", but expected size="
,
num_tensors_per_list
);
}
// Check tensor dtypes
const
auto
g_in_type_te
=
tensor_lists
[
0
][
0
]
->
dtype
();
for
(
size_t
j
=
0
;
j
<
num_tensors_per_list
;
j
++
)
{
NVTE_CHECK
(
tensor_lists
[
0
][
j
]
->
dtype
()
==
g_in_type_te
,
"Grad tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
0
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
g_in_type_te
));
NVTE_CHECK
(
tensor_lists
[
1
][
j
]
->
dtype
()
==
g_in_type_te
,
"Param tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
1
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
g_in_type_te
));
NVTE_CHECK
(
tensor_lists
[
2
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"First moment tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
2
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
NVTE_CHECK
(
tensor_lists
[
3
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"Second moment tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
3
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
NVTE_CHECK
(
tensor_lists
[
4
][
j
]
->
dtype
()
==
DType
::
kFloat32
,
"Master param tensor "
,
j
,
" has dtype="
,
to_string
(
tensor_lists
[
4
][
j
]
->
dtype
()),
", but expected dtype="
,
to_string
(
DType
::
kFloat32
));
}
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
...
...
transformer_engine/common/multi_tensor/multi_tensor_apply.cuh
View file @
44740c6c
...
...
@@ -52,7 +52,7 @@ class OptionalCUDAGuard {
~
OptionalCUDAGuard
()
{
if
(
device_changed_
)
{
NVTE_CHECK_CUDA
(
cudaSetDevice
(
prev_device_
)
)
;
cudaSetDevice
(
prev_device_
);
}
}
...
...
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment