Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
c1a1c04e
Commit
c1a1c04e
authored
Dec 27, 2025
by
wenjh
Browse files
Merge nv_main(2.10) to main
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
e698a0a7
66aed3ae
Changes
208
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2556 additions
and
330 deletions
+2556
-330
tests/pytorch/test_custom_recipe.py
tests/pytorch/test_custom_recipe.py
+42
-0
tests/pytorch/test_fused_rope.py
tests/pytorch/test_fused_rope.py
+138
-32
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+92
-33
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+1
-1
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+3
-1
tests/pytorch/utils.py
tests/pytorch/utils.py
+3
-0
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+290
-195
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+95
-48
transformer_engine/common/activation/activation_template.h
transformer_engine/common/activation/activation_template.h
+7
-20
transformer_engine/common/activation/gelu.cu
transformer_engine/common/activation/gelu.cu
+26
-0
transformer_engine/common/activation/relu.cu
transformer_engine/common/activation/relu.cu
+26
-0
transformer_engine/common/activation/swiglu.cu
transformer_engine/common/activation/swiglu.cu
+13
-0
transformer_engine/common/cast/cast.cu
transformer_engine/common/cast/cast.cu
+104
-0
transformer_engine/common/cast/core/common.cuh
transformer_engine/common/cast/core/common.cuh
+99
-0
transformer_engine/common/cast/dispatch/dequantize.cuh
transformer_engine/common/cast/dispatch/dequantize.cuh
+56
-0
transformer_engine/common/cast/dispatch/gated.cuh
transformer_engine/common/cast/dispatch/gated.cuh
+161
-0
transformer_engine/common/cast/dispatch/quantize.cuh
transformer_engine/common/cast/dispatch/quantize.cuh
+336
-0
transformer_engine/common/cast/fp8/dequantize_fp8.cuh
transformer_engine/common/cast/fp8/dequantize_fp8.cuh
+56
-0
transformer_engine/common/cast/fp8/gated_fp8.cuh
transformer_engine/common/cast/fp8/gated_fp8.cuh
+402
-0
transformer_engine/common/cast/fp8/quantize_fp8.cuh
transformer_engine/common/cast/fp8/quantize_fp8.cuh
+606
-0
No files found.
tests/pytorch/test_custom_recipe.py
View file @
c1a1c04e
...
...
@@ -17,6 +17,48 @@ from transformer_engine.pytorch import (
Float8CurrentScalingQuantizer
,
)
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.custom_recipes.quantization_nvfp4
import
(
nvfp4_ref_rht_2d_quantizer_factory
,
)
@
pytest
.
mark
.
parametrize
(
"module_type"
,
[
"Linear"
,
"LayerNormLinear"
,
"OpsLinear"
])
def
test_custom_recipe_sanity_modules_nvfp4
(
module_type
):
"""Test modules with NVFP4 custom recipe support"""
available
,
reason
=
te
.
is_fp8_available
(
return_reason
=
True
)
if
not
torch
.
cuda
.
is_available
()
or
not
available
:
pytest
.
skip
(
f
"FP8 unsupported on this device:
{
reason
}
"
)
torch
.
manual_seed
(
0
)
# Simple linear layer with dims divisible by 16
in_features
=
64
out_features
=
64
batch
=
32
if
module_type
==
"Linear"
:
model
=
Linear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
,
bias
=
False
).
cuda
()
elif
module_type
==
"LayerNormLinear"
:
model
=
LayerNormLinear
(
in_features
,
out_features
,
params_dtype
=
torch
.
bfloat16
,
bias
=
False
).
cuda
()
else
:
# OpsLinear
model
=
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
bias
=
False
)
inp
=
torch
.
randn
(
batch
,
in_features
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
True
)
# Use NVFP4 quantizer factory
custom_recipe
=
recipe
.
CustomRecipe
(
qfactory
=
nvfp4_ref_rht_2d_quantizer_factory
)
# Execute with custom recipe
with
autocast
(
enabled
=
True
,
recipe
=
custom_recipe
):
out
=
model
(
inp
)
loss
=
out
.
float
().
sum
()
loss
.
backward
()
# Basic sanity: gradients exist
assert
inp
.
grad
is
not
None
@
pytest
.
mark
.
parametrize
(
"module_type"
,
[
"Linear"
,
"LayerNormLinear"
,
"OpsLinear"
,
"LayerNormMLP"
])
...
...
tests/pytorch/test_fused_rope.py
View file @
c1a1c04e
...
...
@@ -58,10 +58,6 @@ def test_fused_rope(
# are with the maximum length of the rope embeddings.
pytest
.
skip
(
"Skipping test with margin=0 and start_positions=True"
)
if
start_positions
==
True
and
cp_size
>
1
:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest
.
skip
(
"Skipping test with cp_size>1 and start_positions=True"
)
device
=
torch
.
device
(
"cuda:0"
)
batch_size
,
head_num
=
2
,
64
t
=
torch
.
rand
(
...
...
@@ -102,11 +98,8 @@ def test_fused_rope(
cp_rank
=
cp_rank
,
).
to
(
dtype
)
loss_unfused
=
loss_func
(
output_unfused
)
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_unfused
.
backward
()
grad_unfused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
# fused
...
...
@@ -121,17 +114,12 @@ def test_fused_rope(
cp_rank
=
cp_rank
,
)
loss_fused
=
loss_func
(
output_fused
)
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_fused
.
backward
()
grad_fused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
torch
.
testing
.
assert_close
(
output_fused
,
output_unfused
)
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
assert
output_fused
.
is_contiguous
()
...
...
@@ -156,10 +144,6 @@ def test_fused_rope_thd(
margin
:
int
,
)
->
None
:
if
start_positions
==
True
and
cp_size
>
1
:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest
.
skip
(
"Skipping test with cp_size>1 and start_positions=True"
)
device
=
torch
.
device
(
"cuda:0"
)
batch_size
,
head_num
=
2
,
64
cu_seqlens
=
[
0
,
400
,
542
,
711
,
727
,
752
,
1270
,
1426
,
1450
,
1954
,
2044
,
2048
]
...
...
@@ -214,8 +198,6 @@ def test_fused_rope_thd(
cp_rank
=
cp_rank
,
).
to
(
dtype
)
loss_unfused
=
loss_func
(
output_unfused
)
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_unfused
.
backward
()
grad_unfused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
...
...
@@ -233,20 +215,144 @@ def test_fused_rope_thd(
cp_rank
=
cp_rank
,
)
loss_fused
=
loss_func
(
output_fused
)
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_fused
.
backward
()
grad_fused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
torch
.
testing
.
assert_close
(
output_fused
,
output_unfused
)
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
assert
output_fused
.
is_contiguous
()
@
pytest
.
mark
.
parametrize
(
"start_positions"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"rotary_percent"
,
[
1.0
])
@
pytest
.
mark
.
parametrize
(
"loss_func"
,
[
_overlapping_grad
])
@
pytest
.
mark
.
parametrize
(
"cp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
def
test_unfused_rope_thd_vs_bshd
(
dtype
:
torch
.
dtype
,
hidden_size
:
int
,
rotary_percent
:
float
,
loss_func
:
Callable
,
cp_size
:
int
,
interleaved
:
bool
,
start_positions
:
bool
,
)
->
None
:
"""
This is just a sanity check to ensure that the unfused RoPE in THD/SBHD/BSHD
formats are the same.
"""
device
=
torch
.
device
(
"cuda:0"
)
seqlen
,
max_seqlen
=
16
,
2048
batch_size
,
head_num
=
4
,
256
# NOTE: dtype=torch.int32 is important, otherwise the cumsum will be in int64 and
# that causes unexpected issues.
seq_lens
=
torch
.
tensor
([
seqlen
for
_
in
range
(
batch_size
)],
dtype
=
torch
.
int32
)
cu_seqlens
=
torch
.
cumsum
(
torch
.
cat
([
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
),
seq_lens
]),
dim
=
0
).
to
(
device
=
device
,
dtype
=
torch
.
int32
)
# Create a tensor in THD format
thd
=
torch
.
rand
(
(
cu_seqlens
[
-
1
]
//
cp_size
,
head_num
,
hidden_size
),
dtype
=
dtype
,
device
=
device
,
)
thd
.
requires_grad
=
True
# Clone the tensor to create a tensor in BSHD format
bshd
=
thd
.
view
(
batch_size
,
-
1
,
head_num
,
hidden_size
).
clone
().
detach
()
bshd
=
bshd
.
to
(
dtype
=
dtype
,
device
=
device
)
bshd
.
requires_grad
=
True
# Clone the tensor to create a tensor in SBHD format
sbhd
=
bshd
.
transpose
(
1
,
0
).
clone
().
detach
()
sbhd
=
sbhd
.
to
(
dtype
=
dtype
,
device
=
device
)
sbhd
.
requires_grad
=
True
rotary_pos_emb
=
RotaryPositionEmbedding
(
hidden_size
,
rotary_percent
,
interleaved
=
interleaved
)
emb
=
rotary_pos_emb
(
max_seqlen
)
assert
emb
.
is_contiguous
()
start_positions
=
cu_seqlens
[:
-
1
]
if
start_positions
else
None
for
cp_rank
in
range
(
cp_size
):
# unfused bshd
output_unfused_bshd
=
apply_rotary_pos_emb
(
bshd
.
float
(),
emb
,
start_positions
=
start_positions
,
interleaved
=
interleaved
,
fused
=
False
,
tensor_format
=
"bshd"
,
cu_seqlens
=
cu_seqlens
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
).
to
(
dtype
)
loss_unfused_bshd
=
loss_func
(
output_unfused_bshd
)
loss_unfused_bshd
.
backward
()
grad_unfused_bshd
=
bshd
.
grad
.
detach
().
clone
()
bshd
.
grad
=
None
# unfused sbhd
output_unfused_sbhd
=
apply_rotary_pos_emb
(
sbhd
.
float
(),
emb
,
start_positions
=
start_positions
,
interleaved
=
interleaved
,
fused
=
False
,
tensor_format
=
"sbhd"
,
cu_seqlens
=
cu_seqlens
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
).
to
(
dtype
)
loss_unfused_sbhd
=
loss_func
(
output_unfused_sbhd
)
loss_unfused_sbhd
.
backward
()
grad_unfused_sbhd
=
sbhd
.
grad
.
detach
().
clone
()
sbhd
.
grad
=
None
# unfused thd
output_unfused_thd
=
apply_rotary_pos_emb
(
thd
.
float
(),
emb
,
start_positions
=
start_positions
,
tensor_format
=
"thd"
,
interleaved
=
interleaved
,
fused
=
False
,
cu_seqlens
=
cu_seqlens
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
).
to
(
dtype
)
loss_unfused_thd
=
loss_func
(
output_unfused_thd
)
loss_unfused_thd
.
backward
()
grad_unfused_thd
=
thd
.
grad
.
detach
().
clone
()
thd
.
grad
=
None
torch
.
testing
.
assert_close
(
output_unfused_bshd
.
reshape
(
*
output_unfused_thd
.
shape
),
output_unfused_thd
)
torch
.
testing
.
assert_close
(
output_unfused_sbhd
.
transpose
(
1
,
0
).
reshape
(
*
output_unfused_thd
.
shape
),
output_unfused_thd
,
)
torch
.
testing
.
assert_close
(
grad_unfused_bshd
.
reshape
(
*
grad_unfused_thd
.
shape
),
grad_unfused_thd
)
torch
.
testing
.
assert_close
(
grad_unfused_sbhd
.
transpose
(
1
,
0
).
reshape
(
*
grad_unfused_thd
.
shape
),
grad_unfused_thd
)
assert
output_unfused_thd
.
is_contiguous
()
assert
output_unfused_bshd
.
is_contiguous
()
assert
output_unfused_sbhd
.
is_contiguous
()
@
pytest
.
mark
.
parametrize
(
"start_positions"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"seq_length"
,
[
2
,
8
,
2048
,
4096
])
...
...
tests/pytorch/test_numerics.py
View file @
c1a1c04e
...
...
@@ -41,21 +41,22 @@ from transformer_engine.pytorch import (
is_mxfp8_available
,
is_fp8_block_scaling_available
,
is_bf16_available
,
is_nvfp4_available
,
)
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
,
get_available_attention_backends
from
utils
import
ModelConfig
,
reset_rng_states
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
is_mxfp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
=
is_fp8_block_scaling_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
is_fp8_block_scaling_available
(
return_reason
=
True
)
nvfp4_available
=
is_nvfp4_available
()
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
...
...
@@ -120,6 +121,43 @@ if NVTE_TEST_NVINSPECT_ENABLED:
)
def
nvfp4_rht_and_2d_quantization
():
nvfp4_recipe
=
recipe
.
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
recipe
.
QParams
(
random_hadamard_transform
=
True
,
fp4_2d_quantization
=
False
)
nvfp4_recipe
.
fp4_quant_fwd_weight
=
recipe
.
QParams
(
random_hadamard_transform
=
False
,
fp4_2d_quantization
=
True
)
nvfp4_recipe
.
fp4_quant_bwd_grad
=
recipe
.
QParams
(
random_hadamard_transform
=
True
,
fp4_2d_quantization
=
False
)
return
nvfp4_recipe
def
check_rht_usage
(
recipe
:
recipe
.
Recipe
)
->
bool
:
# if using RHT, we can only support bf16
# check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
if
recipe
.
nvfp4
():
if
(
recipe
.
fp4_quant_fwd_inp
.
random_hadamard_transform
or
recipe
.
fp4_quant_fwd_weight
.
random_hadamard_transform
or
recipe
.
fp4_quant_bwd_grad
.
random_hadamard_transform
):
return
True
return
False
def
get_nvfp4_inp_supported_dtypes
(
recipe
:
recipe
.
Recipe
,
dtype
:
torch
.
dtype
)
->
bool
:
supported_input_dtypes
=
[]
if
recipe
.
nvfp4
():
supported_input_dtypes
.
append
(
torch
.
bfloat16
)
# if not using RHT, we can add fp32 as well
if
not
check_rht_usage
(
recipe
):
supported_input_dtypes
.
append
(
torch
.
float32
)
return
supported_input_dtypes
fp8_recipes
=
[]
if
mxfp8_available
:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
...
...
@@ -128,6 +166,8 @@ if fp8_block_scaling_available:
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
if
nvfp4_available
:
fp8_recipes
.
append
(
nvfp4_rht_and_2d_quantization
())
use_cutlass_grouped_gemm
=
[
False
]
# Only enable cutlass grouped gemm on Hopper
...
...
@@ -135,23 +175,6 @@ if torch.cuda.get_device_capability() == (9, 0):
use_cutlass_grouped_gemm
.
append
(
True
)
def
is_fused_attn_available
(
config
:
ModelConfig
,
dtype
:
torch
.
dtype
,
qkv_layout
=
"bshd_bshd_bshd"
,
is_training
=
True
,
deterministic
=
False
,
):
_
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
deterministic
=
deterministic
,
)
return
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
in
fused_attn_backends
def
get_causal_attn_mask
(
sq
:
int
)
->
torch
.
Tensor
:
return
torch
.
triu
(
torch
.
ones
(
sq
,
sq
,
device
=
"cuda"
),
diagonal
=
1
).
bool
()
...
...
@@ -612,6 +635,11 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
nvfp4
():
if
dtype
not
in
get_nvfp4_inp_supported_dtypes
(
recipe
,
dtype
):
pytest
.
skip
(
f
"Input dtype
{
dtype
}
not supported for NVFP4 Recipe
{
recipe
.
__class__
.
__name__
}
"
)
config
=
model_configs
[
model
]
...
...
@@ -729,6 +757,11 @@ def test_gpt_full_activation_recompute(
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
nvfp4
():
if
dtype
not
in
get_nvfp4_inp_supported_dtypes
(
recipe
,
dtype
):
pytest
.
skip
(
f
"Input dtype
{
dtype
}
not supported for NVFP4 Recipe
{
recipe
.
__class__
.
__name__
}
"
)
config
=
model_configs
[
model
]
...
...
@@ -872,8 +905,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
def
test_gpt_checkpointing
(
dtype
,
bs
,
model
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
outputs
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
True
)
...
...
@@ -920,10 +951,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
def
test_gpt_accuracy
(
dtype
,
bs
,
model
,
parallel_attention_mlp
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_gpt
=
TransformerLayer
(
hidden_size
=
config
.
hidden_size
,
...
...
@@ -1035,10 +1062,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@
pytest
.
mark
.
parametrize
(
"mask_type"
,
mask_types
)
def
test_mha_accuracy
(
dtype
,
bs
,
model
,
mask_type
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_mha
=
MultiheadAttention
(
config
.
hidden_size
,
...
...
@@ -1327,6 +1350,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
if
recipe
is
not
None
and
recipe
.
nvfp4
():
if
dtype
not
in
get_nvfp4_inp_supported_dtypes
(
recipe
,
dtype
):
pytest
.
skip
(
f
"Input dtype
{
dtype
}
not supported for NVFP4 Recipe
{
recipe
.
__class__
.
__name__
}
"
)
with
quantized_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
te_linear_ref
=
Linear
(
config
.
hidden_size
,
...
...
@@ -1770,8 +1799,8 @@ def _test_grouped_linear_accuracy(
split_size
=
1
if
fp8
:
split_size
=
16
if
recipe
.
mxfp8
():
split_size
=
128
if
recipe
.
mxfp8
()
or
recipe
.
nvfp4
()
:
split_size
=
32
m
=
config
.
max_seqlen_q
//
split_size
dist
=
torch
.
sort
(
torch
.
randint
(
0
,
m
,
(
num_gemms
-
2
,))).
values
.
tolist
()
dist
.
append
(
dist
[
-
1
])
# Manually add a zero
...
...
@@ -1849,6 +1878,12 @@ def test_grouped_linear_accuracy(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
if
recipe
is
not
None
and
recipe
.
nvfp4
():
if
dtype
not
in
get_nvfp4_inp_supported_dtypes
(
recipe
,
dtype
):
pytest
.
skip
(
f
"Input dtype
{
dtype
}
not supported for NVFP4 Recipe
{
recipe
.
__class__
.
__name__
}
"
)
with
quantized_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
GroupedLinear
(
num_gemms
,
...
...
@@ -1993,6 +2028,12 @@ def test_grouped_linear_accuracy_save_original_input(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
if
recipe
is
not
None
and
recipe
.
nvfp4
():
if
dtype
not
in
get_nvfp4_inp_supported_dtypes
(
recipe
,
dtype
):
pytest
.
skip
(
f
"Input dtype
{
dtype
}
not supported for NVFP4 Recipe
{
recipe
.
__class__
.
__name__
}
"
)
with
quantized_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
GroupedLinear
(
num_gemms
,
...
...
@@ -2086,7 +2127,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
def
_pad_tensor_for_fp8
(
hidden_states
,
tokens_per_expert
):
align_size
=
16
if
recipe
.
mxfp8
():
if
recipe
.
mxfp8
()
or
recipe
.
nvfp4
()
:
align_size
=
32
padded_tokens_per_expert
=
[
(
num_tokens
+
align_size
-
1
)
//
align_size
*
align_size
...
...
@@ -2207,6 +2248,12 @@ def test_padding_grouped_linear_accuracy(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
if
recipe
is
not
None
and
recipe
.
nvfp4
():
if
dtype
not
in
get_nvfp4_inp_supported_dtypes
(
recipe
,
dtype
):
pytest
.
skip
(
f
"Input dtype
{
dtype
}
not supported for NVFP4 Recipe
{
recipe
.
__class__
.
__name__
}
"
)
with
quantized_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
TorchGroupedLinearWithPadding
(
num_gemms
,
...
...
@@ -2284,6 +2331,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
if
recipe
is
not
None
and
recipe
.
nvfp4
():
if
dtype
not
in
get_nvfp4_inp_supported_dtypes
(
recipe
,
dtype
):
pytest
.
skip
(
f
"Input dtype
{
dtype
}
not supported for NVFP4 Recipe
{
recipe
.
__class__
.
__name__
}
"
)
with
quantized_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
grouped_linear
=
TorchGroupedLinearWithPadding
(
num_gemms
,
...
...
@@ -2499,6 +2552,12 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
recipe
.
nvfp4
():
if
dtype
not
in
get_nvfp4_inp_supported_dtypes
(
recipe
,
dtype
):
pytest
.
skip
(
f
"Input dtype
{
dtype
}
not supported for NVFP4 Recipe
{
recipe
.
__class__
.
__name__
}
"
)
config
=
model_configs
[
model
]
outputs
=
_test_gpt_fp8_parameters
(
bs
,
dtype
,
config
,
False
,
recipe
)
...
...
tests/pytorch/test_onnx_export.py
View file @
c1a1c04e
...
...
@@ -68,7 +68,7 @@ if fp8_available:
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
fp8_recipes
.
append
(
None
)
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
,
"clamped_swiglu"
]
all_normalizations
=
[
"LayerNorm"
,
"RMSNorm"
]
...
...
tests/pytorch/test_sanity.py
View file @
c1a1c04e
...
...
@@ -123,6 +123,7 @@ all_activations = [
"sreglu"
,
"silu"
,
"swiglu"
,
"clamped_swiglu"
,
]
all_normalizations
=
[
"LayerNorm"
,
"RMSNorm"
]
...
...
@@ -566,7 +567,7 @@ def test_sanity_layernorm_mlp(
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
activation_params
=
None
if
activation
!=
"clamped_swiglu"
else
{
"limit"
:
7.0
,
"alpha"
:
1.702
}
block
=
LayerNormMLP
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
...
...
@@ -574,6 +575,7 @@ def test_sanity_layernorm_mlp(
output_layer_init_method
=
output_layer_init_method
,
zero_centered_gamma
=
zero_centered_gamma
,
activation
=
activation
,
activation_params
=
activation_params
,
normalization
=
normalization
,
params_dtype
=
dtype
,
device
=
"cuda"
,
...
...
tests/pytorch/utils.py
View file @
c1a1c04e
...
...
@@ -205,6 +205,7 @@ class ModelConfig:
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
context_parallel
:
bool
=
False
,
cp_comm_type
:
str
=
"p2p"
,
return_max_logit
=
False
,
total_requests
:
int
=
None
,
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
...
...
@@ -233,6 +234,7 @@ class ModelConfig:
self
.
window_size
=
check_set_window_size
(
self
.
attn_mask_type
,
window_size
)
self
.
context_parallel
=
context_parallel
self
.
cp_comm_type
=
cp_comm_type
self
.
return_max_logit
=
return_max_logit
self
.
total_requests
=
total_requests
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
...
...
@@ -318,6 +320,7 @@ def get_available_attention_backends(
is_training
=
is_training
,
inference_params
=
inference_params
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
)
(
use_flash_attention
,
...
...
transformer_engine/common/CMakeLists.txt
View file @
c1a1c04e
...
...
@@ -29,15 +29,6 @@ endif()
# Language options
if
(
USE_CUDA
)
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
elseif
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120
)
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90
)
endif
()
endif
()
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
...
...
@@ -54,8 +45,62 @@ if(USE_CUDA)
# CUDA Toolkit
find_package
(
CUDAToolkit REQUIRED
)
if
(
CUDAToolkit_VERSION VERSION_LESS 12.0
)
message
(
FATAL_ERROR
"CUDA 12.0+ is required, but found CUDA
${
CUDAToolkit_VERSION
}
"
)
if
(
CUDAToolkit_VERSION VERSION_LESS 12.1
)
message
(
FATAL_ERROR
"CUDA 12.1+ is required, but found CUDA
${
CUDAToolkit_VERSION
}
"
)
endif
()
# Process GPU architectures
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
elseif
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120
)
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90
)
endif
()
endif
()
# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures
set
(
NVTE_GENERIC_ARCHS
)
set
(
NVTE_SPECIFIC_ARCHS
)
# Check for architecture 100
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"100"
arch_100_index
)
if
(
NOT arch_100_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"100"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"100"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"100a"
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"103a"
)
endif
()
endif
()
# Check for architecture 101 (if we see this we are in toolkit <= 12.9)
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"101"
arch_101_index
)
if
(
NOT arch_101_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"101"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"101"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"101a"
)
endif
()
# Check for architecture 110 (if we see this we are in toolkit >= 13.0)
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"110"
arch_110_index
)
if
(
NOT arch_110_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"110"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"110"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"110f"
)
endif
()
# Check for architecture 120
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"120"
arch_120_index
)
if
(
NOT arch_120_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"120"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"120"
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"120f"
)
else
()
list
(
APPEND NVTE_SPECIFIC_ARCHS
"120a"
)
endif
()
endif
()
# cuDNN frontend API
...
...
@@ -110,38 +155,32 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python
find_package
(
Python COMPONENTS Interpreter Development.Module REQUIRED
)
if
(
USE_CUDA
)
# NVIDIA MathDX include directory (from Python package install location)
if
(
NOT DEFINED MATHDX_INCLUDE_DIR
)
execute_process
(
COMMAND
${
Python_EXECUTABLE
}
-m pip show nvidia-mathdx
OUTPUT_VARIABLE _PIP_SHOW_MATHDX
ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR
RESULT_VARIABLE _PIP_SHOW_MATHDX_RES
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if
(
NOT _PIP_SHOW_MATHDX_RES EQUAL 0
)
message
(
FATAL_ERROR
"Failed to query 'nvidia-mathdx' with pip (using
${
Python_EXECUTABLE
}
):
${
_PIP_SHOW_MATHDX_ERR
}
"
)
endif
()
string
(
REGEX MATCH
"Location: ([^
\n\r
]+)"
_MATHDX_LOC_MATCH
"
${
_PIP_SHOW_MATHDX
}
"
)
if
(
NOT _MATHDX_LOC_MATCH
)
message
(
FATAL_ERROR
"Could not parse installation location for 'nvidia-mathdx'. Output was:
\n
${
_PIP_SHOW_MATHDX
}
"
)
endif
()
set
(
MATHDX_LOCATION
"
${
CMAKE_MATCH_1
}
"
)
set
(
MATHDX_INCLUDE_DIR
"
${
MATHDX_LOCATION
}
/nvidia/mathdx/include"
)
endif
()
if
(
NOT EXISTS
"
${
MATHDX_INCLUDE_DIR
}
"
)
message
(
FATAL_ERROR
"MATHDX include directory not found at
${
MATHDX_INCLUDE_DIR
}
. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for
${
Python_EXECUTABLE
}
."
)
endif
()
endif
()
# Configure Transformer Engine library
include_directories
(
${
PROJECT_SOURCE_DIR
}
/..
)
set
(
transformer_engine_SOURCES
)
set
(
transformer_engine_cpp_sources
)
set
(
transformer_engine_cuda_sources
)
set
(
transformer_engine_cuda_arch_specific_sources
)
if
(
USE_CUDA
)
list
(
APPEND transformer_engine_
SOURCES
list
(
APPEND transformer_engine_
cpp_sources
cudnn_utils.cpp
transformer_engine.cpp
fused_attn/fused_attn.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp
)
list
(
APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
...
...
@@ -153,40 +192,23 @@ if(USE_CUDA)
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
activation/gelu.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
...
...
@@ -200,26 +222,91 @@ if(USE_CUDA)
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu
)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
hadamard_transform/hadamard_transform_cast_fusion.cu
)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list
(
APPEND transformer_engine_SOURCES
${
transformer_engine_cuda_arch_specific_sources
}
${
transformer_engine_cuda_sources
}
${
transformer_engine_cpp_sources
}
)
# Set compile options for CUDA sources with generic architectures
foreach
(
cuda_source IN LISTS transformer_engine_cuda_sources
)
set
(
arch_compile_options
)
foreach
(
arch IN LISTS NVTE_GENERIC_ARCHS
)
list
(
APPEND arch_compile_options
"--generate-code=arch=compute_
${
arch
}
,code=sm_
${
arch
}
"
)
endforeach
()
if
(
arch_compile_options
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
${
arch_compile_options
}
)
endif
()
endforeach
()
# Set compile options for CUDA sources with specific architectures
foreach
(
cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources
)
set
(
arch_compile_options
)
foreach
(
arch IN LISTS NVTE_SPECIFIC_ARCHS
)
list
(
APPEND arch_compile_options
"--generate-code=arch=compute_
${
arch
}
,code=sm_
${
arch
}
"
)
endforeach
()
if
(
arch_compile_options
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
${
arch_compile_options
}
)
endif
()
endforeach
()
if
(
NVTE_WITH_CUBLASMP
)
list
(
APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp
)
endif
()
add_library
(
transformer_engine SHARED
${
transformer_engine_SOURCES
}
)
# CUTLASS kernels require SM90a and cause hang in debug build
set_property
(
SOURCE gemm/cutlass_grouped_gemm.cu
APPEND
PROPERTY
COMPILE_OPTIONS
"--generate-code=arch=compute_90a,code=sm_90a;-g0"
)
else
()
list
(
APPEND transformer_engine_
SOURCES
list
(
APPEND transformer_engine_
cpp_sources
cudnn_utils.cpp
transformer_engine.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp
)
list
(
APPEND transformer_engine_cuda_sources
common.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/utils.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
...
...
@@ -230,31 +317,21 @@ else()
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
dropout/dropout.cu
activation/relu.cu
activation/swiglu.cu
gemm/config.cpp
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
...
...
@@ -267,10 +344,22 @@ else()
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu
)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list
(
APPEND transformer_engine_SOURCES
${
transformer_engine_cuda_arch_specific_sources
}
${
transformer_engine_cuda_sources
}
${
transformer_engine_cpp_sources
}
)
if
(
NVTE_WITH_CUBLASMP
)
list
(
APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp
)
...
...
@@ -311,14 +400,16 @@ else()
add_library
(
transformer_engine SHARED
${
te_hip_sources
}
)
endif
()
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/include"
)
# Configure dependencies
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/include"
)
if
(
USE_CUDA
)
# Configure dependencies
target_link_libraries
(
transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all
)
target_include_directories
(
transformer_engine PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
target_include_directories
(
transformer_engine SYSTEM PRIVATE
...
...
@@ -439,7 +530,8 @@ target_include_directories(transformer_engine PRIVATE
"
${
CMAKE_CURRENT_BINARY_DIR
}
/string_headers"
)
# Compiler options
set_source_files_properties
(
fused_softmax/scaled_masked_softmax.cu
set
(
nvte_sources_with_fast_math
)
list
(
APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
...
...
@@ -449,20 +541,23 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
fused_attn/kv_cache.cu
)
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
if
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
)
set_source_files_properties
(
activation/gelu.cu
list
(
APPEND nvte_sources_with_fast_math
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
activation/swiglu.cu
)
endif
()
if
(
USE_CUDA
)
foreach
(
cuda_source IN LISTS nvte_sources_with_fast_math
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
"--use_fast_math"
)
endforeach
()
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
--expt-relaxed-constexpr"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-O3"
)
else
()
...
...
@@ -491,10 +586,10 @@ else()
endif
()
# Number of parallel build jobs
if
(
ENV{MAX_JOBS}
)
set
(
BUILD_JOBS_STR
"
$ENV{MAX_JOBS}
"
)
elseif
(
ENV{NVTE_BUILD_MAX_JOBS}
)
set
(
BUILD_JOBS_STR
"
$ENV{NVTE_BUILD_MAX_JOBS}
"
)
if
(
$
ENV{MAX_JOBS}
)
set
(
BUILD_JOBS_STR $ENV{MAX_JOBS}
)
elseif
(
$
ENV{NVTE_BUILD_MAX_JOBS}
)
set
(
BUILD_JOBS_STR $ENV{NVTE_BUILD_MAX_JOBS}
)
else
()
set
(
BUILD_JOBS_STR
"max"
)
endif
()
...
...
transformer_engine/common/__init__.py
View file @
c1a1c04e
...
...
@@ -8,22 +8,19 @@ import ctypes
import
functools
import
glob
import
importlib
from
importlib.metadata
import
version
,
metadata
,
PackageNotFoundError
import
logging
from
importlib.metadata
import
version
,
distribution
,
PackageNotFoundError
import
os
from
pathlib
import
Path
import
platform
import
subprocess
import
sys
import
sysconfig
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
_logger
=
logging
.
getLogger
(
__name__
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_is_
pip_
package_installed
(
package
)
->
bool
:
def
_is_package_installed
(
package
)
->
bool
:
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
...
...
@@ -31,12 +28,34 @@ def _is_pip_package_installed(package) -> bool:
# if it's importable in the current directory due to
# the presence of the shared library module.
try
:
metadata
(
package
)
distribution
(
package
)
except
PackageNotFoundError
:
return
False
return
True
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_is_package_installed_from_wheel
(
package
)
->
bool
:
"""Check if the given package is installed via PyPI."""
if
not
_is_package_installed
(
package
):
return
False
te_dist
=
distribution
(
package
)
te_wheel_file
=
""
for
file_path
in
te_dist
.
files
:
if
file_path
.
name
==
"WHEEL"
:
te_wheel_file
=
te_dist
.
locate_file
(
""
)
/
file_path
if
not
te_wheel_file
:
return
False
with
te_wheel_file
.
open
(
"r"
)
as
f
:
for
line
in
f
:
if
line
.
startswith
(
"Root-Is-Purelib:"
):
return
line
.
strip
().
split
(
":"
)[
1
].
strip
().
lower
()
==
"true"
return
False
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_find_shared_object_in_te_dir
(
te_path
:
Path
,
prefix
:
str
)
->
Optional
[
Path
]:
"""
...
...
@@ -112,6 +131,19 @@ def _get_shared_object_file(library: str) -> Path:
)
def
get_te_core_package_info
()
->
Tuple
[
bool
,
str
,
str
]:
"""
Check if Tranformer Engine core package is installed.
Returns the module name and version if found.
"""
te_core_packages
=
(
"transformer-engine-cu12"
,
"transformer-engine-cu13"
)
for
package
in
te_core_packages
:
if
_is_package_installed
(
package
):
return
True
,
package
,
version
(
package
)
return
False
,
""
,
""
@
functools
.
lru_cache
(
maxsize
=
None
)
def
load_framework_extension
(
framework
:
str
)
->
None
:
"""
...
...
@@ -130,37 +162,28 @@ def load_framework_extension(framework: str) -> None:
if
framework
==
"torch"
:
extra_dep_name
=
"pytorch"
# Find the TE packages. The core and framework packages can only be installed via PyPI.
# For the `transformer-engine` package, we need to check explicity.
te_core_installed
,
te_core_package_name
,
te_core_version
=
get_te_core_package_info
()
te_framework_installed
=
_is_package_installed
(
module_name
)
te_installed
=
_is_package_installed
(
"transformer_engine"
)
te_installed_via_pypi
=
_is_package_installed_from_wheel
(
"transformer_engine"
)
assert
te_installed
,
"Could not find `transformer_engine`."
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if
_is_pip_package_installed
(
module_name
):
assert
_is_pip_package_installed
(
"transformer_engine"
),
"Could not find `transformer-engine`."
assert
_is_pip_package_installed
(
"transformer_engine_cu12"
),
"Could not find `transformer-engine-cu12`."
assert
(
version
(
module_name
)
==
version
(
"transformer-engine"
)
==
version
(
"transformer-engine-cu12"
)
),
(
"TransformerEngine package version mismatch. Found"
f
"
{
module_name
}
v
{
version
(
module_name
)
}
, transformer-engine"
f
" v
{
version
(
'transformer-engine'
)
}
, and transformer-engine-cu12"
f
" v
{
version
(
'transformer-engine-cu12'
)
}
. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
)
# extension are all installed via PyPI and have matching versions.
if
te_framework_installed
:
assert
te_installed_via_pypi
,
"Could not find `transformer-engine` PyPI package."
assert
te_core_installed
,
"Could not find TE core package `transformer-engine-cu*`."
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if
_is_pip_package_installed
(
"transformer-engine-cu12"
):
if
not
_is_pip_package_installed
(
module_name
):
_logger
.
info
(
"Could not find package %s. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
,
module_name
,
assert
version
(
module_name
)
==
version
(
"transformer-engine"
)
==
te_core_version
,
(
"Transformer Engine package version mismatch. Found"
f
"
{
module_name
}
v
{
version
(
module_name
)
}
, transformer-engine"
f
" v
{
version
(
'transformer-engine'
)
}
, and
{
te_core_package_name
}
"
f
" v
{
te_core_version
}
. Install transformer-engine using "
f
"'pip3 install --no-build-isolation transformer-engine[
{
extra_dep_name
}
]==VERSION'"
)
# After all checks are completed, load the shared object file.
...
...
@@ -170,6 +193,35 @@ def load_framework_extension(framework: str) -> None:
spec
.
loader
.
exec_module
(
solib
)
def
sanity_checks_for_pypi_installation
()
->
None
:
"""Ensure that package is installed correctly if using PyPI."""
te_core_installed
,
te_core_package_name
,
te_core_version
=
get_te_core_package_info
()
te_installed
=
_is_package_installed
(
"transformer_engine"
)
te_installed_via_pypi
=
_is_package_installed_from_wheel
(
"transformer_engine"
)
assert
te_installed
,
"Could not find `transformer-engine`."
# If the core package is installed via PyPI.
if
te_core_installed
:
assert
te_installed_via_pypi
,
"Could not find `transformer-engine` PyPI package."
assert
version
(
"transformer-engine"
)
==
te_core_version
,
(
"Transformer Engine package version mismatch. Found "
f
"transformer-engine v
{
version
(
'transformer-engine'
)
}
"
f
"and
{
te_core_package_name
}
v
{
te_core_version
}
."
)
# Only the metapackage is found, invalid usecase.
elif
te_installed_via_pypi
:
raise
RuntimeError
(
"Found empty `transformer-engine` meta package installed. "
"Install `transformer-engine` with framework extensions via"
"'pip3 install --no-build-isolation transformer-engine[pytorch,jax]==VERSION'"
" or 'pip3 install transformer-engine[core]` for the TE core lib only. The `core_cu12`"
" or `core_cu13` extra deps can be used to specify CUDA version for the TE core lib."
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_sys_extension
()
->
str
:
"""File extension for shared objects."""
...
...
@@ -253,9 +305,7 @@ def _load_cudnn():
if
not
IS_HIP_EXTENSION
:
# Attempt to locate libcudnn via ldconfig
libs
=
subprocess
.
check_output
(
f
"ldconfig -p | grep 'libcudnn
{
_get_sys_extension
()
}
'"
,
shell
=
True
)
libs
=
subprocess
.
check_output
([
"ldconfig"
,
"-p"
])
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
sos
=
[]
for
lib
in
libs
:
...
...
@@ -285,9 +335,7 @@ def _load_nvrtc():
return
handle
# Attempt to locate NVRTC via ldconfig
libs
=
subprocess
.
check_output
(
f
"ldconfig -p | grep 'libnvrtc
{
_get_sys_extension
()
}
'"
,
shell
=
True
)
libs
=
subprocess
.
check_output
([
"ldconfig"
,
"-p"
])
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
sos
=
[]
for
lib
in
libs
:
...
...
@@ -317,9 +365,7 @@ def _load_curand():
return
handle
# Attempt to locate cuRAND via ldconfig
libs
=
subprocess
.
check_output
(
f
"ldconfig -p | grep 'libcurand
{
_get_sys_extension
()
}
'"
,
shell
=
True
)
libs
=
subprocess
.
check_output
([
"ldconfig"
,
"-p"
])
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
sos
=
[]
for
lib
in
libs
:
...
...
@@ -340,15 +386,16 @@ def _load_core_library():
if
"NVTE_PROJECT_BUILDING"
not
in
os
.
environ
or
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
try
:
sanity_checks_for_pypi_installation
()
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
_CURAND_LIB_CTYPES
=
_load_curand
()
_CUBLAS_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cublas"
)
_CUDART_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cuda_runtime"
)
_TE_LIB_CTYPES
=
_load_core_library
()
# Needed to find the correct headers for NVRTC kernels.
if
not
os
.
getenv
(
"NVTE_CUDA_INCLUDE_DIR"
)
and
_nvidia_cudart_include_dir
():
os
.
environ
[
"NVTE_CUDA_INCLUDE_DIR"
]
=
_nvidia_cudart_include_dir
()
except
OSError
:
pass
_TE_LIB_CTYPES
=
_load_core_library
()
transformer_engine/common/activation/activation_template.h
View file @
c1a1c04e
...
...
@@ -14,26 +14,17 @@
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include "../cast/dispatch/gated.cuh"
#include "../cast/dispatch/quantize.cuh"
#include "../common.h"
#include "../util/cast_gated_kernels.cuh"
#include "../util/cast_kernels.cuh"
#include "../util/math.h"
#include "../util/vectorized_pointwise.h"
namespace
transformer_engine
{
template
<
typename
ComputeType
,
typename
Param
,
ComputeType
(
*
OP
)(
ComputeType
,
const
Param
&
)>
void
act_fn
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
using
namespace
detail
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
false
;
constexpr
bool
IS_ACT
=
true
;
constexpr
NVTETensor
dbias
=
nullptr
;
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
const
NVTETensor
grad
=
nullptr
;
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
OP
>
(
input
,
grad
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
dispatch
::
quantize_fwd_helper
<
IS_ACT
,
Empty
,
OP
>
(
input
,
output
,
nullptr
,
stream
);
}
template
<
typename
ComputeType
,
typename
Param
,
ComputeType
(
*
OP
)(
ComputeType
,
const
Param
&
)>
...
...
@@ -42,20 +33,17 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
using
namespace
detail
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
true
;
constexpr
bool
IS_ACT
=
false
;
constexpr
NVTETensor
dbias
=
nullptr
;
constexpr
NVTETensor
workspace
=
nullptr
;
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
OP
>
(
input
,
grad
,
output
,
dbias
,
workspace
,
dispatch
::
quantize_
bwd_
helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
OP
>
(
grad
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
template
<
typename
ComputeType
,
typename
Param
,
ComputeType
(
*
ActOP
)(
ComputeType
,
const
Param
&
)>
void
gated_act_fn
(
const
NVTETensor
input
,
NVTETensor
output
,
Param
&
p
,
cudaStream_t
stream
)
{
using
namespace
detail
;
constexpr
bool
IS_DGATED
=
false
;
constexpr
NVTETensor
grad
=
nullptr
;
quantize_gated_helper
<
IS_DGATED
,
Param
,
ActOP
,
nullptr
>
(
grad
,
input
,
output
,
p
,
stream
);
dispatch
::
quantize_gated_fwd_helper
<
Param
,
ActOP
>
(
input
,
output
,
p
,
stream
);
}
template
<
typename
ComputeType
,
typename
Param
,
ComputeType
(
*
ActOP
)(
ComputeType
,
const
Param
&
),
...
...
@@ -63,8 +51,7 @@ template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType
void
dgated_act_fn
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
Param
&
p
,
cudaStream_t
stream
)
{
using
namespace
detail
;
constexpr
bool
IS_DGATED
=
true
;
quantize_gated_helper
<
IS_DGATED
,
Param
,
ActOP
,
DActOP
>
(
grad
,
input
,
output
,
p
,
stream
);
dispatch
::
quantize_gated_bwd_helper
<
Param
,
ActOP
,
DActOP
>
(
grad
,
input
,
output
,
p
,
stream
);
}
}
// namespace transformer_engine
...
...
transformer_engine/common/activation/gelu.cu
View file @
c1a1c04e
...
...
@@ -20,6 +20,19 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn
<
fp32
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
void
nvte_quantize_dbias_dgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_dbias_dgelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_geglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_geglu
);
using
namespace
transformer_engine
;
...
...
@@ -48,6 +61,19 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn
<
fp32
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
void
nvte_quantize_dbias_dqgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_dbias_dqgelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_qgeglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_qgeglu
);
using
namespace
transformer_engine
;
...
...
transformer_engine/common/activation/relu.cu
View file @
c1a1c04e
...
...
@@ -20,6 +20,19 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn
<
fp32
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
void
nvte_quantize_dbias_drelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_dbias_drelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_reglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_reglu
);
using
namespace
transformer_engine
;
...
...
@@ -48,6 +61,19 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn
<
fp32
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
void
nvte_quantize_dbias_dsrelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_dbias_dsrelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_sreglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_sreglu
);
using
namespace
transformer_engine
;
...
...
transformer_engine/common/activation/swiglu.cu
View file @
c1a1c04e
...
...
@@ -20,6 +20,19 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn
<
fp32
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
void
nvte_quantize_dbias_dsilu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_dbias_dsilu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
dispatch
::
quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_swiglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_swiglu
);
using
namespace
transformer_engine
;
...
...
transformer_engine/common/
util
/cast.cu
→
transformer_engine/common/
cast
/cast.cu
View file @
c1a1c04e
...
...
@@ -12,36 +12,20 @@
#include <transformer_engine/cast.h>
#include <transformer_engine/multi_stream.h>
#include <cfloat>
#include <limits>
#include <mutex>
#include <string>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/multi_stream.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "cast_kernels.cuh"
#include "dequantize_kernels.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "dispatch/dequantize.cuh"
#include "dispatch/quantize.cuh"
#include "transformer_engine/transpose.h"
void
nvte_quantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
constexpr
NVTETensor
dbias
=
nullptr
;
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
const
NVTETensor
grad
=
nullptr
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
nullptr
>
(
input
,
grad
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
dispatch
::
quantize_fwd_helper
<
IS_ACT
,
Empty
,
nullptr
>
(
input
,
output
,
nullptr
,
stream
);
}
void
nvte_quantize_noop
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
noop
,
...
...
@@ -61,15 +45,8 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
NVTE_API_CALL
(
nvte_quantize_v2
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
constexpr
NVTETensor
dbias
=
nullptr
;
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
const
NVTETensor
grad
=
nullptr
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
nullptr
>
(
input
,
grad
,
output
,
dbias
,
workspace
,
quant_config
,
stream
);
dispatch
::
quantize_fwd_helper
<
IS_ACT
,
Empty
,
nullptr
>
(
input
,
output
,
quant_config
,
stream
);
}
void
nvte_quantize_dbias
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
dbias
,
...
...
@@ -79,87 +56,17 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
constexpr
const
NVTETensor
activation_input
=
nullptr
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
nullptr
>
(
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_dbias_dgelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
constexpr
bool
IS_ACT
=
false
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dsilu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_dbias_dsilu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
constexpr
bool
IS_ACT
=
false
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_drelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_dbias_drelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
constexpr
bool
IS_ACT
=
false
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dqgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_dbias_dqgelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
constexpr
bool
IS_ACT
=
false
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_quantize_dbias_dsrelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_dbias_dsrelu
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
true
;
constexpr
bool
IS_DACT
=
true
;
constexpr
bool
IS_ACT
=
false
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
dispatch
::
quantize_bwd_helper
<
IS_DBIAS
,
IS_DACT
,
Empty
,
nullptr
>
(
input
,
activation_input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
void
nvte_dequantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dequantize
);
using
namespace
transformer_engine
;
detail
::
dequantize_helper
(
*
convertNVTETensorCheck
(
input
),
convertNVTETensorCheck
(
output
),
stream
);
dispatch
::
dequantize_helper
(
*
convertNVTETensorCheck
(
input
),
convertNVTETensorCheck
(
output
),
stream
);
}
void
nvte_multi_tensor_quantize
(
const
NVTETensor
*
inputs
,
NVTETensor
*
outputs
,
...
...
@@ -168,12 +75,7 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_API_CALL
(
nvte_multi_tensor_quantize
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
constexpr
NVTETensor
dbias
=
nullptr
;
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
const
NVTETensor
grad
=
nullptr
;
const
size_t
num_streams
=
nvte_get_num_compute_streams
();
...
...
@@ -186,9 +88,8 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
}
for
(
int
i
=
0
;
i
<
num_tensors
;
i
++
)
{
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
nullptr
>
(
inputs
[
i
],
grad
,
outputs
[
i
],
dbias
,
workspace
,
nullptr
,
detail
::
get_compute_stream
(
i
%
num_streams
));
dispatch
::
quantize_fwd_helper
<
IS_ACT
,
Empty
,
nullptr
>
(
inputs
[
i
],
outputs
[
i
],
quant_configs
,
detail
::
get_compute_stream
(
i
%
num_streams
));
}
// record events on compute streams
...
...
transformer_engine/common/cast/core/common.cuh
0 → 100644
View file @
c1a1c04e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file common.cuh
* \brief Common functions in quantize.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../utils.cuh"
namespace
transformer_engine
{
namespace
dispatch
{
namespace
common
{
inline
bool
full_tile_1D_tensor
(
const
Tensor
*
const
t
,
const
size_t
elems_per_block
)
{
const
size_t
N
=
product
(
t
->
data
.
shape
);
const
bool
isFullTile
=
(
N
%
elems_per_block
==
0
);
return
isFullTile
;
}
inline
bool
dimensions_supported_by_TMA
(
const
Tensor
*
const
t
)
{
const
size_t
cols
=
t
->
flat_last_dim
();
constexpr
size_t
TMA_bytes
=
16
;
const
size_t
alignment_requirement
=
(
TMA_bytes
*
8
)
/
typeToNumBits
(
t
->
dtype
());
return
cols
%
alignment_requirement
==
0
;
}
namespace
kernel
{
constexpr
size_t
THREADS_PER_BLOCK
=
256
;
template
<
int
nvec
,
typename
OType
>
__global__
void
__launch_bounds__
(
THREADS_PER_BLOCK
)
reduce_dbias_kernel
(
OType
*
const
dbias_output
,
const
float
*
const
dbias_partial
,
const
size_t
rows
,
const
size_t
cols
)
{
using
ComputeVec
=
Vec
<
float
,
nvec
>
;
using
OutputVec
=
Vec
<
OType
,
nvec
>
;
const
size_t
thread_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_id
*
nvec
>=
cols
)
{
return
;
}
const
float
*
const
thread_in_base
=
dbias_partial
+
thread_id
*
nvec
;
OType
*
const
thread_out_base
=
dbias_output
+
thread_id
*
nvec
;
ComputeVec
ldg_vec
;
ComputeVec
acc_vec
;
acc_vec
.
clear
();
for
(
int
i
=
0
;
i
<
rows
;
++
i
)
{
ldg_vec
.
load_from
(
thread_in_base
+
i
*
cols
);
#pragma unroll
for
(
int
e
=
0
;
e
<
nvec
;
++
e
)
{
acc_vec
.
data
.
elt
[
e
]
+=
ldg_vec
.
data
.
elt
[
e
];
}
}
OutputVec
stg_vec
;
#pragma unroll
for
(
int
e
=
0
;
e
<
nvec
;
++
e
)
{
stg_vec
.
data
.
elt
[
e
]
=
static_cast
<
OType
>
(
acc_vec
.
data
.
elt
[
e
]);
}
stg_vec
.
store_to
(
thread_out_base
);
}
}
// namespace kernel
template
<
typename
IType
>
void
reduce_dbias
(
const
float
*
workspace_ptr
,
Tensor
*
dbias
,
const
size_t
rows
,
const
size_t
cols
,
cudaStream_t
stream
)
{
using
namespace
kernel
;
constexpr
size_t
reduce_dbias_store_bytes
=
8
;
// stg.64
constexpr
size_t
reduce_dbias_nvec
=
reduce_dbias_store_bytes
/
sizeof
(
IType
);
NVTE_CHECK
(
cols
%
reduce_dbias_nvec
==
0
,
"Unsupported shape."
);
const
size_t
reduce_dbias_num_blocks
=
DIVUP
(
cols
,
THREADS_PER_BLOCK
*
reduce_dbias_nvec
);
reduce_dbias_kernel
<
reduce_dbias_nvec
,
IType
>
<<<
reduce_dbias_num_blocks
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
reinterpret_cast
<
IType
*>
(
dbias
->
data
.
dptr
),
workspace_ptr
,
rows
,
cols
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// namespace common
}
// namespace dispatch
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
transformer_engine/common/cast/dispatch/dequantize.cuh
0 → 100644
View file @
c1a1c04e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize.cuh
* \brief Dequantize dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../fp8/dequantize_fp8.cuh"
#include "../mxfp8/dequantize_mxfp8.cuh"
#include "../nvfp4/dequantize_nvfp4.cuh"
namespace
transformer_engine
{
namespace
dispatch
{
inline
void
dequantize_helper
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
CheckInputTensor
(
input
,
"cast_input"
);
CheckOutputTensor
(
*
output
,
"cast_output"
);
switch
(
input
.
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
NVTE_CHECK
(
is_fp8_dtype
(
input
.
data
.
dtype
)
||
is_int8_dtype
(
input
.
data
.
dtype
),
"Input must have FP8 or INT8 type."
);
NVTE_CHECK
(
!
is_fp8_dtype
(
output
->
data
.
dtype
)
&&
!
is_int8_dtype
(
output
->
data
.
dtype
),
"Output must be in higher precision."
);
NVTE_CHECK
(
output
->
data
.
shape
==
input
.
data
.
shape
,
"Input and output shapes need to match."
);
fp8
::
dequantize
(
input
,
output
,
stream
);
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
if
(
is_supported_by_CC_100
())
{
mxfp8
::
dequantize
(
input
,
output
,
stream
);
}
else
{
NVTE_ERROR
(
"MXFP8 Dequantization is NOT supported by architectures < 10.0"
);
}
break
;
}
case
NVTE_NVFP4_1D_SCALING
:
{
nvfp4
::
dequantize
(
input
,
output
,
stream
);
break
;
}
default:
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
input
.
scaling_mode
)
+
"."
);
}
}
}
// namespace dispatch
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_
transformer_engine/common/cast/dispatch/gated.cuh
0 → 100644
View file @
c1a1c04e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file gated.cuh
* \brief Gated dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../utils.cuh"
#include "../fp8/gated_fp8.cuh"
#include "../mxfp8/gated_mxfp8.cuh"
namespace
transformer_engine
{
namespace
dispatch
{
template
<
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
)>
void
quantize_gated_fwd_helper
(
const
NVTETensor
nvte_input
,
NVTETensor
nvte_output
,
ParamOP
&
p
,
cudaStream_t
stream
)
{
const
Tensor
input
=
*
convertNVTETensorCheck
(
nvte_input
);
Tensor
*
output
=
convertNVTETensorCheck
(
nvte_output
);
CheckInputTensor
(
input
,
"input"
);
CheckOutputTensor
(
*
output
,
"output"
,
/*allow_empty=*/
false
);
const
size_t
rows
=
input
.
flat_first_dim
();
const
size_t
cols
=
input
.
flat_last_dim
()
/
2
;
NVTE_CHECK
(
input
.
flat_last_dim
()
%
2
==
0
,
"Wrong input shape. Expected (after flattening) last dimension to be even, "
,
"got ["
,
input
.
flat_first_dim
(),
", "
,
input
.
flat_last_dim
(),
"]."
);
NVTE_CHECK
(
output
->
flat_last_dim
()
==
cols
,
"Wrong output shape. Expected (after flattening) [*, "
,
cols
,
"], got ["
,
output
->
flat_first_dim
(),
", "
,
output
->
flat_last_dim
(),
"]."
);
NVTE_CHECK
(
output
->
has_data
()
||
output
->
has_columnwise_data
(),
"Either rowwise or columnwise output data need to be allocated."
);
switch
(
output
->
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
const
bool
use_tma_kernels
=
(
cols
%
32
==
0
)
&&
is_supported_by_CC_100
();
if
(
use_tma_kernels
)
{
Tensor
dummy_grad_tensor
;
fp8
::
cast_gated_tma
<
/*IS_BWD=*/
false
,
ParamOP
,
ActOP
,
nullptr
>
(
input
,
dummy_grad_tensor
,
output
,
p
,
stream
);
}
else
{
fp8
::
cast_gated_fwd
<
ParamOP
,
ActOP
>
(
input
,
output
,
p
,
stream
);
}
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
NVTE_CHECK
(
cols
%
32
==
0
,
"Invalid input shape. Expected the last dimension to be "
"divisible by 32, but got "
,
cols
,
"."
);
if
(
output
->
has_data
())
{
NVTE_CHECK
(
is_fp8_dtype
(
output
->
data
.
dtype
),
"The type of the output tensor should be FP8."
);
}
if
(
output
->
has_columnwise_data
())
{
NVTE_CHECK
(
is_fp8_dtype
(
output
->
columnwise_data
.
dtype
),
"The type of the columnwise output tensor should be FP8."
);
}
NVTE_CHECK
(
is_supported_by_CC_100
(),
"Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"
);
Tensor
dummy_grad_tensor
;
mxfp8
::
quantize_gated
<
/*IS_BWD=*/
false
,
ParamOP
,
ActOP
,
nullptr
>
(
input
,
dummy_grad_tensor
,
output
,
p
,
stream
);
break
;
}
default:
NVTE_ERROR
(
"Not supported scaling mode: "
+
to_string
(
output
->
scaling_mode
)
+
"."
);
}
}
template
<
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
quantize_gated_bwd_helper
(
const
NVTETensor
nvte_grad
,
const
NVTETensor
nvte_gated_input
,
NVTETensor
nvte_output
,
ParamOP
&
p
,
cudaStream_t
stream
)
{
const
Tensor
&
grad
=
*
(
convertNVTETensorCheck
(
nvte_grad
));
const
Tensor
gated_input
=
*
convertNVTETensorCheck
(
nvte_gated_input
);
Tensor
*
output
=
convertNVTETensorCheck
(
nvte_output
);
CheckInputTensor
(
grad
,
"grad"
);
CheckInputTensor
(
gated_input
,
"gated_input"
);
CheckOutputTensor
(
*
output
,
"output"
,
/*allow_empty=*/
false
);
NVTE_CHECK
(
gated_input
.
flat_last_dim
()
%
2
==
0
,
"Number of columns must be even, but got "
,
gated_input
.
flat_last_dim
(),
"."
);
const
size_t
rows
=
gated_input
.
flat_first_dim
();
const
size_t
cols
=
gated_input
.
flat_last_dim
()
/
2
;
NVTE_CHECK
(
!
is_fp8_dtype
(
grad
.
data
.
dtype
),
"Grad input must be in higher precision."
);
NVTE_CHECK
(
grad
.
data
.
dtype
==
gated_input
.
data
.
dtype
,
"Types of both inputs must match."
);
NVTE_CHECK
(
grad
.
flat_first_dim
()
==
rows
,
"Wrong Grad shape. Expected first dimension (after flattening) ["
,
rows
,
", *], got ["
,
grad
.
flat_first_dim
(),
", "
,
grad
.
flat_last_dim
(),
"]."
);
NVTE_CHECK
(
grad
.
flat_last_dim
()
==
cols
,
"Wrong Grad shape. Expected last dimension (after flattening) ["
,
cols
,
", *], got ["
,
grad
.
flat_first_dim
(),
", "
,
grad
.
flat_last_dim
(),
"]."
);
NVTE_CHECK
(
output
->
has_data
()
||
output
->
has_columnwise_data
(),
"Either rowwise or columnwise output data need to be allocated."
);
NVTE_CHECK
(
output
->
flat_first_dim
()
==
rows
,
"Wrong output shape. Expected (after flattening) ["
,
rows
,
", *], got ["
,
output
->
flat_first_dim
(),
", "
,
output
->
flat_last_dim
(),
"]."
);
NVTE_CHECK
(
output
->
flat_last_dim
()
==
cols
*
2
,
"Wrong output shape. Expected (after flattening) [*, "
,
cols
*
2
,
"], got ["
,
output
->
flat_first_dim
(),
", "
,
output
->
flat_last_dim
(),
"]."
);
NVTE_CHECK
(
gated_input
.
data
.
shape
==
output
->
data
.
shape
,
"Gated input and output shapes must match. Input shape: "
,
gated_input
.
data
.
shape
,
", output shape: "
,
output
->
data
.
shape
,
"."
);
switch
(
output
->
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
const
bool
use_tma_kernels
=
(
cols
%
32
==
0
)
&&
is_supported_by_CC_100
();
if
(
use_tma_kernels
)
{
fp8
::
cast_gated_tma
<
/*IS_BWD=*/
true
,
ParamOP
,
ActOP
,
DActOP
>
(
gated_input
,
grad
,
output
,
p
,
stream
);
}
else
{
fp8
::
cast_gated_bwd
<
ParamOP
,
ActOP
,
DActOP
>
(
gated_input
,
grad
,
output
,
p
,
stream
);
}
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
NVTE_CHECK
(
cols
%
32
==
0
,
"Invalid input shape. Expected the last dimension to be "
"divisible by 32, but got "
,
cols
,
"."
);
if
(
output
->
has_data
())
{
NVTE_CHECK
(
is_fp8_dtype
(
output
->
data
.
dtype
),
"The type of the output tensor should be FP8."
);
}
if
(
output
->
has_columnwise_data
())
{
NVTE_CHECK
(
is_fp8_dtype
(
output
->
columnwise_data
.
dtype
),
"The type of the columnwise output tensor should be FP8."
);
}
NVTE_CHECK
(
is_supported_by_CC_100
(),
"Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"
);
mxfp8
::
quantize_gated
<
/*IS_BWD=*/
true
,
ParamOP
,
ActOP
,
DActOP
>
(
gated_input
,
grad
,
output
,
p
,
stream
);
break
;
}
default:
NVTE_ERROR
(
"Not supported scaling mode: "
+
to_string
(
output
->
scaling_mode
)
+
"."
);
}
}
}
// namespace dispatch
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
transformer_engine/common/cast/dispatch/quantize.cuh
0 → 100644
View file @
c1a1c04e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize.cuh
* \brief Quantize dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../transpose/cast_transpose.h"
#include "../../util/vectorized_pointwise.h"
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh"
namespace
transformer_engine
{
namespace
dispatch
{
template
<
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
quantize_fwd_helper
(
const
NVTETensor
input
,
NVTETensor
output
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
using
namespace
detail
;
const
Tensor
*
input_tensor
=
convertNVTETensorCheck
(
input
);
Tensor
*
output_tensor
=
convertNVTETensorCheck
(
output
);
// Quantization config
QuantizationConfig
quant_config_cpp
;
if
(
quant_config
!=
nullptr
)
{
quant_config_cpp
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
quant_config
);
}
// Noop flag
Tensor
dummy_tensor
;
Tensor
*
noop_tensor
=
&
dummy_tensor
;
if
(
quant_config_cpp
.
noop_tensor
!=
nullptr
)
{
noop_tensor
=
convertNVTETensorCheck
(
quant_config_cpp
.
noop_tensor
);
}
// Check for unsupported options
if
(
quant_config_cpp
.
stochastic_rounding
)
{
NVTE_CHECK
(
output_tensor
->
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Stochastic rounding is only supported for NVFP4 quantization."
);
}
NVTE_CHECK
(
output_tensor
->
has_data
()
||
output_tensor
->
has_columnwise_data
(),
"Either rowwise or columnwise output data need to be allocated."
);
// Dispatch to quantization kernel depending on data format
switch
(
output_tensor
->
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
const
Tensor
*
dummy_input_tensor
=
nullptr
;
Tensor
*
dummy_dbias_tensor
=
nullptr
;
Tensor
*
dummy_workspace_tensor
=
nullptr
;
if
(
output_tensor
->
has_columnwise_data
())
{
const
char
*
NVTE_INT8_SIM_FP8_TENSORWISE
=
std
::
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
);
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
){
NVTE_CHECK
(
false
,
"NVTE_INT8_SIM_FP8_TENSORWISE need not be transposed!"
);
}
NVTE_CHECK
(
output_tensor
->
has_data
(),
"Quantizing in only the columnwise direction not supported yet!"
);
if
constexpr
(
!
IS_ACT
)
{
cast_transpose
(
*
input_tensor
,
*
noop_tensor
,
output_tensor
,
stream
);
}
else
{
cast_transpose_fused
<
/*IS_DBIAS=*/
false
,
/*IS_DACT=*/
false
,
IS_ACT
,
float
,
ParamOP
,
OP
>
(
*
input_tensor
,
dummy_input_tensor
,
output_tensor
,
dummy_dbias_tensor
,
dummy_workspace_tensor
,
stream
);
}
}
else
if
(
output_tensor
->
has_data
())
{
fp8
::
quantize
<
/*IS_DBIAS=*/
false
,
/*IS_DACT=*/
false
,
IS_ACT
,
ParamOP
,
OP
>
(
*
input_tensor
,
dummy_input_tensor
,
noop_tensor
,
output_tensor
,
dummy_dbias_tensor
,
dummy_workspace_tensor
,
stream
);
}
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
const
Tensor
*
dummy_input_tensor
=
nullptr
;
Tensor
*
dummy_dbias_tensor
=
nullptr
;
Tensor
*
dummy_workspace_tensor
=
nullptr
;
mxfp8
::
quantize
<
/*IS_DBIAS=*/
false
,
/*IS_DACT=*/
false
,
IS_ACT
,
ParamOP
,
OP
>
(
*
input_tensor
,
dummy_input_tensor
,
noop_tensor
,
output_tensor
,
dummy_dbias_tensor
,
dummy_workspace_tensor
,
stream
);
break
;
}
case
NVTE_NVFP4_1D_SCALING
:
{
NVTE_CHECK
(
!
IS_ACT
,
"IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"
);
// Check tensors
CheckNoopTensor
(
*
noop_tensor
,
"cast_noop"
);
CheckInputTensor
(
*
input_tensor
,
"input"
);
CheckOutputTensor
(
*
output_tensor
,
"output"
,
false
);
// Choose kernel
int32_t
rows
=
input_tensor
->
flat_first_dim
();
int32_t
cols
=
input_tensor
->
flat_last_dim
();
auto
dtype
=
input_tensor
->
dtype
();
bool
use_optimized_kernel
=
(
dtype
==
DType
::
kBFloat16
)
&&
(
rows
%
32
==
0
)
&&
(
cols
%
32
==
0
)
&&
output_tensor
->
has_data
();
// Launch NVFP4 quantize kernel
if
(
use_optimized_kernel
)
{
if
(
quant_config_cpp
.
nvfp4_2d_quantization
)
{
nvfp4
::
quantize_transpose
<
/*use_2d_quantization=*/
true
>
(
*
input_tensor
,
noop_tensor
,
output_tensor
,
&
quant_config_cpp
,
stream
);
}
else
{
nvfp4
::
quantize_transpose
<
/*use_2d_quantization*/
false
>
(
*
input_tensor
,
noop_tensor
,
output_tensor
,
&
quant_config_cpp
,
stream
);
}
}
else
{
auto
&
global_amax
=
(
output_tensor
->
amax
.
dptr
!=
nullptr
)
?
output_tensor
->
amax
:
output_tensor
->
columnwise_amax
;
quantize_transpose_vector_blockwise_fp4
(
/*input=*/
input_tensor
->
data
,
/*global_amax=*/
global_amax
,
/*scale_inv=*/
output_tensor
->
scale_inv
,
/*scale_inv_t=*/
output_tensor
->
columnwise_scale_inv
,
/*output=*/
output_tensor
->
data
,
/*output_t=*/
output_tensor
->
columnwise_data
,
/*epsilon=*/
0.0
f
,
/*return_identity=*/
output_tensor
->
has_data
(),
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
/*pow2_scale=*/
false
,
/*swizzled_scale=*/
false
,
/*use_stochastic_rounding=*/
quant_config_cpp
.
stochastic_rounding
,
/*rng_state=*/
quant_config_cpp
.
rng_state
,
/*use_2d_quantization=*/
quant_config_cpp
.
nvfp4_2d_quantization
,
/*noop_tensor=*/
noop_tensor
->
data
,
/*stream=*/
stream
);
}
break
;
}
case
NVTE_BLOCK_SCALING_2D
:
{
// TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK
(
!
IS_ACT
,
"IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"
);
bool
force_pow_2_scales
=
quant_config_cpp
.
force_pow_2_scales
;
float
epsilon
=
quant_config_cpp
.
amax_epsilon
;
quantize_transpose_square_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
force_pow_2_scales
,
/*noop_tensor=*/
noop_tensor
->
data
,
stream
);
break
;
}
case
NVTE_BLOCK_SCALING_1D
:
{
// TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK
(
!
IS_ACT
,
"IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"
);
bool
force_pow_2_scales
=
quant_config_cpp
.
force_pow_2_scales
;
float
epsilon
=
quant_config_cpp
.
amax_epsilon
;
FP8BlockwiseRowwiseOption
rowwise_option
=
FP8BlockwiseRowwiseOption
::
NONE
;
FP8BlockwiseColumnwiseOption
columnwise_option
=
FP8BlockwiseColumnwiseOption
::
NONE
;
if
(
output_tensor
->
has_data
())
{
bool
rowwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
rowwise_option
=
rowwise_compact
?
FP8BlockwiseRowwiseOption
::
ROWWISE_COMPACT
:
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
;
}
if
(
output_tensor
->
has_columnwise_data
())
{
bool
columnwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
columnwise_option
=
columnwise_compact
?
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
:
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
}
quantize_transpose_vector_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
rowwise_option
,
columnwise_option
,
force_pow_2_scales
,
noop_tensor
->
data
,
stream
);
break
;
}
default:
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output_tensor
->
scaling_mode
)
+
"."
);
}
}
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
quantize_bwd_helper
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
using
namespace
detail
;
const
Tensor
*
grad_tensor
=
convertNVTETensorCheck
(
grad
);
const
Tensor
*
input_tensor
=
convertNVTETensor
(
input
);
Tensor
*
output_tensor
=
convertNVTETensorCheck
(
output
);
Tensor
*
dbias_tensor
=
convertNVTETensor
(
dbias
);
Tensor
*
workspace_tensor
=
convertNVTETensor
(
workspace
);
// Quantization config
QuantizationConfig
quant_config_cpp
;
if
(
quant_config
!=
nullptr
)
{
quant_config_cpp
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
quant_config
);
}
// Noop flag
Tensor
dummy_tensor
;
Tensor
*
noop_tensor
=
&
dummy_tensor
;
if
(
quant_config_cpp
.
noop_tensor
!=
nullptr
)
{
noop_tensor
=
convertNVTETensorCheck
(
quant_config_cpp
.
noop_tensor
);
}
// Check for unsupported options
if
(
quant_config_cpp
.
stochastic_rounding
)
{
NVTE_CHECK
(
output_tensor
->
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Stochastic rounding is only supported for NVFP4 quantization."
);
}
NVTE_CHECK
(
output_tensor
->
has_data
()
||
output_tensor
->
has_columnwise_data
(),
"Either rowwise or columnwise output data need to be allocated."
);
// Dispatch to quantization kernel depending on data format
switch
(
output_tensor
->
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
if
(
output_tensor
->
has_columnwise_data
())
{
const
char
*
NVTE_INT8_SIM_FP8_TENSORWISE
=
std
::
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
);
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
){
NVTE_CHECK
(
false
,
"NVTE_INT8_SIM_FP8_TENSORWISE need not be transposed!"
);
}
NVTE_CHECK
(
output_tensor
->
has_data
(),
"Quantizing in only the columnwise direction not supported yet!"
);
if
constexpr
(
!
IS_DBIAS
&&
!
IS_DACT
)
{
cast_transpose
(
*
grad_tensor
,
*
noop_tensor
,
output_tensor
,
stream
);
}
else
{
cast_transpose_fused
<
IS_DBIAS
,
IS_DACT
,
/*IS_ACT=*/
false
,
float
,
ParamOP
,
OP
>
(
*
grad_tensor
,
input_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
stream
);
}
}
else
if
(
output_tensor
->
has_data
())
{
fp8
::
quantize
<
IS_DBIAS
,
IS_DACT
,
/*IS_ACT=*/
false
,
ParamOP
,
OP
>
(
*
grad_tensor
,
input_tensor
,
noop_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
stream
);
}
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
mxfp8
::
quantize
<
IS_DBIAS
,
IS_DACT
,
/*IS_ACT=*/
false
,
ParamOP
,
OP
>
(
*
grad_tensor
,
input_tensor
,
noop_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
stream
);
break
;
}
case
NVTE_NVFP4_1D_SCALING
:
{
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
),
"IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"
);
// Check tensors
CheckNoopTensor
(
*
noop_tensor
,
"cast_noop"
);
CheckInputTensor
(
*
grad_tensor
,
"input"
);
CheckOutputTensor
(
*
output_tensor
,
"output"
,
false
);
// Choose kernel
int32_t
rows
=
grad_tensor
->
flat_first_dim
();
int32_t
cols
=
grad_tensor
->
flat_last_dim
();
auto
dtype
=
grad_tensor
->
dtype
();
bool
use_optimized_kernel
=
(
dtype
==
DType
::
kBFloat16
)
&&
(
rows
%
32
==
0
)
&&
(
cols
%
32
==
0
)
&&
output_tensor
->
has_data
();
// Launch NVFP4 quantize kernel
if
(
use_optimized_kernel
)
{
if
(
quant_config_cpp
.
nvfp4_2d_quantization
)
{
nvfp4
::
quantize_transpose
<
/*use_2d_quantization=*/
true
>
(
*
grad_tensor
,
noop_tensor
,
output_tensor
,
&
quant_config_cpp
,
stream
);
}
else
{
nvfp4
::
quantize_transpose
<
/*use_2d_quantization*/
false
>
(
*
grad_tensor
,
noop_tensor
,
output_tensor
,
&
quant_config_cpp
,
stream
);
}
}
else
{
auto
&
global_amax
=
(
output_tensor
->
amax
.
dptr
!=
nullptr
)
?
output_tensor
->
amax
:
output_tensor
->
columnwise_amax
;
quantize_transpose_vector_blockwise_fp4
(
/*input=*/
grad_tensor
->
data
,
/*global_amax=*/
global_amax
,
/*scale_inv=*/
output_tensor
->
scale_inv
,
/*scale_inv_t=*/
output_tensor
->
columnwise_scale_inv
,
/*output=*/
output_tensor
->
data
,
/*output_t=*/
output_tensor
->
columnwise_data
,
/*epsilon=*/
0.0
f
,
/*return_identity=*/
output_tensor
->
has_data
(),
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
/*pow2_scale=*/
false
,
/*swizzled_scale=*/
false
,
/*use_stochastic_rounding=*/
quant_config_cpp
.
stochastic_rounding
,
/*rng_state=*/
quant_config_cpp
.
rng_state
,
/*use_2d_quantization=*/
quant_config_cpp
.
nvfp4_2d_quantization
,
/*noop_tensor=*/
noop_tensor
->
data
,
/*stream=*/
stream
);
}
break
;
}
case
NVTE_BLOCK_SCALING_2D
:
{
// TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
),
"IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"
);
bool
force_pow_2_scales
=
quant_config_cpp
.
force_pow_2_scales
;
float
epsilon
=
quant_config_cpp
.
amax_epsilon
;
quantize_transpose_square_blockwise
(
grad_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
force_pow_2_scales
,
/*noop_tensor=*/
noop_tensor
->
data
,
stream
);
break
;
}
case
NVTE_BLOCK_SCALING_1D
:
{
// TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
),
"IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"
);
bool
force_pow_2_scales
=
quant_config_cpp
.
force_pow_2_scales
;
float
epsilon
=
quant_config_cpp
.
amax_epsilon
;
FP8BlockwiseRowwiseOption
rowwise_option
=
FP8BlockwiseRowwiseOption
::
NONE
;
FP8BlockwiseColumnwiseOption
columnwise_option
=
FP8BlockwiseColumnwiseOption
::
NONE
;
if
(
output_tensor
->
has_data
())
{
bool
rowwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
rowwise_option
=
rowwise_compact
?
FP8BlockwiseRowwiseOption
::
ROWWISE_COMPACT
:
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
;
}
if
(
output_tensor
->
has_columnwise_data
())
{
bool
columnwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
columnwise_option
=
columnwise_compact
?
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
:
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
}
quantize_transpose_vector_blockwise
(
grad_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
rowwise_option
,
columnwise_option
,
force_pow_2_scales
,
noop_tensor
->
data
,
stream
);
break
;
}
default:
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output_tensor
->
scaling_mode
)
+
"."
);
}
}
}
// namespace dispatch
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
transformer_engine/common/cast/fp8/dequantize_fp8.cuh
0 → 100644
View file @
c1a1c04e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize_fp8.cuh
* \brief CUDA kernels to dequantize from FP8.
*/
#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
#define TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/vectorized_pointwise.h"
#include "../../utils.cuh"
namespace
transformer_engine
{
namespace
dispatch
{
namespace
fp8
{
struct
DequantizeParam
{
const
float
*
scale_inv
;
};
__device__
inline
float
dequantize_func
(
float
value
,
const
DequantizeParam
&
param
)
{
return
value
*
(
*
(
param
.
scale_inv
));
}
inline
void
dequantize
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
const
size_t
N
=
product
(
input
.
data
.
shape
);
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT
(
input
.
data
.
dtype
,
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
output
->
data
.
dtype
,
OType
,
constexpr
int
nvec
=
32
/
sizeof
(
OType
);
DequantizeParam
p
;
p
.
scale_inv
=
reinterpret_cast
<
const
fp32
*>
(
input
.
scale_inv
.
dptr
);
VectorizedUnaryKernelLauncher
<
nvec
,
DequantizeParam
,
dequantize_func
>
(
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
nullptr
,
reinterpret_cast
<
OType
*>
(
output
->
data
.
dptr
),
nullptr
,
nullptr
,
nullptr
,
N
,
p
,
stream
););
// NOLINT(*)
);
// NOLINT(*)
}
}
// namespace fp8
}
// namespace dispatch
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_
transformer_engine/common/cast/fp8/gated_fp8.cuh
0 → 100644
View file @
c1a1c04e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file gated_fp8.cuh
* \brief CUDA kernels to cast to FP8 with gated activations.
*/
#ifndef TRANSFORMER_ENGINE_GATED_FP8_CUH_
#define TRANSFORMER_ENGINE_GATED_FP8_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../util/vectorized_pointwise.h"
#include "../../utils.cuh"
namespace
transformer_engine
{
namespace
dispatch
{
namespace
fp8
{
namespace
kernel
{
constexpr
size_t
CHUNK_DIM_Y
=
128
;
constexpr
size_t
CHUNK_DIM_X
=
128
;
constexpr
size_t
THREADS_PER_CHUNK
=
512
;
constexpr
size_t
THREADS_PER_CHUNK_X
=
CHUNK_DIM_X
;
constexpr
size_t
THREADS_PER_CHUNK_Y
=
THREADS_PER_CHUNK
/
THREADS_PER_CHUNK_X
;
// 4 = 512 / 128
constexpr
size_t
BUFFERS_NUM
=
2
;
constexpr
size_t
BUFFER_DIM_Y
=
32
;
constexpr
size_t
BUFFER_DIM_X
=
CHUNK_DIM_X
;
// 128
constexpr
size_t
SHMEM_DIM_Y
=
BUFFER_DIM_Y
;
// 32
constexpr
size_t
SHMEM_DIM_X
=
BUFFER_DIM_X
;
// 128
constexpr
size_t
BUFFER_STAGES_NUM
=
BUFFER_DIM_Y
/
THREADS_PER_CHUNK_Y
;
// 8 = 32 / 4
constexpr
size_t
ITERATIONS
=
CHUNK_DIM_Y
/
BUFFER_DIM_Y
;
// 4 = 128 / 32
static_assert
(
ITERATIONS
>=
1
);
#ifndef __HIP_PLATFORM_AMD__
template
<
bool
IS_BWD
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
THREADS_PER_CHUNK
)
cast_fp8_gated_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_grad
,
const
__grid_constant__
CUtensorMap
tensor_map_input_act
,
const
__grid_constant__
CUtensorMap
tensor_map_input_gate
,
const
__grid_constant__
CUtensorMap
tensor_map_output_act
,
const
__grid_constant__
CUtensorMap
tensor_map_output_gate
,
float
*
const
amax_ptr
,
float
*
const
scale_inv_ptr
,
const
float
*
const
scale_ptr
,
const
size_t
rows
,
const
size_t
cols
,
const
ParamOP
p
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const
size_t
chunk_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
chunk_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
tid_Y
=
threadIdx
.
x
/
THREADS_PER_CHUNK_X
;
const
size_t
tid_X
=
threadIdx
.
x
%
THREADS_PER_CHUNK_X
;
const
size_t
thread_offset_Y
=
tid_Y
;
const
size_t
thread_offset_X
=
tid_X
;
float
amax
=
0
;
const
float
scale
=
(
scale_ptr
!=
nullptr
)
?
*
scale_ptr
:
1
;
extern
__shared__
char
dynamic_shmem
[];
uintptr_t
base_shmem_ptr
=
reinterpret_cast
<
uintptr_t
>
(
dynamic_shmem
);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t
dshmem
=
(
base_shmem_ptr
+
TMA_SHMEM_ALIGNMENT
-
1
)
&
~
(
static_cast
<
uintptr_t
>
(
TMA_SHMEM_ALIGNMENT
-
1
));
constexpr
size_t
buff_elems
=
SHMEM_DIM_Y
*
SHMEM_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFERS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
OType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
grad_mem
=
IS_BWD
?
buff_size_aligned_in
:
0
;
constexpr
size_t
in_act_mem
=
buff_size_aligned_in
;
constexpr
size_t
in_gate_mem
=
buff_size_aligned_in
;
constexpr
size_t
in_mem
=
in_act_mem
+
in_gate_mem
;
constexpr
size_t
out_act_mem
=
buff_size_aligned_out
;
constexpr
size_t
in_transaction_size
=
buff_elems
*
sizeof
(
IType
);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_grad_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
IType
*
in_act_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
+
grad_mem
);
IType
*
in_gate_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
+
grad_mem
+
in_act_mem
);
OType
*
out_act_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
grad_mem
+
in_mem
);
OType
*
out_gate_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
grad_mem
+
in_mem
+
out_act_mem
);
const
uint64_t
*
TMAP_grad_in
=
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_grad
);
const
uint64_t
*
TMAP_in_act
=
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_input_act
);
const
uint64_t
*
TMAP_in_gate
=
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_input_gate
);
const
uint64_t
*
TMAP_output_act
=
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_act
);
const
uint64_t
*
TMAP_output_gate
=
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_gate
);
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
alignas
(
8
)
uint64_t
mbar
[
ITERATIONS
];
initialize_barriers
<
ITERATIONS
,
THREADS_PER_CHUNK
>
(
mbar
,
is_master_thread
);
int
parity
=
0
;
// Prefetch data of the first stage
if
constexpr
(
IS_BWD
)
{
copy_2d_to_sharedx3
(
in_grad_sh
,
TMAP_grad_in
,
chunk_offset_X
,
chunk_offset_Y
,
in_act_sh
,
TMAP_in_act
,
chunk_offset_X
,
chunk_offset_Y
,
in_gate_sh
,
TMAP_in_gate
,
chunk_offset_X
,
chunk_offset_Y
,
in_transaction_size
,
&
mbar
[
0
],
is_master_thread
);
}
else
{
copy_2d_to_sharedx2
(
in_act_sh
,
TMAP_in_act
,
chunk_offset_X
,
chunk_offset_Y
,
in_gate_sh
,
TMAP_in_gate
,
chunk_offset_X
,
chunk_offset_Y
,
in_transaction_size
,
&
mbar
[
0
],
is_master_thread
);
}
#pragma unroll
for
(
int
it
=
0
;
it
<
ITERATIONS
;
++
it
)
{
const
size_t
buff
=
it
%
BUFFERS_NUM
;
const
size_t
next_it
=
it
+
1
;
if
(
next_it
<
ITERATIONS
)
{
const
size_t
next_buff
=
next_it
%
BUFFERS_NUM
;
const
size_t
chunk_it_offset_y
=
chunk_offset_Y
+
next_it
*
BUFFER_DIM_Y
;
const
size_t
chunk_it_offset_x
=
chunk_offset_X
;
if
constexpr
(
IS_BWD
)
{
copy_2d_to_sharedx3
(
&
in_grad_sh
[
next_buff
*
buff_elems
],
TMAP_grad_in
,
chunk_it_offset_x
,
chunk_it_offset_y
,
&
in_act_sh
[
next_buff
*
buff_elems
],
TMAP_in_act
,
chunk_it_offset_x
,
chunk_it_offset_y
,
&
in_gate_sh
[
next_buff
*
buff_elems
],
TMAP_in_gate
,
chunk_it_offset_x
,
chunk_it_offset_y
,
in_transaction_size
,
&
mbar
[
next_it
],
is_master_thread
);
}
else
{
copy_2d_to_sharedx2
(
&
in_act_sh
[
next_buff
*
buff_elems
],
TMAP_in_act
,
chunk_it_offset_x
,
chunk_it_offset_y
,
&
in_gate_sh
[
next_buff
*
buff_elems
],
TMAP_in_gate
,
chunk_it_offset_x
,
chunk_it_offset_y
,
in_transaction_size
,
&
mbar
[
next_it
],
is_master_thread
);
}
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
it
],
parity
);
IType
*
in_grad_sh_curr
=
in_grad_sh
+
buff
*
buff_elems
;
IType
*
in_act_sh_curr
=
in_act_sh
+
buff
*
buff_elems
;
IType
*
in_gate_sh_curr
=
in_gate_sh
+
buff
*
buff_elems
;
OType
*
out_act_sh_curr
=
out_act_sh
+
buff
*
buff_elems
;
OType
*
out_gate_sh_curr
=
out_gate_sh
+
buff
*
buff_elems
;
#pragma unroll
for
(
int
stage
=
0
;
stage
<
BUFFER_STAGES_NUM
;
++
stage
)
{
const
size_t
stage_offset_Y
=
stage
*
THREADS_PER_CHUNK_Y
;
const
size_t
shmem_offset_y
=
thread_offset_Y
+
stage_offset_Y
;
const
size_t
shmem_offset_x
=
thread_offset_X
;
const
size_t
shmem_idx
=
shmem_offset_y
*
SHMEM_DIM_X
+
shmem_offset_x
;
float
act_elt
=
static_cast
<
float
>
(
in_act_sh_curr
[
shmem_idx
]);
float
gate_elt
=
static_cast
<
float
>
(
in_gate_sh_curr
[
shmem_idx
]);
bool
dgate_elt
=
true
;
// gating is ideally an identity function
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
// In case of GPT OSS, clamp the activation and gate values
dgate_elt
=
gate_elt
<=
p
.
limit
&&
gate_elt
>=
-
p
.
limit
;
// Derivative of clamp
gate_elt
=
min
(
max
(
-
p
.
limit
,
gate_elt
),
p
.
limit
)
+
1
;
}
if
constexpr
(
IS_BWD
)
{
float
grad_elt
=
static_cast
<
float
>
(
in_grad_sh_curr
[
shmem_idx
]);
const
float
x
=
act_elt
;
float
act_x
;
float
dact_x
;
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
const
float
x
=
min
(
act_elt
,
p
.
limit
);
const
float
s
=
sigmoidf
(
p
.
alpha
*
x
);
act_x
=
x
*
s
;
if
(
act_elt
<=
p
.
limit
)
{
dact_x
=
s
+
s
*
(
1
-
s
)
*
p
.
alpha
*
x
;
}
else
{
dact_x
=
0.0
f
;
}
}
else
{
if
constexpr
((
ActOP
==
&
silu
<
fp32
,
fp32
>
)
&&
(
DActOP
==
&
dsilu
<
fp32
,
fp32
>
))
{
const
float
s
=
sigmoidf
(
x
);
act_x
=
x
*
s
;
dact_x
=
x
*
s
*
(
1
-
s
)
+
s
;
}
else
{
act_x
=
ActOP
(
x
,
p
);
dact_x
=
DActOP
(
x
,
p
);
}
}
float
after_dact
=
dact_x
*
grad_elt
*
gate_elt
;
float
after_dgate
=
dgate_elt
?
act_x
*
grad_elt
:
0.0
f
;
out_act_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_dact
);
out_gate_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_dgate
);
amax
=
fmaxf
(
amax
,
fabsf
(
after_dact
));
amax
=
fmaxf
(
amax
,
fabsf
(
after_dgate
));
}
else
{
const
float
after_act
=
ActOP
(
act_elt
,
p
)
*
gate_elt
;
out_act_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_act
);
amax
=
fmaxf
(
amax
,
fabsf
(
after_act
));
}
}
// Wait for shared memory writes to be visible to TMA engine (cross-proxy fence)
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
const
size_t
chunk_it_offset_y
=
chunk_offset_Y
+
it
*
BUFFER_DIM_Y
;
const
size_t
chunk_it_offset_x
=
chunk_offset_X
;
// dGeLU
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
TMAP_output_act
,
chunk_it_offset_x
,
chunk_it_offset_y
,
reinterpret_cast
<
uint64_t
*>
(
out_act_sh_curr
));
if
constexpr
(
IS_BWD
)
{
// dGate
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
TMAP_output_gate
,
chunk_it_offset_x
,
chunk_it_offset_y
,
reinterpret_cast
<
uint64_t
*>
(
out_gate_sh_curr
));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
// Wait for TMA transfer to have finished reading shared memory.
ptx
::
cp_async_bulk_wait_group_read
<
BUFFERS_NUM
-
1
>
();
}
}
ptx
::
cp_async_bulk_wait_group_read
<
0
>
();
__syncthreads
();
if
(
amax_ptr
!=
nullptr
)
{
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
// Reduce the amax over the block
amax
=
reduce_max
<
THREADS_PER_CHUNK
/
THREADS_PER_WARP
>
(
amax
,
warp_id
);
// Update the global amax
if
(
is_master_thread
)
{
atomicMaxFloat
(
amax_ptr
,
amax
);
}
}
// Update scale-inverse
if
(
is_master_thread
&&
blockIdx
.
x
==
0
&&
(
scale_inv_ptr
!=
nullptr
))
{
reciprocal
<
float
>
(
scale_inv_ptr
,
scale
);
}
// Destroy the barriers. This invalidates the memory region of the barrier.
// If further computations were to take place in the kernel, this allows the
// memory location of the shared memory barrier to be reused.
if
(
is_master_thread
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
ITERATIONS
;
++
it
)
{
ptx
::
mbarrier_invalid
(
&
mbar
[
it
]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
#endif
}
// namespace kernel
template
<
bool
IS_BWD
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
cast_gated_tma
(
const
Tensor
&
gated_input
,
const
Tensor
&
grad
,
Tensor
*
output
,
ParamOP
&
p
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
#else
using
namespace
kernel
;
checkCuDriverContext
(
stream
);
NVTE_CHECK
(
!
output
->
has_columnwise_data
(),
"Only rowwise cast supported in this function."
);
const
size_t
rows
=
gated_input
.
flat_first_dim
();
const
size_t
cols
=
gated_input
.
flat_last_dim
()
/
2
;
const
size_t
output_cols
=
(
IS_BWD
?
2
:
1
)
*
cols
;
const
size_t
blocks_Y
=
DIVUP
(
rows
,
CHUNK_DIM_Y
);
const
size_t
blocks_X
=
DIVUP
(
cols
,
CHUNK_DIM_X
);
float
*
const
amax_ptr
=
reinterpret_cast
<
float
*>
(
output
->
amax
.
dptr
);
float
*
const
scale_inv_ptr
=
reinterpret_cast
<
float
*>
(
output
->
scale_inv
.
dptr
);
float
*
const
scale_ptr
=
reinterpret_cast
<
float
*>
(
output
->
scale
.
dptr
);
const
dim3
block_dim
(
THREADS_PER_CHUNK
);
const
dim3
grid_dim
(
blocks_X
,
blocks_Y
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
gated_input
.
dtype
(),
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
output
->
dtype
(),
OType
,
alignas
(
64
)
CUtensorMap
tensor_map_grad
{};
alignas
(
64
)
CUtensorMap
tensor_map_input_act
{};
alignas
(
64
)
CUtensorMap
tensor_map_input_gate
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_act
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_gate
{};
if
constexpr
(
IS_BWD
)
{
create_2D_tensor_map
(
tensor_map_grad
,
grad
.
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
,
0
,
typeToNumBits
(
gated_input
.
dtype
()));
}
const
uint32_t
tensor_stride_elems
=
output_cols
;
create_2D_tensor_map
(
tensor_map_input_act
,
gated_input
.
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
*
2
,
0
,
typeToNumBits
(
gated_input
.
dtype
()));
create_2D_tensor_map
(
tensor_map_input_gate
,
gated_input
.
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
*
2
,
cols
,
typeToNumBits
(
gated_input
.
dtype
()));
create_2D_tensor_map
(
tensor_map_output_act
,
output
->
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
tensor_stride_elems
,
0
,
typeToNumBits
(
output
->
dtype
()));
create_2D_tensor_map
(
tensor_map_output_gate
,
output
->
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
tensor_stride_elems
,
cols
,
typeToNumBits
(
output
->
dtype
()));
const
size_t
buff_elems_total
=
BUFFERS_NUM
*
SHMEM_DIM_Y
*
SHMEM_DIM_X
;
const
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
const
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
OType
),
TMA_SHMEM_ALIGNMENT
);
const
size_t
grad_mem
=
(
IS_BWD
?
buff_size_aligned_in
:
0
);
const
size_t
in_act_mem
=
buff_size_aligned_in
;
const
size_t
in_gate_mem
=
buff_size_aligned_in
;
const
size_t
out_act_mem
=
buff_size_aligned_out
;
const
size_t
out_gate_mem
=
buff_size_aligned_out
;
const
size_t
shmem_size
=
grad_mem
+
(
in_act_mem
+
in_gate_mem
)
+
(
out_act_mem
+
out_gate_mem
)
+
TMA_SHMEM_ALIGNMENT
;
auto
kernel
=
cast_fp8_gated_kernel
<
IS_BWD
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
>
;
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_size
));
kernel
<<<
grid_dim
,
block_dim
,
shmem_size
,
stream
>>>
(
tensor_map_grad
,
tensor_map_input_act
,
tensor_map_input_gate
,
tensor_map_output_act
,
tensor_map_output_gate
,
amax_ptr
,
scale_inv_ptr
,
scale_ptr
,
rows
,
cols
,
p
);
NVTE_CHECK_CUDA
(
cudaGetLastError
()););
// NOLINT(*)
);
// NOLINT(*)
#endif
}
template
<
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
)>
void
cast_gated_fwd
(
const
Tensor
&
input
,
Tensor
*
output
,
ParamOP
&
p
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
(),
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
output
->
dtype
(),
OType
,
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
GatedActivationKernelLauncher
<
nvec
,
fp32
,
ParamOP
,
ActOP
>
(
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
OType
*>
(
output
->
data
.
dptr
),
reinterpret_cast
<
const
fp32
*>
(
output
->
scale
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
amax
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
input
.
flat_first_dim
(),
output
->
flat_last_dim
(),
p
,
stream
););
// NOLINT(*)
);
// NOLINT(*)
}
template
<
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
cast_gated_bwd
(
const
Tensor
&
input
,
const
Tensor
&
grad
,
Tensor
*
output
,
ParamOP
&
p
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
(),
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
output
->
dtype
(),
OType
,
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
DGatedActivationKernelLauncher
<
nvec
,
fp32
,
ParamOP
,
ActOP
,
DActOP
>
(
reinterpret_cast
<
const
IType
*>
(
grad
.
data
.
dptr
),
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
OType
*>
(
output
->
data
.
dptr
),
reinterpret_cast
<
const
fp32
*>
(
output
->
scale
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
amax
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
grad
.
flat_first_dim
(),
grad
.
flat_last_dim
(),
p
,
stream
););
// NOLINT(*)
);
// NOLINT(*)
}
}
// namespace fp8
}
// namespace dispatch
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GATED_FP8_CUH_
transformer_engine/common/cast/fp8/quantize_fp8.cuh
0 → 100644
View file @
c1a1c04e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize_fp8.cuh
* \brief CUDA kernels to quantize to FP8.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
#include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits>
#include "../../common.h"
#include "../../transpose/cast_transpose.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../util/vectorized_pointwise.h"
#include "../../utils.cuh"
#include "../core/common.cuh"
namespace
transformer_engine
{
namespace
dispatch
{
namespace
fp8
{
namespace
quantize_2D_kernel
{
constexpr
size_t
FP8_CHUNK_DIM_Y
=
128
;
constexpr
size_t
FP8_CHUNK_DIM_X
=
128
;
constexpr
size_t
FP8_THREADS_PER_CHUNK
=
128
;
constexpr
size_t
FP8_BUFFERS_NUM
=
2
;
constexpr
size_t
FP8_PREFETCH_BUFFERS_NUM
=
1
;
static_assert
(
FP8_PREFETCH_BUFFERS_NUM
<
FP8_BUFFERS_NUM
);
constexpr
size_t
FP8_BUFFER_DIM_Y
=
16
;
constexpr
size_t
FP8_BUFFER_DIM_X
=
FP8_CHUNK_DIM_X
;
// 128
constexpr
size_t
FP8_SHMEM_DIM_Y
=
FP8_BUFFER_DIM_Y
;
// 16
constexpr
size_t
FP8_SHMEM_DIM_X
=
FP8_BUFFER_DIM_X
;
// 128
constexpr
size_t
FP8_BUFF_STAGES_NUM
=
FP8_BUFFER_DIM_Y
;
// 16
constexpr
size_t
FP8_ITERATIONS
=
FP8_CHUNK_DIM_Y
/
FP8_BUFFER_DIM_Y
;
// 8 = 128 / 16
static_assert
(
FP8_ITERATIONS
>=
FP8_PREFETCH_BUFFERS_NUM
);
#ifndef __HIP_PLATFORM_AMD__
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
FP8_THREADS_PER_CHUNK
)
cast_fp8_2D_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
const
__grid_constant__
CUtensorMap
tensor_map_act_input
,
const
__grid_constant__
CUtensorMap
tensor_map_output
,
float
*
const
dbias_workspace
,
float
*
const
amax_ptr
,
float
*
const
scale_inv_ptr
,
const
float
*
const
scale_ptr
,
const
size_t
rows
,
const
size_t
cols
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const
size_t
block_offset_Y
=
blockIdx
.
y
*
FP8_CHUNK_DIM_Y
;
const
size_t
block_offset_X
=
blockIdx
.
x
*
FP8_CHUNK_DIM_X
;
const
size_t
tid_Y
=
threadIdx
.
x
/
FP8_THREADS_PER_CHUNK
;
const
size_t
tid_X
=
threadIdx
.
x
%
FP8_THREADS_PER_CHUNK
;
const
size_t
thread_offset_Y
=
tid_Y
;
const
size_t
thread_offset_X
=
tid_X
;
const
size_t
dbias_offset_Y
=
blockIdx
.
y
+
tid_Y
;
const
size_t
my_column
=
blockIdx
.
x
*
FP8_CHUNK_DIM_X
+
thread_offset_X
;
const
bool
col_out_of_bounds
=
my_column
>=
cols
;
const
size_t
dbias_stride
=
cols
;
float
partial_dbias
=
0.
f
;
float
amax
=
0
;
const
float
scale
=
(
scale_ptr
!=
nullptr
)
?
*
scale_ptr
:
1
;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__
alignas
(
TMA_SHMEM_ALIGNMENT
)
IType
in_sh
[
FP8_BUFFERS_NUM
][
FP8_SHMEM_DIM_Y
][
FP8_SHMEM_DIM_X
];
__shared__
alignas
(
TMA_SHMEM_ALIGNMENT
)
IType
act_in_sh
[
FP8_BUFFERS_NUM
][
FP8_SHMEM_DIM_Y
][
FP8_SHMEM_DIM_X
];
__shared__
alignas
(
TMA_SHMEM_ALIGNMENT
)
OType
out_sh
[
FP8_BUFFERS_NUM
][
FP8_SHMEM_DIM_Y
][
FP8_SHMEM_DIM_X
];
constexpr
size_t
shmem_buff_size
=
sizeof
(
in_sh
)
/
FP8_BUFFERS_NUM
;
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
alignas
(
8
)
uint64_t
mbar
[
FP8_ITERATIONS
];
initialize_barriers
<
FP8_ITERATIONS
,
FP8_THREADS_PER_CHUNK
>
(
mbar
,
is_master_thread
);
int
parity
=
0
;
const
size_t
chunk_offset_Y
=
block_offset_Y
;
const
size_t
chunk_offset_X
=
block_offset_X
;
#pragma unroll
for
(
int
prefetch_buff
=
0
;
prefetch_buff
<
FP8_PREFETCH_BUFFERS_NUM
;
++
prefetch_buff
)
{
const
size_t
chunk_stage_offset_Y
=
chunk_offset_Y
+
prefetch_buff
*
FP8_BUFFER_DIM_Y
;
const
size_t
chunk_stage_offset_X
=
chunk_offset_X
;
if
constexpr
(
IS_DACT
)
{
copy_2d_to_sharedx2
(
&
in_sh
[
prefetch_buff
],
&
tensor_map_input
,
chunk_stage_offset_X
,
chunk_stage_offset_Y
,
&
act_in_sh
[
prefetch_buff
],
&
tensor_map_act_input
,
chunk_stage_offset_X
,
chunk_stage_offset_Y
,
shmem_buff_size
,
&
mbar
[
prefetch_buff
],
is_master_thread
);
}
else
{
copy_2d_to_shared
(
&
in_sh
[
prefetch_buff
],
&
tensor_map_input
,
chunk_stage_offset_X
,
chunk_stage_offset_Y
,
shmem_buff_size
,
&
mbar
[
prefetch_buff
],
is_master_thread
);
}
}
#pragma unroll
for
(
int
iter
=
0
;
iter
<
FP8_ITERATIONS
;
++
iter
)
{
const
size_t
buff
=
iter
%
FP8_BUFFERS_NUM
;
const
size_t
next_iter
=
iter
+
FP8_PREFETCH_BUFFERS_NUM
;
const
size_t
row_base
=
block_offset_Y
+
iter
*
FP8_BUFFER_DIM_Y
;
if
(
next_iter
<
FP8_ITERATIONS
)
{
const
size_t
next_buff
=
next_iter
%
FP8_BUFFERS_NUM
;
const
size_t
chunk_it_offset_y
=
chunk_offset_Y
+
next_iter
*
FP8_BUFFER_DIM_Y
;
const
size_t
chunk_it_offset_x
=
chunk_offset_X
;
if
constexpr
(
IS_DACT
)
{
copy_2d_to_sharedx2
(
&
in_sh
[
next_buff
],
&
tensor_map_input
,
chunk_it_offset_x
,
chunk_it_offset_y
,
&
act_in_sh
[
next_buff
],
&
tensor_map_act_input
,
chunk_it_offset_x
,
chunk_it_offset_y
,
shmem_buff_size
,
&
mbar
[
next_iter
],
is_master_thread
);
}
else
{
copy_2d_to_shared
(
&
in_sh
[
next_buff
],
&
tensor_map_input
,
chunk_it_offset_x
,
chunk_it_offset_y
,
shmem_buff_size
,
&
mbar
[
next_iter
],
is_master_thread
);
}
}
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
iter
],
parity
);
#pragma unroll
for
(
int
stage
=
0
;
stage
<
FP8_BUFF_STAGES_NUM
;
++
stage
)
{
const
size_t
stage_offset_Y
=
stage
;
const
size_t
shmem_offset_y
=
thread_offset_Y
+
stage_offset_Y
;
const
size_t
shmem_offset_x
=
thread_offset_X
;
const
size_t
row
=
row_base
+
shmem_offset_y
;
const
bool
row_out_of_bounds
=
row
>=
rows
;
const
bool
out_of_bounds
=
col_out_of_bounds
||
row_out_of_bounds
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
buff
][
shmem_offset_y
][
shmem_offset_x
]);
if
constexpr
(
IS_DACT
)
{
float
act_in_elt
=
static_cast
<
float
>
(
act_in_sh
[
buff
][
shmem_offset_y
][
shmem_offset_x
]);
elt
*=
OP
(
act_in_elt
,
{});
}
if
constexpr
(
IS_DBIAS
)
{
if
constexpr
(
IS_DACT
)
{
if
(
!
out_of_bounds
)
{
partial_dbias
+=
elt
;
}
}
else
{
// If no activation, elt is 0 so we can safely do this
partial_dbias
+=
elt
;
}
}
__builtin_assume
(
amax
>=
0
);
if
(
IS_DACT
)
{
if
(
!
out_of_bounds
)
{
amax
=
fmaxf
(
amax
,
fabsf
(
elt
));
}
}
else
{
// If no activation, elt is 0 so we can safely do this
amax
=
fmaxf
(
amax
,
fabsf
(
elt
));
}
out_sh
[
buff
][
shmem_offset_y
][
shmem_offset_x
]
=
static_cast
<
OType
>
(
elt
*
scale
);
}
// Wait for shared memory writes to be visible to TMA engine.
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
const
size_t
chunk_it_offset_y
=
chunk_offset_Y
+
iter
*
FP8_BUFFER_DIM_Y
;
const
size_t
chunk_it_offset_x
=
chunk_offset_X
;
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output
),
chunk_it_offset_x
,
chunk_it_offset_y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_sh
[
buff
]));
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
// Wait for TMA transfer to have finished reading shared memory.
ptx
::
cp_async_bulk_wait_group_read
<
FP8_PREFETCH_BUFFERS_NUM
>
();
}
}
ptx
::
cp_async_bulk_wait_group_read
<
0
>
();
__syncthreads
();
parity
^=
1
;
if
constexpr
(
IS_DBIAS
)
{
const
size_t
dbias_offset_X
=
my_column
;
const
size_t
dbias_offset
=
dbias_offset_Y
*
dbias_stride
+
dbias_offset_X
;
if
(
!
col_out_of_bounds
)
{
dbias_workspace
[
dbias_offset
]
=
partial_dbias
;
}
}
if
(
amax_ptr
!=
nullptr
)
{
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
// Reduce the amax over the block
amax
=
reduce_max
<
FP8_THREADS_PER_CHUNK
/
THREADS_PER_WARP
>
(
amax
,
warp_id
);
// Update the global amax
if
(
is_master_thread
)
{
atomicMaxFloat
(
amax_ptr
,
amax
);
}
}
// Update scale-inverse
if
(
is_master_thread
&&
blockIdx
.
x
==
0
&&
(
scale_inv_ptr
!=
nullptr
))
{
reciprocal
<
float
>
(
scale_inv_ptr
,
scale
);
}
destroy_barriers
<
FP8_ITERATIONS
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
#endif
}
// namespace quantize_2D_kernel
namespace
quantize_1D_kernel
{
using
namespace
quantize_2D_kernel
;
constexpr
size_t
CHUNKS_PER_BLOCK
=
128
;
constexpr
size_t
THREADS_PER_BLOCK
=
FP8_THREADS_PER_CHUNK
;
constexpr
size_t
CHUNK_SIZE
=
THREADS_PER_BLOCK
;
constexpr
size_t
ELEMS_PER_BLOCK
=
CHUNKS_PER_BLOCK
*
CHUNK_SIZE
;
constexpr
size_t
CHUNKS_PER_ITERATION
=
32
;
constexpr
size_t
SHMEM_DIM
=
CHUNKS_PER_ITERATION
*
CHUNK_SIZE
;
constexpr
size_t
ITERATIONS
=
CHUNKS_PER_BLOCK
/
CHUNKS_PER_ITERATION
;
constexpr
size_t
SHMEM_BUFFERS
=
2
;
static_assert
(
CHUNKS_PER_BLOCK
%
CHUNKS_PER_ITERATION
==
0
);
template
<
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
THREADS_PER_BLOCK
)
cast_fp8_1D_kernel
(
const
IType
*
input_ptr
,
OType
*
output_ptr
,
float
*
const
amax_ptr
,
float
*
const
scale_inv_ptr
,
const
float
*
const
scale_ptr
,
const
size_t
N
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const
size_t
block_offset
=
blockIdx
.
x
*
ELEMS_PER_BLOCK
;
const
IType
*
input
=
input_ptr
+
block_offset
;
OType
*
output
=
output_ptr
+
block_offset
;
float
amax
=
0
;
const
float
scale
=
(
scale_ptr
!=
nullptr
)
?
*
scale_ptr
:
1
;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__
alignas
(
TMA_SHMEM_ALIGNMENT
)
IType
in_sh
[
SHMEM_BUFFERS
][
SHMEM_DIM
];
__shared__
alignas
(
TMA_SHMEM_ALIGNMENT
)
OType
out_sh
[
SHMEM_BUFFERS
][
SHMEM_DIM
];
constexpr
size_t
transaction_size_IN
=
sizeof
(
in_sh
)
/
SHMEM_BUFFERS
;
constexpr
size_t
transaction_size_OUT
=
sizeof
(
out_sh
)
/
SHMEM_BUFFERS
;
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
alignas
(
8
)
uint64_t
mbar
[
ITERATIONS
];
initialize_barriers
<
ITERATIONS
,
THREADS_PER_BLOCK
>
(
mbar
,
is_master_thread
);
int
parity
=
0
;
copy_1d_to_shared
(
&
(
in_sh
[
0
]),
input
,
transaction_size_IN
,
&
(
mbar
[
0
]),
is_master_thread
);
#pragma unroll
for
(
int
iter
=
0
;
iter
<
ITERATIONS
;
++
iter
)
{
const
size_t
buff
=
iter
%
SHMEM_BUFFERS
;
const
size_t
it_offset
=
iter
*
SHMEM_DIM
;
const
size_t
next_iter
=
iter
+
1
;
const
size_t
next_buff
=
next_iter
%
SHMEM_BUFFERS
;
const
size_t
next_iter_offset
=
next_iter
*
SHMEM_DIM
;
if
(
next_iter
<
ITERATIONS
)
{
copy_1d_to_shared
(
&
(
in_sh
[
next_buff
]),
input
+
next_iter_offset
,
transaction_size_IN
,
&
(
mbar
[
next_iter
]),
is_master_thread
);
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
iter
],
parity
);
#pragma unroll
for
(
int
chunk
=
0
;
chunk
<
CHUNKS_PER_ITERATION
;
++
chunk
)
{
const
size_t
shmem_offset
=
chunk
*
CHUNK_SIZE
+
threadIdx
.
x
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
buff
][
shmem_offset
]);
if
constexpr
(
IS_ACT
)
{
elt
=
OP
(
elt
,
{});
}
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
elt
));
out_sh
[
buff
][
shmem_offset
]
=
static_cast
<
OType
>
(
elt
*
scale
);
}
// Wait for shared memory writes to be visible to TMA engine.
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
ptx
::
cp_async_bulk_tensor_1d_shared_to_global
(
reinterpret_cast
<
uint64_t
*>
(
output
+
it_offset
),
reinterpret_cast
<
uint64_t
*>
(
&
out_sh
[
buff
]),
transaction_size_OUT
);
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
// Wait for TMA transfer to have finished reading shared memory.
ptx
::
cp_async_bulk_wait_group_read
<
1
>
();
}
}
ptx
::
cp_async_bulk_wait_group_read
<
0
>
();
__syncthreads
();
if
(
amax_ptr
!=
nullptr
)
{
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
// Reduce the amax over the block
amax
=
reduce_max
<
THREADS_PER_BLOCK
/
THREADS_PER_WARP
>
(
amax
,
warp_id
);
// Update the global amax
if
(
is_master_thread
)
{
atomicMaxFloat
(
amax_ptr
,
amax
);
}
}
// Update scale-inverse
if
(
is_master_thread
&&
blockIdx
.
x
==
0
&&
(
scale_inv_ptr
!=
nullptr
))
{
reciprocal
<
float
>
(
scale_inv_ptr
,
scale
);
}
destroy_barriers
<
ITERATIONS
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// namespace quantize_1D_kernel
template
<
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
quantize_1D
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
using
namespace
quantize_1D_kernel
;
const
size_t
N
=
product
(
input
.
data
.
shape
);
const
bool
isFullTile
=
(
N
%
ELEMS_PER_BLOCK
==
0
);
NVTE_CHECK
(
isFullTile
,
"Only full tiles are supported."
);
NVTE_CHECK
(
is_fp8_dtype
(
output
->
dtype
())
||
is_int8_dtype
(
output
->
dtype
()),
"Output must have FP8 or int8 type."
);
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated"
);
const
size_t
chunks
=
DIVUP
(
N
,
CHUNK_SIZE
);
const
size_t
blocks
=
DIVUP
(
chunks
,
CHUNKS_PER_BLOCK
);
float
*
const
amax_ptr
=
reinterpret_cast
<
float
*>
(
output
->
amax
.
dptr
);
float
*
const
scale_inv_ptr
=
reinterpret_cast
<
float
*>
(
output
->
scale_inv
.
dptr
);
const
float
*
const
scale_ptr
=
reinterpret_cast
<
float
*>
(
output
->
scale
.
dptr
);
const
dim3
block
(
THREADS_PER_BLOCK
);
const
dim3
grid
(
blocks
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
(),
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT
(
output
->
dtype
(),
OType
,
const
IType
*
input_ptr
=
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
);
OType
*
output_ptr
=
reinterpret_cast
<
OType
*>
(
output
->
data
.
dptr
);
cast_fp8_1D_kernel
<
IS_ACT
,
ParamOP
,
OP
,
IType
,
OType
><<<
grid
,
block
,
0
,
stream
>>>
(
input_ptr
,
output_ptr
,
amax_ptr
,
scale_inv_ptr
,
scale_ptr
,
N
););
// NOLINT(*)
);
// NOLINT(*)
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
quantize_2D
(
const
Tensor
&
input
,
const
Tensor
*
act_input
,
Tensor
*
output
,
Tensor
*
dbias
,
Tensor
*
workspace
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
#else
using
namespace
quantize_2D_kernel
;
checkCuDriverContext
(
stream
);
const
size_t
rows
=
input
.
flat_first_dim
();
const
size_t
cols
=
input
.
flat_last_dim
();
const
size_t
chunks_Y
=
DIVUP
(
rows
,
FP8_CHUNK_DIM_Y
);
const
size_t
chunks_X
=
DIVUP
(
cols
,
FP8_CHUNK_DIM_X
);
const
size_t
blocks_Y
=
chunks_Y
;
const
size_t
blocks_X
=
chunks_X
;
const
size_t
dbias_rows
=
blocks_Y
;
const
size_t
dbias_cols
=
cols
;
NVTE_CHECK
(
is_fp8_dtype
(
output
->
dtype
()),
"Output must have FP8 type."
);
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated"
);
if
constexpr
(
IS_DBIAS
)
{
NVTE_CHECK
(
dbias
->
data
.
dtype
==
input
.
data
.
dtype
,
"DBias must have the same type as input."
);
NVTE_CHECK
(
dbias
->
data
.
shape
==
std
::
vector
<
size_t
>
{
cols
},
"Wrong shape of DBias."
);
NVTE_CHECK
(
workspace
!=
nullptr
,
"Workspace must be a tensor."
);
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
workspace
->
data
.
shape
=
{
dbias_rows
,
dbias_cols
};
workspace
->
data
.
dtype
=
DType
::
kFloat32
;
return
;
}
}
float
*
const
workspace_ptr
=
IS_DBIAS
?
reinterpret_cast
<
float
*>
(
workspace
->
data
.
dptr
)
:
nullptr
;
float
*
const
amax_ptr
=
reinterpret_cast
<
float
*>
(
output
->
amax
.
dptr
);
float
*
const
scale_inv_ptr
=
reinterpret_cast
<
float
*>
(
output
->
scale_inv
.
dptr
);
float
*
const
scale_ptr
=
reinterpret_cast
<
float
*>
(
output
->
scale
.
dptr
);
const
dim3
block
(
FP8_THREADS_PER_CHUNK
);
const
dim3
grid
(
blocks_X
,
blocks_Y
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
output
->
data
.
dtype
,
OType
,
alignas
(
64
)
CUtensorMap
tensor_map_input
{};
alignas
(
64
)
CUtensorMap
tensor_map_act_input
{};
alignas
(
64
)
CUtensorMap
tensor_map_output
{};
create_2D_tensor_map
(
tensor_map_input
,
input
.
data
,
rows
,
cols
,
FP8_SHMEM_DIM_Y
,
FP8_SHMEM_DIM_X
,
cols
,
0
,
typeToNumBits
(
input
.
data
.
dtype
));
if
constexpr
(
IS_DACT
)
{
create_2D_tensor_map
(
tensor_map_act_input
,
act_input
->
data
,
rows
,
cols
,
FP8_SHMEM_DIM_Y
,
FP8_SHMEM_DIM_X
,
cols
,
0
,
typeToNumBits
(
input
.
data
.
dtype
));
}
create_2D_tensor_map
(
tensor_map_output
,
output
->
data
,
rows
,
cols
,
FP8_SHMEM_DIM_Y
,
FP8_SHMEM_DIM_X
,
cols
,
0
,
typeToNumBits
(
output
->
data
.
dtype
));
cast_fp8_2D_kernel
<
IS_DBIAS
,
IS_DACT
,
ParamOP
,
OP
,
IType
,
OType
>
<<<
grid
,
block
,
0
,
stream
>>>
(
tensor_map_input
,
tensor_map_act_input
,
tensor_map_output
,
workspace_ptr
,
amax_ptr
,
scale_inv_ptr
,
scale_ptr
,
rows
,
cols
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
if
constexpr
(
IS_DBIAS
)
{
common
::
reduce_dbias
<
IType
>
(
workspace_ptr
,
dbias
,
dbias_rows
,
dbias_cols
,
stream
);
});
// NOLINT(*)
);
// NOLINT(*)
#endif
}
namespace
detail
{
using
Empty
=
transformer_engine
::
Empty
;
__device__
inline
float
identity
(
float
value
,
const
Empty
&
)
{
return
value
;
}
}
// namespace detail
#ifdef __HIP_PLATFORM_AMD__
template
<
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
struct
KernelType
{
static
constexpr
auto
op
=
OP
;
};
#endif
template
<
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
CastVectorizedUnaryKernelLauncher
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
using
kernel
=
KernelType
<
ParamOP
,
OP
>
;
constexpr
float
(
*
UnaryOP
)(
float
,
const
ParamOP
&
)
=
(
kernel
::
op
==
nullptr
)
?
KernelType
<
ParamOP
,
&
detail
::
identity
>::
op
:
kernel
::
op
;
#else
constexpr
float
(
*
UnaryOP
)(
float
,
const
ParamOP
&
)
=
(
OP
==
nullptr
)
?
detail
::
identity
:
OP
;
#endif
const
size_t
N
=
product
(
input
.
data
.
shape
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8
(
output
->
data
.
dtype
,
OType
,
if
(
!
is_fp8_dtype
(
output
->
data
.
dtype
)
||
is_tensor_scaling
(
output
->
scaling_mode
))
{
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
VectorizedUnaryKernelLauncher
<
nvec
,
ParamOP
,
UnaryOP
>
(
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
const
fp32
*>
(
noop
->
data
.
dptr
),
reinterpret_cast
<
OType
*>
(
output
->
data
.
dptr
),
reinterpret_cast
<
const
fp32
*>
(
output
->
scale
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
amax
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
N
,
{},
stream
);
}
else
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output
->
scaling_mode
)
+
"."
);
});
// NOLINT(*)
);
// NOLINT(*)
}
template
<
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
CastVectorizedUnaryGradKernelLauncher
(
const
Tensor
&
grad
,
const
Tensor
*
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
using
kernel
=
KernelType
<
ParamOP
,
OP
>
;
constexpr
float
(
*
UnaryOP
)(
float
,
const
ParamOP
&
)
=
(
kernel
::
op
==
nullptr
)
?
KernelType
<
ParamOP
,
&
detail
::
identity
>::
op
:
kernel
::
op
;
#else
constexpr
float
(
*
UnaryOP
)(
float
,
const
ParamOP
&
)
=
(
OP
==
nullptr
)
?
detail
::
identity
:
OP
;
#endif
const
size_t
N
=
product
(
input
->
data
.
shape
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
->
data
.
dtype
,
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8
(
output
->
data
.
dtype
,
OType
,
if
(
!
is_fp8_dtype
(
output
->
data
.
dtype
)
||
is_tensor_scaling
(
output
->
scaling_mode
))
{
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
VectorizedUnaryGradKernelLauncher
<
nvec
,
ParamOP
,
UnaryOP
>
(
reinterpret_cast
<
const
IType
*>
(
grad
.
data
.
dptr
),
reinterpret_cast
<
const
IType
*>
(
input
->
data
.
dptr
),
reinterpret_cast
<
OType
*>
(
output
->
data
.
dptr
),
reinterpret_cast
<
const
fp32
*>
(
output
->
scale
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
amax
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
N
,
{},
stream
);
}
else
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output
->
scaling_mode
)
+
"."
);
});
// NOLINT(*)
);
// NOLINT(*)
}
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
quantize
(
const
Tensor
&
input
,
const
Tensor
*
act_input
,
const
Tensor
*
noop
,
Tensor
*
output
,
Tensor
*
dbias
,
Tensor
*
workspace
,
cudaStream_t
stream
)
{
using
namespace
quantize_1D_kernel
;
CheckNoopTensor
(
*
noop
,
"cast_noop"
);
CheckInputTensor
(
input
,
"cast_input"
);
CheckOutputTensor
(
*
output
,
"cast_output"
);
if
constexpr
(
IS_DBIAS
)
{
NVTE_CHECK
(
dbias
!=
nullptr
);
CheckOutputTensor
(
*
dbias
,
"dbias"
);
}
if
constexpr
(
IS_DACT
)
{
NVTE_CHECK
(
act_input
!=
nullptr
);
CheckInputTensor
(
*
act_input
,
"activation_input"
);
NVTE_CHECK
(
input
.
dtype
()
==
act_input
->
dtype
(),
"Types of both inputs must match."
);
NVTE_CHECK
(
input
.
data
.
shape
==
act_input
->
data
.
shape
,
"Shapes of both inputs must match."
);
}
NVTE_CHECK
(
!
is_fp8_dtype
(
input
.
dtype
()),
"Input must be in higher precision."
);
NVTE_CHECK
(
output
->
data
.
shape
==
input
.
data
.
shape
,
"Input and output shapes need to match."
);
// Supported by the Arch >= 10.0
if
(
is_supported_by_CC_100
())
{
if
(
!
IS_DBIAS
&&
!
IS_DACT
)
{
if
(
common
::
full_tile_1D_tensor
(
output
,
ELEMS_PER_BLOCK
)
&&
is_fp8_dtype
(
output
->
dtype
())
&&
is_aligned_tensor_data
(
input
,
TMA_GMEM_ALIGNMENT
)
&&
is_aligned_tensor_data
(
*
output
,
TMA_GMEM_ALIGNMENT
))
{
// Aligned AND FP8
quantize_1D
<
IS_ACT
,
ParamOP
,
OP
>
(
input
,
output
,
stream
);
}
else
{
// Unaligned
CastVectorizedUnaryKernelLauncher
<
ParamOP
,
OP
>
(
input
,
noop
,
output
,
stream
);
}
}
else
if
(
!
IS_DBIAS
&&
IS_DACT
)
{
if
(
common
::
dimensions_supported_by_TMA
(
output
)
&&
is_fp8_dtype
(
output
->
dtype
())
&&
is_aligned_tensor_data
(
input
,
TMA_GMEM_ALIGNMENT
)
&&
is_aligned_tensor_data
(
*
output
,
TMA_GMEM_ALIGNMENT
)
&&
is_aligned_tensor_data
(
*
act_input
,
TMA_GMEM_ALIGNMENT
))
{
// Aligned AND FP8 (+dAct)
quantize_2D
<
IS_DBIAS
,
IS_DACT
,
ParamOP
,
OP
>
(
input
,
act_input
,
output
,
dbias
,
workspace
,
stream
);
}
else
{
// Unaligned
CastVectorizedUnaryGradKernelLauncher
<
ParamOP
,
OP
>
(
input
,
act_input
,
output
,
stream
);
}
}
else
{
quantize_2D
<
IS_DBIAS
,
IS_DACT
,
ParamOP
,
OP
>
(
input
,
act_input
,
output
,
dbias
,
workspace
,
stream
);
}
}
else
{
if
(
IS_DBIAS
)
{
// zhongboz: should we just ignore IS_ACT here?
NVTE_ERROR
(
"Not implemented scaling mode or fusion: "
+
to_string
(
output
->
scaling_mode
)
+
" or IS_DBIAS=true"
+
" on GPU with compute capability < 10.0."
);
}
if
(
!
IS_DACT
)
{
CastVectorizedUnaryKernelLauncher
<
ParamOP
,
OP
>
(
input
,
noop
,
output
,
stream
);
}
else
{
CastVectorizedUnaryGradKernelLauncher
<
ParamOP
,
OP
>
(
input
,
act_input
,
output
,
stream
);
}
}
}
}
// namespace fp8
}
// namespace dispatch
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment