Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
27ddce40
Commit
27ddce40
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_main'
parents
d262ef4c
5b3092a0
Changes
208
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1823 additions
and
193 deletions
+1823
-193
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+69
-12
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
.../pytorch/distributed/test_fusible_ops_with_userbuffers.py
+7
-1
tests/pytorch/test_cpu_offloading.py
tests/pytorch/test_cpu_offloading.py
+126
-70
tests/pytorch/test_fused_rope.py
tests/pytorch/test_fused_rope.py
+141
-6
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+154
-21
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+167
-18
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+71
-6
tests/pytorch/test_parallel_cross_entropy.py
tests/pytorch/test_parallel_cross_entropy.py
+38
-21
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+12
-1
tests/pytorch/utils.py
tests/pytorch/utils.py
+1
-1
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+40
-6
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+38
-0
transformer_engine/common/comm_gemm/comm_gemm.cpp
transformer_engine/common/comm_gemm/comm_gemm.cpp
+519
-0
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+26
-16
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
...common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
+1
-1
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
...ngine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
+20
-7
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
...engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
+3
-3
transformer_engine/common/common.cu
transformer_engine/common/common.cu
+21
-1
transformer_engine/common/common.h
transformer_engine/common/common.h
+14
-2
transformer_engine/common/dropout/dropout.cu
transformer_engine/common/dropout/dropout.cu
+355
-0
No files found.
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
27ddce40
...
...
@@ -12,6 +12,8 @@ import argparse
import
warnings
import
pprint
import
yaml
from
contextlib
import
nullcontext
from
functools
import
partial
import
torch
import
torch.distributed
as
dist
...
...
@@ -35,8 +37,9 @@ class multi_module_model(torch.nn.Module):
self
.
num_layers
=
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
([
module
(
*
args
,
**
kwargs
)
for
_
in
range
(
num_layers
)])
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
def
forward
(
self
,
x
,
layer_contexts
):
for
layer
,
context
in
zip
(
self
.
layers
,
layer_contexts
):
with
context
():
x
=
layer
(
x
)
return
x
...
...
@@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None):
default
=
False
,
help
=
"Print out additional debug information."
,
)
parser
.
add_argument
(
"--first-last-layers-bf16"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use bf16 for first and last N layers."
,
)
parser
.
add_argument
(
"--num-layers-at-start-in-bf16"
,
type
=
int
,
default
=
0
,
help
=
"Number of layers at the start to run in bf16."
,
)
parser
.
add_argument
(
"--num-layers-at-end-in-bf16"
,
type
=
int
,
default
=
0
,
help
=
"Number of layers at the end to run in bf16."
,
)
args
=
parser
.
parse_args
(
argv
,
namespace
)
if
args
.
use_cuda_graphs
and
args
.
layer_type
in
[
te
.
MultiheadAttention
,
te
.
TransformerLayer
]:
warnings
.
warn
(
f
"
{
args
.
layer_type
.
__name__
}
does not support CUDA Graphs!"
)
args
.
use_cuda_graphs
=
False
if
not
args
.
first_last_layers_bf16
and
(
args
.
num_layers_at_start_in_bf16
>
0
or
args
.
num_layers_at_end_in_bf16
>
0
):
warnings
.
warn
(
"num-layers-at-start-in-bf16 and num-layers-at-end-in-bf16 are only supported when"
" first-last-layers-bf16 is enabled!"
)
args
.
num_layers_at_start_in_bf16
=
0
args
.
num_layers_at_end_in_bf16
=
0
if
args
.
num_layers_at_start_in_bf16
+
args
.
num_layers_at_end_in_bf16
>
args
.
num_layers
:
raise
ValueError
(
"num-layers-at-start-in-bf16 + num-layers-at-end-in-bf16 must be less than or equal to"
" num-layers!"
)
return
args
...
...
@@ -381,10 +418,21 @@ def _train(opts):
"qkv_dgrad"
:
{
"method"
:
"ring_exchange"
},
"fc1_dgrad"
:
{
"method"
:
"ring_exchange"
},
}
quantization_modes
=
[
(
te
.
module
.
base
.
UserBufferQuantizationMode
.
FP8
if
opts
.
fp8
else
te
.
module
.
base
.
UserBufferQuantizationMode
.
NONE
)
]
if
opts
.
first_last_layers_bf16
and
opts
.
fp8
:
quantization_modes
.
append
(
te
.
module
.
base
.
UserBufferQuantizationMode
.
NONE
)
te
.
module
.
base
.
initialize_ub
(
[
opts
.
seq_length
*
opts
.
batch_size
,
opts
.
num_heads
*
opts
.
head_dim
],
opts
.
tp
,
use_fp8
=
opts
.
fp8
,
quantization_modes
=
quantization_modes
,
dtype
=
torch
.
bfloat16
,
bootstrap_backend
=
opts
.
bootstrap_backend
,
ub_cfgs
=
ub_cfgs
if
opts
.
ub_cfg
is
None
else
opts
.
ub_cfg
,
...
...
@@ -423,6 +471,16 @@ def _train(opts):
elif
opts
.
quantization
==
"mxfp8"
:
fp8_recipe
=
MXFP8BlockScaling
()
layer_contexts
=
[
(
partial
(
te
.
fp8_autocast
,
enabled
=
opts
.
fp8
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
nccl_world
)
if
opts
.
num_layers_at_start_in_bf16
<=
i
and
i
<
(
opts
.
num_layers
-
opts
.
num_layers_at_end_in_bf16
)
else
nullcontext
)
for
i
in
range
(
opts
.
num_layers
)
]
# Prepare random input tensors
test_x
=
torch
.
randn
(
input_shape
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
True
)
test_x
.
retain_grad
()
...
...
@@ -435,8 +493,7 @@ def _train(opts):
# Execute fwd/bwd and collect tensors to test
def
run_fwd_bwd
(
model
,
x
):
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
with
te
.
fp8_autocast
(
enabled
=
opts
.
fp8
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
nccl_world
):
y
=
model
(
x
)
y
=
model
(
x
,
layer_contexts
)
if
isinstance
(
y
,
tuple
):
out
,
*
_
=
y
else
:
...
...
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
View file @
27ddce40
...
...
@@ -506,7 +506,13 @@ def main() -> None:
model_config
.
num_heads
*
model_config
.
head_dim
,
],
torch
.
distributed
.
get_world_size
(
group
),
use_fp8
=
model_config
.
quantization
is
not
None
,
quantization_modes
=
[
(
te
.
module
.
base
.
UserBufferQuantizationMode
.
FP8
if
model_config
.
quantization
is
not
None
else
te
.
module
.
base
.
UserBufferQuantizationMode
.
NONE
)
],
dtype
=
model_config
.
dtype
,
bootstrap_backend
=
bootstrap_backend
,
ub_cfgs
=
userbuffer_configs
,
...
...
tests/pytorch/test_cpu_offloading.py
View file @
27ddce40
...
...
@@ -2,8 +2,11 @@
#
# See LICENSE for license information.
import
contextlib
import
gc
import
os
from
contextlib
import
nullcontext
from
typing
import
Iterable
,
Optional
import
pytest
import
torch
...
...
@@ -11,15 +14,16 @@ import transformer_engine.pytorch as te
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
from
utils
import
ModelConfig
,
get_available_attention_backends
# Check
if FP8 is supported
# Check
supported quantization schemes
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_recipes
=
[
None
]
quantization_recipes
:
Optional
[
recipe
.
Recipe
]
=
[
None
]
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
quantization_recipes
.
extend
((
recipe
.
Float8CurrentScaling
(),
recipe
.
DelayedScaling
()))
model_config
=
{
"small"
:
ModelConfig
(
8
,
512
,
8
,
64
,
num_layers
=
5
,
eps
=
0.1
),
...
...
@@ -48,85 +52,139 @@ model_types = {
"transformer_layer"
:
lambda
:
te
.
TransformerLayer
(
SIZE
,
SIZE
,
NUM_HEADS
,
params_dtype
=
torch
.
bfloat16
,
hidden_dropout
=
0.0
),
"linear_op"
:
lambda
:
te
.
ops
.
Linear
(
SIZE
,
SIZE
,
dtype
=
torch
.
bfloat16
),
"layernorm_mlp_ops"
:
lambda
:
te
.
ops
.
Sequential
(
te
.
ops
.
LayerNorm
(
SIZE
,
dtype
=
torch
.
bfloat16
),
te
.
ops
.
Linear
(
SIZE
,
SIZE
,
dtype
=
torch
.
bfloat16
),
te
.
ops
.
GELU
(),
te
.
ops
.
Linear
(
SIZE
,
SIZE
,
dtype
=
torch
.
bfloat16
),
),
}
def
_get_input
():
return
torch
.
empty
((
128
,
SIZE
,
SIZE
),
dtype
=
torch
.
bfloat16
).
cuda
()
def
_make_input
()
->
torch
.
Tensor
:
"""Generate random input tensor."""
return
torch
.
randn
(
(
128
,
SIZE
,
SIZE
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
,
requires_grad
=
True
,
)
def
_get_fp8_weight_cache_size
(
models
,
fp8_recipe
):
"""
Calculate the total FP8 weight cache size (in MB) for a list of models.
"""
if
fp8_recipe
is
None
:
def
_warmup_model
(
modules
:
Iterable
[
torch
.
nn
.
Module
],
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
)
->
None
:
"""Perform forward and backward pass"""
tensor
=
_make_input
()
for
module
in
modules
:
with
te
.
fp8_autocast
(
enabled
=
quantization_recipe
is
not
None
,
fp8_recipe
=
quantization_recipe
,
):
tensor
=
module
(
tensor
)
tensor
.
sum
().
backward
()
def
_estimate_cached_weight_size
(
model_name
:
str
,
modules
:
Iterable
[
torch
.
nn
.
Module
],
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
)
->
float
:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if
quantization_recipe
is
None
:
return
0
params_bytes
=
0
for
model
in
models
:
for
name
,
param
in
model
.
named_parameters
():
if
"weight"
in
name
:
params_bytes
+=
param
.
numel
()
# Count number of weight param elements
param_elements
=
0
for
module
in
modules
:
for
param
in
module
.
parameters
():
if
param
.
dim
()
==
2
:
param_elements
+=
param
.
numel
()
# FP8 tensor-scaling caches one byte per element
if
quantization_recipe
.
delayed
()
or
quantization_recipe
.
float8_current_scaling
():
if
not
is_non_tn_fp8_gemm_supported
()
and
model_name
not
in
(
"linear_op"
,
"layernorm_mlp_ops"
,
):
# Modules do not deallocate FP8 transpose for weights
return
2
*
param_elements
/
1024
**
2
return
param_elements
/
1024
**
2
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if
quantization_recipe
.
mxfp8
():
if
model_name
not
in
(
"linear_op"
,
"layernorm_mlp_ops"
):
# Modules do not deallocate column-wise MXFP8 data for weights
return
2
*
param_elements
*
(
1
+
1
/
32
)
/
1024
**
2
return
param_elements
*
(
1
+
1
/
32
)
/
1024
**
2
raise
NotImplementedError
(
f
"Unrecognized recipe (
{
quantization_recipe
}
)"
)
def
_measure_cached_memory
(
modules
:
Iterable
[
torch
.
nn
.
Module
],
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
cpu_offload
:
bool
,
)
->
float
:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
Memory measurement excludes the input and output tensors.
# One byte for columnwise and one byte for rowwise,
# hence multiply by 2 and convert to MB
# there is 1 byte of scale per 32 elements in mxFP8
factor_for_scale_inv_tensor
=
(
1
+
1
/
32
)
if
fp8_recipe
.
mxfp8
()
else
1
return
(
2
*
params_bytes
*
factor_for_scale_inv_tensor
)
/
(
1024
**
2
)
"""
# Reset memory
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
_measure_memory_between_forward_and_backward
(
models
,
fp8_recipe
,
cpu_offload
):
tensor
=
_get_input
()
# Context and sync function for CPU offloading
if
cpu_offload
:
offload_context
,
sync_function
=
te
.
get_cpu_offload_context
(
enabled
=
True
,
num_layers
=
len
(
mod
els
)
-
1
,
model_layers
=
len
(
mod
els
)
,
num_layers
=
len
(
mod
ules
)
,
model_layers
=
len
(
mod
ules
)
+
1
,
offload_activations
=
True
,
offload_weights
=
False
,
)
else
:
offload_context
=
nullcontext
()
offload_context
=
contextlib
.
nullcontext
()
sync_function
=
lambda
x
:
x
for
model
in
models
:
# Forward pass, with dummy step to trigger offload for last module
inp
=
_make_input
()
tensor
=
inp
memory_before_forward
=
torch
.
cuda
.
memory_allocated
()
/
(
1024
**
2
)
for
module
in
modules
:
with
te
.
fp8_autocast
(
enabled
=
fp8
_recipe
is
not
None
,
fp8_recipe
=
fp8
_recipe
enabled
=
quantization
_recipe
is
not
None
,
fp8_recipe
=
quantization
_recipe
),
offload_context
:
tensor
=
mode
l
(
tensor
)
tensor
=
mod
ul
e
(
tensor
)
tensor
=
sync_function
(
tensor
)
with
offload_context
:
tensor
=
tensor
.
clone
()
tensor
=
sync_function
(
tensor
)
memory_after_forward
=
(
torch
.
cuda
.
memory_allocated
()
-
tensor
.
nbytes
)
/
(
1024
**
2
)
max_mem_used
=
torch
.
cuda
.
memory_allocated
()
/
(
1024
**
2
)
torch
.
cuda
.
synchronize
()
# Backward pass
tensor
.
sum
().
backward
()
torch
.
cuda
.
synchronize
()
return
max_mem_used
# Memory usage in MiB
return
memory_after_forward
-
memory_before_forward
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model_key"
,
model_types
.
keys
())
def
test_cpu_offload
(
fp8_recipe
,
model_key
)
->
None
:
"""
We run three configurations:
(1) No offloading: All activations remain on the GPU between forward and backward passes.
(2) No offloading (one layer): Only the first layer's activations remain on the GPU between
forward and backward passes.
(3) With offloading (all layers): Only the last layer's activations remain on the GPU
between forward and backward passes, while all other layers are offloaded to the CPU.
We expect the memory consumption of configurations (2) and (3) to be similar, with
the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""
import
gc
@
pytest
.
mark
.
parametrize
(
"quantization_recipe"
,
quantization_recipes
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
model_types
.
keys
())
def
test_cpu_offload
(
quantization_recipe
:
Optional
[
recipe
.
Recipe
],
model_name
:
str
)
->
None
:
"""Check that CPU offloading runs and has expected memory usage."""
gc
.
collect
()
model_cls
=
model_types
[
model_key
]
models_list
=
[
model_cls
()
for
_
in
range
(
NUM_LAYERS
)]
if
model_key
in
[
"multihead_attention"
,
"transformer_layer"
]:
# Construct model
modules_list
=
[
model_types
[
model_name
]()
for
_
in
range
(
NUM_LAYERS
)]
if
model_name
in
[
"multihead_attention"
,
"transformer_layer"
]:
available_backends
,
*
_
=
get_available_attention_backends
(
model_config
[
"small"
],
qkv_dtype
=
torch
.
bfloat16
,
...
...
@@ -138,20 +196,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
without_offloading
=
_measure_memory_between_forward_and_backward
(
models_list
,
fp8_recipe
,
False
)
without_offloading_one_layer
=
_measure_memory_between_forward_and_backward
(
models_list
[:
1
],
fp8_recipe
,
False
)
with_offloading
=
_measure_memory_between_forward_and_backward
(
models_list
,
fp8_recipe
,
True
)
# Warmup
_warmup_model
(
modules_list
,
quantization_recipe
)
assert
with_offloading
<
without_offloading
# Measure cached memory after forward pass
memory_without_offload
=
_measure_cached_memory
(
modules_list
,
quantization_recipe
,
False
)
memory_with_offload
=
_measure_cached_memory
(
modules_list
,
quantization_recipe
,
True
)
#
T
he
only difference between the memory consumption of with_offloading
# and
with
out
_offload
ing_one_layer should be the size of the FP8 weights cache,
# which is not offloaded to the CPU.
memory_consumption_diff
=
abs
(
with_offloading
-
without_offloading_one_layer
)
assert
(
memory_consumption_diff
<
_get_fp8_weight_cache_size
(
models_list
[
1
:],
fp8_recipe
)
+
EPSILON
#
C
he
ck for expected memory usage
assert
memory_
with_offload
<
memory_without_offload
memory_from_cached_weights
=
_estimate_cached_weight_size
(
model_name
,
modules_list
,
quantization_recipe
,
)
assert
abs
(
memory_with_offload
-
memory_from_cached_weights
)
<
EPSILON
tests/pytorch/test_fused_rope.py
View file @
27ddce40
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
Callable
,
Tuple
,
Union
from
typing
import
Callable
,
Tuple
,
Union
,
List
import
math
import
torch
import
pytest
from
transformer_engine.pytorch.attention.rope
import
(
RotaryPositionEmbedding
,
apply_rotary_pos_emb
,
apply_fused_qkv_rotary_pos_emb
,
)
# Gradient is a broadcasted scalar
def
_overlapping_grad
(
output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_overlapping_grad
(
output
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
])
->
torch
.
Tensor
:
if
isinstance
(
output
,
List
):
return
sum
(
t
.
sum
()
*
2
for
t
in
output
)
else
:
return
output
.
sum
()
*
2
# Gradient is a full tensor
def
_non_overlapping_grad
(
output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_non_overlapping_grad
(
output
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
])
->
torch
.
Tensor
:
if
isinstance
(
output
,
List
):
return
sum
(
torch
.
sum
(
t
*
torch
.
ones_like
(
t
))
for
t
in
output
)
else
:
t
=
torch
.
ones_like
(
output
)
return
torch
.
sum
(
output
*
t
)
...
...
@@ -238,3 +245,131 @@ def test_fused_rope_thd(
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
assert
output_fused
.
is_contiguous
()
@
pytest
.
mark
.
parametrize
(
"start_positions"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"seq_length"
,
[
2
,
8
,
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
64
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"rotary_percent"
,
[
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"margin"
,
[
0
,
10
])
@
pytest
.
mark
.
parametrize
(
"tensor_format"
,
[
"sbhd"
,
"bshd"
])
@
pytest
.
mark
.
parametrize
(
"loss_func"
,
[
_overlapping_grad
,
_non_overlapping_grad
])
@
pytest
.
mark
.
parametrize
(
"cp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
True
,
False
])
def
test_fused_qkv_rope
(
dtype
:
torch
.
dtype
,
seq_length
:
int
,
hidden_size
:
int
,
rotary_percent
:
float
,
margin
:
int
,
tensor_format
:
str
,
loss_func
:
Callable
,
cp_size
:
int
,
interleaved
:
bool
,
start_positions
:
bool
,
)
->
None
:
if
margin
==
0
and
start_positions
==
True
:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest
.
skip
(
"Skipping test with margin=0 and start_positions=True"
)
if
start_positions
==
True
and
cp_size
>
1
:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest
.
skip
(
"Skipping test with cp_size>1 and start_positions=True"
)
if
seq_length
-
margin
<
0
:
pytest
.
skip
(
"Skipping test with seq_length - margin < 0"
)
device
=
torch
.
device
(
"cuda:0"
)
batch_size
,
head_num
=
2
,
64
t
=
torch
.
rand
(
(
seq_length
-
margin
,
batch_size
,
head_num
,
hidden_size
*
6
),
dtype
=
dtype
,
device
=
device
,
)
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions
=
(
torch
.
randint
(
0
,
margin
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
if
start_positions
else
None
)
if
tensor_format
==
"bshd"
:
t
=
t
.
transpose
(
0
,
1
).
contiguous
()
t
.
requires_grad
=
True
rotary_pos_emb_q
=
RotaryPositionEmbedding
(
hidden_size
,
rotary_percent
,
interleaved
=
interleaved
)
emb_q
=
rotary_pos_emb_q
(
seq_length
*
cp_size
)
rotary_pos_emb_k
=
RotaryPositionEmbedding
(
hidden_size
,
rotary_percent
,
interleaved
=
interleaved
)
emb_k
=
rotary_pos_emb_k
(
seq_length
*
cp_size
)
for
cp_rank
in
range
(
cp_size
):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
t_clone
=
t
.
clone
()
(
query
,
key
,
value
)
=
torch
.
split
(
t_clone
,
[
hidden_size
*
4
,
hidden_size
,
hidden_size
],
dim
=
3
)
query
=
query
.
reshape
(
query
.
shape
[
0
],
query
.
shape
[
1
],
head_num
*
4
,
hidden_size
)
query_unfused
=
apply_rotary_pos_emb
(
query
,
emb_q
,
tensor_format
=
tensor_format
,
start_positions
=
start_positions
,
interleaved
=
interleaved
,
fused
=
True
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
).
to
(
dtype
)
key_unfused
=
apply_rotary_pos_emb
(
key
,
emb_k
,
tensor_format
=
tensor_format
,
start_positions
=
start_positions
,
interleaved
=
interleaved
,
fused
=
True
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
).
to
(
dtype
)
value_unfused
=
value
loss_unfused
=
loss_func
([
query_unfused
,
key_unfused
,
value_unfused
])
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_unfused
.
backward
()
grad_unfused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
# fused
query_fused
,
key_fused
,
value_fused
=
apply_fused_qkv_rotary_pos_emb
(
t
,
emb_q
,
emb_k
,
tensor_format
=
tensor_format
,
start_positions
=
start_positions
,
interleaved
=
interleaved
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
qkv_split_arg_list
=
[
hidden_size
*
4
,
hidden_size
,
hidden_size
],
)
loss_fused
=
loss_func
([
query_fused
,
key_fused
,
value_fused
])
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_fused
.
backward
()
grad_fused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
torch
.
testing
.
assert_close
(
query_fused
,
query_unfused
)
torch
.
testing
.
assert_close
(
key_fused
,
key_unfused
)
torch
.
testing
.
assert_close
(
value_fused
,
value_unfused
)
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
tests/pytorch/test_fusible_ops.py
View file @
27ddce40
...
...
@@ -22,6 +22,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops.fused
import
(
BackwardActivationBias
,
BackwardAddRMSNorm
,
BackwardLinearAdd
,
BackwardLinearScale
,
ForwardLinearBiasActivation
,
...
...
@@ -1545,7 +1546,10 @@ class TestBasicOps:
torch
.
testing
.
assert_close
(
y2_test
,
y2_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
(
"relu"
,
"gelu"
,
"geglu"
,
"reglu"
,
"swiglu"
))
@
pytest
.
mark
.
parametrize
(
"activation"
,
(
"gelu"
,
"geglu"
,
"qgelu"
,
"qgeglu"
,
"relu"
,
"reglu"
,
"srelu"
,
"sreglu"
,
"silu"
,
"swiglu"
),
)
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
37
,),
(
2
,
13
),
(
32
,
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
...
...
@@ -1564,7 +1568,7 @@ class TestBasicOps:
# Tensor dimensions
in_shape
=
list
(
out_shape
)
if
activation
in
(
"geglu"
,
"reglu"
,
"swiglu"
):
if
activation
in
(
"geglu"
,
"
qgeglu"
,
"reglu"
,
"s
reglu"
,
"swiglu"
):
in_shape
[
-
1
]
*=
2
# Skip invalid configurations
...
...
@@ -1591,14 +1595,26 @@ class TestBasicOps:
y_ref
:
torch
.
Tensor
if
activation
==
"gelu"
:
y_ref
=
torch
.
nn
.
functional
.
gelu
(
x_ref
,
approximate
=
"tanh"
)
elif
activation
==
"relu"
:
y_ref
=
torch
.
nn
.
functional
.
relu
(
x_ref
)
elif
activation
==
"geglu"
:
x1
,
x2
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
y_ref
=
torch
.
nn
.
functional
.
gelu
(
x1
,
approximate
=
"tanh"
)
*
x2
elif
activation
==
"qgelu"
:
y_ref
=
x_ref
*
torch
.
sigmoid
(
1.702
*
x_ref
)
elif
activation
==
"qgeglu"
:
x1
,
x2
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
y_ref
=
x1
*
torch
.
sigmoid
(
1.702
*
x1
)
*
x2
elif
activation
==
"relu"
:
y_ref
=
torch
.
nn
.
functional
.
relu
(
x_ref
)
elif
activation
==
"reglu"
:
x1
,
x2
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
y_ref
=
torch
.
nn
.
functional
.
relu
(
x1
)
*
x2
elif
activation
==
"srelu"
:
y_ref
=
torch
.
nn
.
functional
.
relu
(
x_ref
)
**
2
elif
activation
==
"sreglu"
:
x1
,
x2
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
y_ref
=
torch
.
nn
.
functional
.
relu
(
x1
)
**
2
*
x2
elif
activation
==
"silu"
:
y_ref
=
torch
.
nn
.
functional
.
silu
(
x_ref
)
elif
activation
==
"swiglu"
:
x1
,
x2
=
x_ref
.
chunk
(
2
,
dim
=-
1
)
y_ref
=
torch
.
nn
.
functional
.
silu
(
x1
)
*
x2
...
...
@@ -1610,9 +1626,14 @@ class TestBasicOps:
recipe
=
make_recipe
(
quantization
)
make_op
=
dict
(
gelu
=
te_ops
.
GELU
,
relu
=
te_ops
.
ReLU
,
geglu
=
te_ops
.
GEGLU
,
qgelu
=
te_ops
.
QGELU
,
qgeglu
=
te_ops
.
QGEGLU
,
relu
=
te_ops
.
ReLU
,
reglu
=
te_ops
.
ReGLU
,
srelu
=
te_ops
.
SReLU
,
sreglu
=
te_ops
.
SReGLU
,
silu
=
te_ops
.
SiLU
,
swiglu
=
te_ops
.
SwiGLU
,
)[
activation
]
forward
=
te_ops
.
Sequential
(
...
...
@@ -1742,25 +1763,44 @@ class TestBasicOps:
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"prob"
,
(
0.
1
,
0.5
,
0.75
))
@
pytest
.
mark
.
parametrize
(
"prob"
,
(
0.
0625
,
0.5
,
0.75
))
@
pytest
.
mark
.
parametrize
(
"is_training"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"shape"
,
((
101
,),
(
2
,
4
,
16
)))
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8_current_scaling"
))
@
pytest
.
mark
.
parametrize
(
"shape"
,
((
101
,),
(
2
,
4
,
16
),
(
128
,
128
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
def
test_dropout
(
self
,
*
,
prob
:
float
,
is_training
:
bool
,
quantization
:
Optional
[
str
],
shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
):
# Skip invalid configurations
quantized_input
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
shape
,
device
=
device
)
# Random data
x_ref
=
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
device
)
+
0.5
x_test
=
x_ref
.
clone
().
requires_grad_
()
dy_ref
=
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
device
)
+
0.5
dy_test
=
dy_ref
.
clone
()
# Note: Shift values to make sure inputs are non-zero
x_ref
,
x_test
=
make_reference_and_test_tensors
(
shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_quantized
=
quantized_input
,
)
with
torch
.
no_grad
():
x_test
+=
1
x_ref
.
copy_
(
x_test
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Apply dropout
op
=
te_ops
.
Dropout
(
prob
)
...
...
@@ -1768,17 +1808,20 @@ class TestBasicOps:
op
.
train
()
else
:
op
.
eval
()
y
=
op
(
x_test
)
y
.
backward
(
dy_test
)
y
_test
=
op
(
x_test
)
y
_test
.
backward
(
dy_test
)
# Check values
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
if
is_training
:
mask
=
((
y
!=
0
)
/
(
1
-
prob
)).
to
(
dtype
=
dtype
)
torch
.
testing
.
assert_close
(
y
,
x_ref
*
mask
)
torch
.
testing
.
assert_close
(
x_test
.
grad
,
dy_ref
*
mask
)
tols
=
dtype_tols
(
dtype
)
mask
=
((
y_test
!=
0
)
/
(
1
-
prob
)).
to
(
dtype
=
dtype
)
torch
.
testing
.
assert_close
(
y_test
,
x_ref
*
mask
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
dy_ref
*
mask
,
**
tols
)
else
:
torch
.
testing
.
assert_close
(
y
,
x_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
x_test
.
grad
,
dy_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
y
_test
,
x_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
d
x_test
,
dy_ref
,
rtol
=
0
,
atol
=
0
)
# Hypothesis testing for number of zeros
# Note: A Bernoulli random variable with probability p has
...
...
@@ -1790,9 +1833,11 @@ class TestBasicOps:
# p-value is less than 1% and we assume that the dropout
# distribution is incorrect.
if
is_training
:
prob_observed
=
1
-
torch
.
count_nonzero
(
y
).
item
()
/
y
.
numel
()
z_score
=
(
prob_observed
-
prob
)
/
math
.
sqrt
(
prob
*
(
1
-
prob
)
/
y
.
numel
())
assert
abs
(
z_score
)
<
2.5758
,
"Number of zeros is outside 99% confidence interval"
prob_observed
=
1
-
torch
.
count_nonzero
(
y_test
).
item
()
/
y_test
.
numel
()
z_score
=
(
prob_observed
-
prob
)
/
math
.
sqrt
(
prob
*
(
1
-
prob
)
/
y_test
.
numel
())
assert
(
abs
(
z_score
)
<
2.5758
),
f
"Number of zeros is outside 99% confidence interval (
{
prob
=
}
,
{
prob_observed
=
}
)"
class
TestFusedOps
:
...
...
@@ -2220,6 +2265,94 @@ class TestFusedOps:
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
db_test
,
b_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"weight_shape"
,
((
19
,),
(
64
,)))
@
pytest
.
mark
.
parametrize
(
"in_shape"
,
((
-
1
,),
(
6
,
16
,
-
1
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
(
False
,
True
))
def
test_backward_add_rmsnorm
(
self
,
*
,
weight_shape
:
Iterable
[
int
],
in_shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
eps
:
float
=
0.3
,
zero_centered_gamma
:
bool
,
)
->
None
:
"""Fused backward RMNorm + add"""
# Make input and weight shapes consistent
in_shape
=
list
(
in_shape
)[:
-
1
]
+
list
(
weight_shape
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
w_ref
,
w_test
=
make_reference_and_test_tensors
(
weight_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy1_ref
,
dy1_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
dy2_ref
,
dy2_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
inner_dims
=
tuple
(
range
(
len
(
in_shape
)
-
len
(
weight_shape
),
len
(
in_shape
)))
var_ref
=
x_ref
.
square
().
sum
(
dim
=
inner_dims
,
keepdim
=
True
)
/
math
.
prod
(
weight_shape
)
if
zero_centered_gamma
:
y1_ref
=
x_ref
/
torch
.
sqrt
(
eps
+
var_ref
)
*
(
1
+
w_ref
)
else
:
y1_ref
=
x_ref
/
torch
.
sqrt
(
eps
+
var_ref
)
*
w_ref
y2_ref
=
x_ref
(
y1_ref
*
dy1_ref
+
y2_ref
*
dy2_ref
).
sum
().
backward
()
# Implementation with fusible operations
model
=
te_ops
.
Sequential
(
te_ops
.
MakeExtraOutput
(),
te_ops
.
RMSNorm
(
weight_shape
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
,
zero_centered_gamma
=
zero_centered_gamma
,
),
)
with
torch
.
no_grad
():
model
[
1
].
weight
.
copy_
(
w_test
)
del
w_test
y1_test
,
y2_test
=
model
(
x_test
)
(
y1_test
*
dy1_test
+
y2_test
*
dy2_test
).
sum
().
backward
()
# Check that backward operations have been fused
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
assert
len
(
backward_ops
)
==
1
assert
isinstance
(
backward_ops
[
0
][
0
],
BackwardAddRMSNorm
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
# Check results
y1_test
=
y1_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
y2_test
=
y2_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dw_test
=
model
[
1
].
weight
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y1_test
,
y1_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
y2_test
,
y2_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_backward_linear_add
(
...
...
tests/pytorch/test_numerics.py
View file @
27ddce40
...
...
@@ -41,16 +41,21 @@ from transformer_engine.pytorch import torch_version
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
,
get_available_attention_backends
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
...
...
@@ -84,7 +89,18 @@ batch_sizes = [1, 2]
all_boolean
=
[
True
,
False
]
all_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
,
"qgelu"
,
"srelu"
]
all_activations
=
[
"gelu"
,
"geglu"
,
"qgelu"
,
"qgeglu"
,
"relu"
,
"reglu"
,
"srelu"
,
"sreglu"
,
"silu"
,
"swiglu"
,
]
all_normalizations
=
[
"LayerNorm"
,
"RMSNorm"
]
...
...
@@ -114,15 +130,25 @@ if fp8_available:
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
use_cutlass_grouped_gemm
=
[
False
]
# Only enable cutlass grouped gemm on Hopper
if
torch
.
cuda
.
get_device_capability
()
==
(
9
,
0
):
use_cutlass_grouped_gemm
.
append
(
True
)
def
is_fused_attn_available
(
config
:
ModelConfig
,
dtype
:
torch
.
dtype
,
qkv_layout
=
"bshd_bshd_bshd"
,
is_training
=
True
config
:
ModelConfig
,
dtype
:
torch
.
dtype
,
qkv_layout
=
"bshd_bshd_bshd"
,
is_training
=
True
,
deterministic
=
False
,
):
_
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
deterministic
=
deterministic
,
)
return
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
in
fused_attn_backends
...
...
@@ -432,13 +458,16 @@ class TorchGroupedLinearWithPadding(nn.Module):
_supported_act
=
{
"geglu"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"reglu"
:
nn
.
ReLU
(),
"relu"
:
nn
.
ReLU
(),
"swiglu"
:
nn
.
SiLU
(),
"geglu"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"qgelu"
:
TorchQuickGELU
(),
"qgeglu"
:
TorchQuickGELU
(),
"relu"
:
nn
.
ReLU
(),
"reglu"
:
nn
.
ReLU
(),
"srelu"
:
TorchSquaredRELU
(),
"sreglu"
:
TorchSquaredRELU
(),
"silu"
:
nn
.
SiLU
(),
"swiglu"
:
nn
.
SiLU
(),
}
...
...
@@ -830,7 +859,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
def
test_gpt_checkpointing
(
dtype
,
bs
,
model
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
):
if
not
is_fused_attn_available
(
config
,
dtype
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
outputs
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
True
)
...
...
@@ -878,7 +907,9 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
def
test_gpt_accuracy
(
dtype
,
bs
,
model
,
parallel_attention_mlp
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
False
):
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_gpt
=
TransformerLayer
(
...
...
@@ -991,7 +1022,9 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@
pytest
.
mark
.
parametrize
(
"mask_type"
,
mask_types
)
def
test_mha_accuracy
(
dtype
,
bs
,
model
,
mask_type
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
False
):
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_mha
=
MultiheadAttention
(
...
...
@@ -1782,6 +1815,7 @@ def test_grouped_linear_accuracy(
bias
,
delay_wgrad_compute
,
parallel_mode
=
None
,
use_cutlass
=
False
,
):
fp8
=
recipe
is
not
None
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
...
...
@@ -1853,11 +1887,49 @@ def test_grouped_linear_accuracy(
delay_wgrad_compute
,
)
# Shoule be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
outputs
,
outputs_ref
)):
for
o
,
o_ref
in
zip
(
outputs
,
outputs_ref
):
if
use_cutlass
:
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
else
:
# cuBLAS implementation should be bit-wise match
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
get_device_capability
()
!=
(
9
,
0
),
reason
=
"Only enable CUTLASS grouped gemm on Hopper"
,
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"num_gemms"
,
[
3
,
6
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"delay_wgrad_compute"
,
all_boolean
)
def
test_grouped_linear_accuracy_cutlass
(
dtype
,
num_gemms
,
bs
,
model
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
):
os
.
environ
[
"NVTE_USE_CUTLASS_GROUPED_GEMM"
]
=
"1"
test_grouped_linear_accuracy
(
dtype
,
num_gemms
,
bs
,
model
,
None
,
False
,
fuse_wgrad_accumulation
,
False
,
delay_wgrad_compute
,
None
,
use_cutlass
=
True
,
)
os
.
environ
.
pop
(
"NVTE_USE_CUTLASS_GROUPED_GEMM"
,
None
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"num_gemms"
,
[
3
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
1
])
...
...
@@ -2525,10 +2597,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
(
16
,
10027
,
128
,
512
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"layout"
,
[
"TN"
,
"NN"
,
"NT"
])
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
False
,
True
])
def
test_grouped_gemm
(
shape
,
dtype
,
layout
,
accumulate
):
@
pytest
.
mark
.
parametrize
(
"use_cutlass"
,
use_cutlass_grouped_gemm
)
def
test_grouped_gemm
(
shape
,
dtype
,
layout
,
accumulate
,
use_cutlass
):
torch
.
manual_seed
(
0
)
z
,
m
,
k
,
n
=
shape
...
...
@@ -2563,6 +2636,9 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
grad
=
True
single_output
=
False
if
use_cutlass
:
os
.
environ
[
"NVTE_USE_CUTLASS_GROUPED_GEMM"
]
=
"1"
# Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
if
IS_HIP_EXTENSION
:
ori_force_rocm_gemm
=
os
.
environ
.
get
(
"NVTE_FORCE_ROCM_GEMM"
,
None
)
...
...
@@ -2600,9 +2676,82 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
else
:
del
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
# should be bit-wise match
for
o
,
o_ref
in
zip
(
out
,
out_ref
):
if
not
use_cutlass
:
# cublas implementation should be bit-wise match
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
else
:
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1.5e-2
,
atol
=
1.5e-2
)
if
use_cutlass
:
os
.
environ
.
pop
(
"NVTE_USE_CUTLASS_GROUPED_GEMM"
,
None
)
@
pytest
.
mark
.
parametrize
(
"N"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"datatype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"input_quantizer"
,
[
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
),
MXFP8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
),
],
)
@
pytest
.
mark
.
parametrize
(
"out_quantizer"
,
[
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
),
MXFP8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
),
Float8Quantizer
(
torch
.
ones
(
1
).
cuda
().
squeeze
(),
torch
.
ones
(
1
).
cuda
().
squeeze
(),
tex
.
DType
.
kFloat8E4M3
),
],
)
def
test_fp8gemm_with_unfused_quantization
(
N
,
datatype
,
input_quantizer
,
out_quantizer
):
# For MXFP8 and CurrentScaling, below unfused quantization should happen
# FP8 input --> cublas GEMM --> BF16 output --> Quantize to FP8 --> fp8 Output
# Skip invalid configurations
is_mxfp8_needed
=
isinstance
(
input_quantizer
,
MXFP8Quantizer
)
or
isinstance
(
out_quantizer
,
MXFP8Quantizer
)
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
is_mxfp8_needed
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
inp_fp8
=
input_quantizer
(
torch
.
randn
(
N
,
N
,
device
=
"cuda"
,
dtype
=
datatype
))
weight_fp8
=
input_quantizer
(
torch
.
randn
(
N
,
N
,
device
=
"cuda"
,
dtype
=
datatype
))
outp_type
=
torch
.
float32
quantized_out
,
*
_
=
general_gemm
(
weight_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
quantization_params
=
out_quantizer
,
bias
=
None
,
use_split_accumulator
=
False
,
)
out
,
*
_
=
general_gemm
(
weight_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
quantization_params
=
None
,
bias
=
None
,
use_split_accumulator
=
False
,
)
expected_quantized_out
=
out_quantizer
(
out
)
# Match results again Pytorch GEMM and allow for quantization tolerance
pytorch_out
=
torch
.
matmul
(
inp_fp8
.
dequantize
().
to
(
torch
.
float64
),
torch
.
transpose
(
weight_fp8
.
dequantize
().
to
(
torch
.
float64
),
0
,
1
),
)
fp8_tols
=
dict
(
rtol
=
0.125
,
atol
=
0.0675
)
torch
.
testing
.
assert_close
(
pytorch_out
.
to
(
outp_type
),
expected_quantized_out
.
dequantize
(),
**
fp8_tols
)
# Match results between quantization happening inside vs outside general_gemm
torch
.
testing
.
assert_close
(
expected_quantized_out
.
dequantize
(),
quantized_out
.
dequantize
())
@
pytest
.
mark
.
parametrize
(
...
...
tests/pytorch/test_onnx_export.py
View file @
27ddce40
...
...
@@ -36,6 +36,7 @@ import transformer_engine_torch as tex
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
,
te_translation_table
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
get_default_init_method
import
tensorrt
as
trt
# Global test configuration knobs.
...
...
@@ -64,6 +65,7 @@ if mxfp8_available:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
if
fp8_available
:
fp8_recipes
.
append
(
recipe
.
DelayedScaling
())
fp8_recipes
.
append
(
recipe
.
Float8CurrentScaling
())
fp8_recipes
.
append
(
None
)
supported_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
]
...
...
@@ -80,11 +82,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
],
outputs
=
[
PyCustomOpDef
.
dt_uint8
],
)
def
trt_fp8_quantize
(
t
,
scale
):
def
trt_fp8_quantize
(
t
,
scale
_inv
):
"""FP8 quantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
scale
=
1
/
torch
.
from_numpy
(
scale
_inv
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
...
...
@@ -100,11 +102,11 @@ def trt_fp8_quantize(t, scale):
],
outputs
=
[
PyCustomOpDef
.
dt_float
],
)
def
trt_fp8_dequantize
(
t
,
scale
):
def
trt_fp8_dequantize
(
t
,
scale
_inv
):
"""FP8 dequantization extension for ONNX Runtime."""
x
=
torch
.
from_numpy
(
t
).
cuda
()
q
=
te
.
tensor
.
float8_tensor
.
Float8Quantizer
(
scale
=
1
/
torch
.
from_numpy
(
scale
).
cuda
(),
scale
=
1
/
torch
.
from_numpy
(
scale
_inv
).
cuda
(),
amax
=
torch
.
zeros
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
...
...
@@ -113,7 +115,7 @@ def trt_fp8_dequantize(t, scale):
@
onnx_op
(
op_type
=
"trt::TRT_MXFP8Quantize
Linear
"
,
op_type
=
"trt::TRT_MXFP8
Dynamic
Quantize"
,
domain
=
"trt"
,
inputs
=
[
PyCustomOpDef
.
dt_float
,
...
...
@@ -592,7 +594,9 @@ def _test_export_layernorm_linear(
fname
,
inp
,
model
,
atol
=
1e-3
,
# For current scaling we use Float8Quantizer in tests + amax computed by hand,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol
=
1e-3
if
fp8_recipe
.
__class__
is
not
recipe
.
Float8CurrentScaling
else
2e-2
,
is_fp8
=
fp8_recipe
is
not
None
,
te_outputs
=
te_outputs
,
)
...
...
@@ -1139,3 +1143,64 @@ def test_export_ctx_manager(enabled):
with
te
.
onnx_export
(
enabled
):
assert
is_in_onnx_export_mode
()
==
enabled
assert
is_in_onnx_export_mode
()
==
False
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
def
test_trt_integration
(
fp8_recipe
:
recipe
.
Recipe
):
model
=
te
.
TransformerLayer
(
hidden_size
=
128
,
ffn_hidden_size
=
128
,
num_attention_heads
=
4
,
).
eval
()
if
type
(
fp8_recipe
)
==
recipe
.
Float8CurrentScaling
:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model
=
te
.
LayerNormMLP
(
128
,
128
)
inps
=
(
torch
.
randn
([
16
,
16
,
128
],
device
=
"cuda"
,
requires_grad
=
False
),)
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
out_ref
=
model
(
*
inps
)
onnx_fd
,
onnx_path
=
tempfile
.
mkstemp
(
suffix
=
".onnx"
)
os
.
close
(
onnx_fd
)
try
:
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
with
te
.
onnx_export
(
enabled
=
True
):
torch
.
onnx
.
export
(
model
,
inps
,
onnx_path
,
output_names
=
[
"output"
],
dynamo
=
True
,
custom_translation_table
=
te_translation_table
,
)
os
.
system
(
f
"trtexec --onnx=
{
onnx_path
}
--saveEngine=
{
onnx_path
}
.engine"
)
# Run TRT engine
logger
=
trt
.
Logger
(
trt
.
Logger
.
WARNING
)
runtime
=
trt
.
Runtime
(
logger
)
with
open
(
onnx_path
+
".engine"
,
"rb"
)
as
f
:
engine_data
=
f
.
read
()
engine
=
runtime
.
deserialize_cuda_engine
(
engine_data
)
context
=
engine
.
create_execution_context
()
context
.
set_tensor_address
(
engine
.
get_tensor_name
(
0
),
inps
[
0
].
data_ptr
())
stream
=
torch
.
cuda
.
Stream
()
out
=
torch
.
zeros_like
(
out_ref
)
context
.
set_tensor_address
(
"output"
,
out
.
data_ptr
())
context
.
execute_async_v3
(
stream_handle
=
stream
.
cuda_stream
)
stream
.
synchronize
()
# Compare TRT and TE outputs
atol
=
5e-2
if
fp8_recipe
is
not
None
else
1e-4
rtol
=
5e-2
if
fp8_recipe
is
not
None
else
1e-4
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
atol
,
rtol
=
rtol
)
finally
:
try
:
os
.
remove
(
onnx_path
)
except
FileNotFoundError
:
pass
tests/pytorch/test_parallel_cross_entropy.py
View file @
27ddce40
...
...
@@ -6,6 +6,8 @@ import random
import
torch
from
transformer_engine.pytorch.cross_entropy
import
parallel_cross_entropy
from
utils
import
dtype_tols
class
TestParallelCrossEntropy
:
...
...
@@ -18,19 +20,25 @@ class TestParallelCrossEntropy:
label_smoothing
=
label_smoothing
,
reduction
=
"mean"
if
reduce_loss
else
"none"
)
def
generate_input
(
self
,
dtype
:
torch
.
dtype
,
swap_dim
:
bool
,
ignore_idx
:
bool
):
def
generate_input
(
self
,
dtype
:
torch
.
dtype
,
swap_dim
:
bool
,
ignore_idx
:
bool
,
device
:
torch
.
device
=
"cuda"
,
):
SQ
=
random
.
choice
([
64
,
128
])
batch
=
random
.
choice
([
1
,
2
])
vocab
=
random
.
choice
([
64000
,
128000
])
ignore
=
random
.
sample
(
range
(
0
,
SQ
-
1
),
5
)
# Generate random data
if
swap_dim
:
self
.
input_test
=
torch
.
rand
((
SQ
,
batch
,
vocab
),
dtype
=
dtype
).
cuda
(
)
self
.
tar_test
=
torch
.
randint
(
0
,
vocab
,
(
SQ
,
batch
)
).
cuda
(
)
self
.
input_test
=
torch
.
rand
((
SQ
,
batch
,
vocab
),
dtype
=
dtype
,
device
=
device
)
self
.
tar_test
=
torch
.
randint
(
0
,
vocab
,
(
SQ
,
batch
)
,
device
=
device
)
else
:
self
.
input_test
=
torch
.
rand
((
batch
,
SQ
,
vocab
),
dtype
=
dtype
).
cuda
(
)
self
.
tar_test
=
torch
.
randint
(
0
,
vocab
,
(
batch
,
SQ
)
).
cuda
(
)
self
.
input_test
=
torch
.
rand
((
batch
,
SQ
,
vocab
),
dtype
=
dtype
,
device
=
device
)
self
.
tar_test
=
torch
.
randint
(
0
,
vocab
,
(
batch
,
SQ
)
,
device
=
device
)
if
ignore_idx
:
for
i
in
ignore
:
...
...
@@ -40,9 +48,14 @@ class TestParallelCrossEntropy:
else
:
self
.
tar_test
[
0
][
i
]
=
-
100
# Make copy of data for reference implementation
self
.
input_ref
=
torch
.
reshape
(
self
.
input_test
.
clone
().
detach
(),
(
batch
*
SQ
,
vocab
))
self
.
tar_ref
=
torch
.
reshape
(
self
.
tar_test
.
clone
().
detach
(),
(
batch
*
SQ
,))
# Enable autograd
self
.
input_test
.
requires_grad_
()
self
.
input_ref
.
requires_grad_
()
def
one_iteration_test
(
self
,
dtype
:
torch
.
dtype
,
...
...
@@ -52,18 +65,20 @@ class TestParallelCrossEntropy:
ignore_idx
:
bool
=
False
,
):
# Random data
self
.
generate_input
(
dtype
,
swap_dim
,
ignore_idx
)
self
.
input_test
.
requires_grad_
(
True
)
self
.
input_ref
.
requires_grad_
(
True
)
# Forward pass
test_loss
=
self
.
test_loss_func
(
self
.
input_test
,
self
.
tar_test
,
label_smoothing
,
reduce_loss
,
None
)
ref_loss
=
self
.
ref_loss_func
(
self
.
input_ref
,
self
.
tar_ref
)
# Handle backward pass based on the test scenario
# Compute square to avoid trivial backward pass
test_loss
=
torch
.
square
(
test_loss
)
ref_loss
=
torch
.
square
(
ref_loss
)
# Backward pass
if
reduce_loss
:
test_loss
.
backward
()
ref_loss
.
backward
()
...
...
@@ -71,16 +86,18 @@ class TestParallelCrossEntropy:
test_loss
.
sum
().
backward
()
ref_loss
.
sum
().
backward
()
test_loss
=
torch
.
flatten
(
test_loss
)
if
not
reduce_loss
else
test_loss
if
ignore_idx
:
print
(
test_loss
,
ref_loss
)
# Compare gradients when backward pass was called
torch
.
testing
.
assert_close
(
torch
.
flatten
(
self
.
input_test
.
grad
,
start_dim
=
0
,
end_dim
=
1
),
self
.
input_ref
.
grad
)
# Check that loss and grad input match
tols
=
dtype_tols
(
dtype
)
test_loss
=
test_loss
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
ref_loss
=
test_loss
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
ref_loss
=
ref_loss
.
reshape
(
test_loss
.
size
())
test_grad_input
=
self
.
input_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
ref_grad_input
=
self
.
input_ref
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
ref_grad_input
=
ref_grad_input
.
reshape
(
test_grad_input
.
size
())
torch
.
testing
.
assert_close
(
test_loss
,
ref_loss
,
**
tols
)
torch
.
testing
.
assert_close
(
test_grad_input
,
ref_grad_input
,
**
tols
)
# Reset data
self
.
input_test
=
None
self
.
input_ref
=
None
self
.
tar_test
=
None
...
...
tests/pytorch/test_sanity.py
View file @
27ddce40
...
...
@@ -105,7 +105,18 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
all_boolean
=
[
True
,
False
]
batch_sizes_with_zero
=
[
0
,
1
,
2
]
all_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
,
"srelu"
,
"qgelu"
,
"qgeglu"
]
all_activations
=
[
"gelu"
,
"geglu"
,
"qgelu"
,
"qgeglu"
,
"relu"
,
"reglu"
,
"srelu"
,
"sreglu"
,
"silu"
,
"swiglu"
,
]
all_normalizations
=
[
"LayerNorm"
,
"RMSNorm"
]
...
...
tests/pytorch/utils.py
View file @
27ddce40
...
...
@@ -266,8 +266,8 @@ def get_available_attention_backends(
)
(
use_flash_attention
,
use_fused_attention
,
flash_attention_backend
,
use_fused_attention
,
fused_attention_backend
,
use_unfused_attention
,
available_backends
,
...
...
transformer_engine/common/CMakeLists.txt
View file @
27ddce40
...
...
@@ -102,6 +102,11 @@ if(USE_ROCM)
message
(
STATUS
"USE_HIPBLASLT
${
USE_HIPBLASLT
}
USE_ROCBLAS
${
USE_ROCBLAS
}
"
)
endif
()
set
(
CUTLASS_INCLUDE_DIR
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../3rdparty/cutlass/include"
)
set
(
CUTLASS_TOOLS_INCLUDE_DIR
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../3rdparty/cutlass/tools/util/include"
)
# Python
find_package
(
Python COMPONENTS Interpreter Development.Module REQUIRED
)
...
...
@@ -128,6 +133,7 @@ if(USE_CUDA)
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
...
...
@@ -139,6 +145,7 @@ if(USE_CUDA)
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
...
...
@@ -169,6 +176,10 @@ if(USE_CUDA)
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
if
(
NVTE_WITH_CUBLASMP
)
list
(
APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp
)
endif
()
add_library
(
transformer_engine SHARED
${
transformer_engine_SOURCES
}
)
else
()
list
(
APPEND transformer_engine_SOURCES
...
...
@@ -192,10 +203,12 @@ else()
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
dropout/dropout.cu
activation/relu.cu
activation/swiglu.cu
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
...
...
@@ -226,6 +239,10 @@ else()
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
if
(
NVTE_WITH_CUBLASMP
)
list
(
APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp
)
endif
()
# process source code files
message
(
"
${
message_line
}
"
)
message
(
STATUS
"CMAKE_CURRENT_SOURCE_DIR:
${
CMAKE_CURRENT_SOURCE_DIR
}
"
)
...
...
@@ -272,7 +289,12 @@ if (USE_CUDA)
CUDNN::cudnn_all
)
target_include_directories
(
transformer_engine PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
target_include_directories
(
transformer_engine SYSTEM PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
/cccl
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
target_include_directories
(
transformer_engine PRIVATE
${
CUTLASS_INCLUDE_DIR
}
${
CUTLASS_TOOLS_INCLUDE_DIR
}
)
else
()
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
)
# Aotriton is currently unsupported
...
...
@@ -313,11 +335,23 @@ if (NVTE_ENABLE_NVSHMEM)
target_include_directories
(
transformer_engine PUBLIC
${
NVSHMEMAPI_INCLUDE_DIR
}
)
endif
()
option
(
NVTE_ENABLE_NVSHMEM
"Compile with NVSHMEM library"
OFF
)
if
(
NVTE_ENABLE_NVSHMEM
)
add_subdirectory
(
nvshmem_api
)
target_link_libraries
(
transformer_engine PUBLIC nvshmemapi
)
target_include_directories
(
transformer_engine PUBLIC
${
NVSHMEMAPI_INCLUDE_DIR
}
)
option
(
NVTE_WITH_CUBLASMP
"Use cuBLASMp for tensor parallel GEMMs"
OFF
)
if
(
NVTE_WITH_CUBLASMP
)
target_compile_definitions
(
transformer_engine PRIVATE NVTE_WITH_CUBLASMP
)
target_include_directories
(
transformer_engine PRIVATE
${
CUBLASMP_DIR
}
/include
${
NVSHMEM_DIR
}
/include
)
find_library
(
CUBLASMP_LIB
NAMES cublasmp libcublasmp
PATHS
${
CUBLASMP_DIR
}
PATH_SUFFIXES lib
REQUIRED
)
find_library
(
NVSHMEM_HOST_LIB
NAMES nvshmem_host libnvshmem_host.so.3
PATHS
${
NVSHMEM_DIR
}
PATH_SUFFIXES lib
REQUIRED
)
target_link_libraries
(
transformer_engine PUBLIC
${
CUBLASMP_LIB
}
${
NVSHMEM_HOST_LIB
}
)
message
(
STATUS
"Using cuBLASMp at:
${
CUBLASMP_DIR
}
"
)
message
(
STATUS
"Using nvshmem at:
${
NVSHMEM_DIR
}
"
)
endif
()
if
(
USE_CUDA
)
...
...
transformer_engine/common/__init__.py
View file @
27ddce40
...
...
@@ -218,6 +218,11 @@ def _nvidia_cudart_include_dir() -> str:
except
ModuleNotFoundError
:
return
""
# Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia"
# above doesn't through. However, they don't set "__file__" attribute.
if
nvidia
.
__file__
is
None
:
return
""
include_dir
=
Path
(
nvidia
.
__file__
).
parent
/
"cuda_runtime"
return
str
(
include_dir
)
if
include_dir
.
exists
()
else
""
...
...
@@ -295,6 +300,38 @@ def _load_nvrtc():
return
ctypes
.
CDLL
(
f
"libnvrtc
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_curand
():
"""Load cuRAND shared library."""
# Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home
=
os
.
environ
.
get
(
"CUDA_HOME"
)
or
os
.
environ
.
get
(
"CUDA_PATH"
)
or
"/usr/local/cuda"
libs
=
glob
.
glob
(
f
"
{
cuda_home
}
/**/libcurand
{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
=
list
(
filter
(
lambda
x
:
not
(
"stub"
in
x
),
libs
))
libs
.
sort
(
reverse
=
True
,
key
=
os
.
path
.
basename
)
if
libs
:
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# Attempt to locate cuRAND in Python dist-packages
found
,
handle
=
_load_nvidia_cuda_library
(
"curand"
)
if
found
:
return
handle
# Attempt to locate cuRAND via ldconfig
libs
=
subprocess
.
check_output
(
f
"ldconfig -p | grep 'libcurand
{
_get_sys_extension
()
}
'"
,
shell
=
True
)
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
sos
=
[]
for
lib
in
libs
:
if
"libcurand"
in
lib
and
"=>"
in
lib
:
sos
.
append
(
lib
.
split
(
">"
)[
1
].
strip
())
if
sos
:
return
ctypes
.
CDLL
(
sos
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return
ctypes
.
CDLL
(
f
"libcurand
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_core_library
():
"""Load shared library with Transformer Engine C extensions"""
...
...
@@ -305,6 +342,7 @@ if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE
try
:
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
_CURAND_LIB_CTYPES
=
_load_curand
()
_CUBLAS_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cublas"
)
_CUDART_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cuda_runtime"
)
# Needed to find the correct headers for NVRTC kernels.
...
...
transformer_engine/common/comm_gemm/comm_gemm.cpp
0 → 100644
View file @
27ddce40
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/comm_gemm.h"
#include <cublasmp.h>
#include <cuda_runtime.h>
#include <nvshmem.h>
#include <map>
#include <memory>
#include <string>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
#include "../common.h"
#include "../util/logging.h"
using
namespace
transformer_engine
;
namespace
{
// TODO: log warnings on failures of the *Destroy calls below, once TE has such ability.
// For now, just silently ignoring the errors, since the only diag available in TE is throwing
// exceptions, but these calls will typically be made from destructors, so cannot throw.
template
<
typename
HandlePtr
,
typename
CreateFn
,
typename
DestroyFn
,
typename
...
Args
>
auto
CreateWithCudaCheck
(
CreateFn
create_fn
,
DestroyFn
destroy_fn
,
Args
&&
...
args
)
{
using
Handle
=
std
::
remove_pointer_t
<
HandlePtr
>
;
HandlePtr
raw
{};
NVTE_CHECK_CUDA
(
create_fn
(
&
raw
,
std
::
forward
<
Args
>
(
args
)...));
return
std
::
unique_ptr
<
Handle
,
DestroyFn
>
(
raw
,
destroy_fn
);
}
using
CudaStream
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cudaStream_t
>
,
decltype
(
&
cudaStreamDestroy
)
>
;
CudaStream
CudaStreamCreate
()
{
return
CreateWithCudaCheck
<
cudaStream_t
>
(
cudaStreamCreate
,
cudaStreamDestroy
);
}
using
CudaEvent
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cudaEvent_t
>
,
decltype
(
&
cudaEventDestroy
)
>
;
CudaEvent
CudaEventCreate
(
unsigned
flags
)
{
return
CreateWithCudaCheck
<
cudaEvent_t
>
(
cudaEventCreateWithFlags
,
cudaEventDestroy
,
flags
);
}
template
<
bool
raw_last
,
typename
HandlePtr
,
typename
CreateFn
,
typename
DestroyFn
,
typename
...
Args
>
auto
CreateWithCublasMpCheck
(
CreateFn
create_fn
,
DestroyFn
destroy_fn
,
Args
&&
...
args
)
{
using
Handle
=
std
::
remove_pointer_t
<
HandlePtr
>
;
HandlePtr
raw
{};
if
constexpr
(
raw_last
)
{
NVTE_CHECK_CUBLASMP
(
create_fn
(
std
::
forward
<
Args
>
(
args
)...,
&
raw
));
}
else
{
NVTE_CHECK_CUBLASMP
(
create_fn
(
&
raw
,
std
::
forward
<
Args
>
(
args
)...));
}
return
std
::
unique_ptr
<
Handle
,
DestroyFn
>
(
raw
,
destroy_fn
);
}
using
CublasMp
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cublasMpHandle_t
>
,
decltype
(
&
cublasMpDestroy
)
>
;
CublasMp
CublasMpCreate
(
cudaStream_t
stream
)
{
return
CreateWithCublasMpCheck
<
false
,
cublasMpHandle_t
>
(
cublasMpCreate
,
cublasMpDestroy
,
stream
);
}
using
CublasMpGrid
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cublasMpGrid_t
>
,
decltype
(
&
cublasMpGridDestroy
)
>
;
CublasMpGrid
CublasMpGridCreate
(
int64_t
nprow
,
int64_t
npcol
,
cublasMpGridLayout_t
layout
,
ncclComm_t
comm
)
{
return
CreateWithCublasMpCheck
<
true
,
cublasMpGrid_t
>
(
cublasMpGridCreate
,
cublasMpGridDestroy
,
nprow
,
npcol
,
layout
,
comm
);
}
using
CublasMpMatrixDesc
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cublasMpMatrixDescriptor_t
>
,
decltype
(
&
cublasMpMatrixDescriptorDestroy
)
>
;
CublasMpMatrixDesc
CublasMpMatrixDescCreate
(
int64_t
m
,
int64_t
n
,
int64_t
mb
,
int64_t
nb
,
int64_t
rsrc
,
int64_t
csrc
,
int64_t
lld
,
cudaDataType_t
type
,
cublasMpGrid_t
grid
)
{
return
CreateWithCublasMpCheck
<
true
,
cublasMpMatrixDescriptor_t
>
(
cublasMpMatrixDescriptorCreate
,
cublasMpMatrixDescriptorDestroy
,
m
,
n
,
mb
,
nb
,
rsrc
,
csrc
,
lld
,
type
,
grid
);
}
using
CublasMpMatmulDesc
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cublasMpMatmulDescriptor_t
>
,
decltype
(
&
cublasMpMatmulDescriptorDestroy
)
>
;
CublasMpMatmulDesc
CublasMpMatmulDescCreate
(
cublasComputeType_t
compute_type
)
{
return
CreateWithCublasMpCheck
<
false
,
cublasMpMatmulDescriptor_t
>
(
cublasMpMatmulDescriptorCreate
,
cublasMpMatmulDescriptorDestroy
,
compute_type
);
}
}
// namespace
struct
NVTECommGemmCtx
{
int64_t
nranks
;
int64_t
rank
;
ncclComm_t
comm
;
CudaStream
stream
;
CudaEvent
event
;
CublasMp
cublas_mp
;
CublasMpGrid
grid_col_major
;
CublasMpGrid
grid_row_major
;
CublasMpMatrixDesc
a_desc
;
CublasMpMatrixDesc
b_desc
;
CublasMpMatrixDesc
d_desc
;
CublasMpMatmulDesc
matmul_desc
;
void
*
workspace
;
size_t
workspace_size
;
};
namespace
{
int64_t
block_size
(
NVTECommGemmCtx
*
ctx
,
int64_t
global_size
)
{
// Use non-cyclic layout to maximize opportunity for comm overlap.
return
(
global_size
+
ctx
->
nranks
-
1
)
/
ctx
->
nranks
;
}
void
AgGemmInitMatrices
(
NVTECommGemmCtx
*
ctx
,
int64_t
*
ldd
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
Tensor
*
a
,
const
Tensor
*
b
,
const
Tensor
*
d
,
bool
transa
,
bool
transb
)
{
const
auto
a0
=
a
->
flat_first_dim
();
const
auto
a1
=
a
->
flat_last_dim
();
const
auto
b0
=
b
->
flat_first_dim
();
const
auto
b1
=
b
->
flat_last_dim
();
const
auto
d0
=
d
->
flat_first_dim
();
const
auto
d1
=
d
->
flat_last_dim
();
if
(
transa
)
{
NVTE_CHECK
(
a1
==
k
,
"Unsupported tensor dimension in A: expected "
,
k
,
", got "
,
a1
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
m
,
k
,
block_size
(
ctx
,
m
),
0
,
0
,
k
,
get_cuda_dtype
(
a
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
a_desc
.
get
()));
}
else
{
NVTE_CHECK
(
a0
==
k
,
"Unsupported tensor dimension in A: expected "
,
k
,
", got "
,
a0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
m
,
k
,
block_size
(
ctx
,
m
),
k
,
0
,
0
,
block_size
(
ctx
,
m
),
get_cuda_dtype
(
a
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
a_desc
.
get
()));
}
if
(
transb
)
{
NVTE_CHECK
(
b0
==
k
,
"Unsupported tensor dimensionin B: expected "
,
k
,
", got "
,
b0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
n
,
k
,
block_size
(
ctx
,
n
),
k
,
0
,
0
,
block_size
(
ctx
,
n
),
get_cuda_dtype
(
b
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
b_desc
.
get
()));
}
else
{
NVTE_CHECK
(
b1
==
k
,
"Unsupported tensor dimension in B: expected "
,
k
,
", got "
,
b1
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
n
,
k
,
block_size
(
ctx
,
n
),
0
,
0
,
k
,
get_cuda_dtype
(
b
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
b_desc
.
get
()));
}
NVTE_CHECK
(
d0
==
n
,
"Unsupported tensor dimension in D: expected "
,
n
,
", got "
,
d0
);
*
ldd
=
block_size
(
ctx
,
m
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
m
,
n
,
block_size
(
ctx
,
m
),
block_size
(
ctx
,
n
),
0
,
0
,
*
ldd
,
get_cuda_dtype
(
d
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
d_desc
.
get
()));
}
void
GemmRsInitMatrices
(
NVTECommGemmCtx
*
ctx
,
int64_t
*
ldd
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
Tensor
*
a
,
const
Tensor
*
b
,
const
Tensor
*
d
,
bool
transa
,
bool
transb
)
{
const
auto
a0
=
a
->
flat_first_dim
();
const
auto
a1
=
a
->
flat_last_dim
();
const
auto
b0
=
b
->
flat_first_dim
();
const
auto
b1
=
b
->
flat_last_dim
();
const
auto
d0
=
d
->
flat_first_dim
();
const
auto
d1
=
d
->
flat_last_dim
();
if
(
transa
)
{
NVTE_CHECK
(
a0
==
m
,
"Unsupported tensor dimension in A: expected "
,
m
,
", got "
,
a0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
m
,
block_size
(
ctx
,
k
),
m
,
0
,
0
,
block_size
(
ctx
,
k
),
get_cuda_dtype
(
a
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
a_desc
.
get
()));
}
else
{
NVTE_CHECK
(
a1
==
m
,
"Unsupported tensor dimension in A: expected "
,
m
,
", got "
,
a1
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
m
,
k
,
m
,
block_size
(
ctx
,
k
),
0
,
0
,
m
,
get_cuda_dtype
(
a
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
a_desc
.
get
()));
}
if
(
transb
)
{
NVTE_CHECK
(
b1
==
n
,
"Unsupported tensor dimension in B: expected "
,
n
,
", got "
,
b1
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
n
,
k
,
block_size
(
ctx
,
n
),
block_size
(
ctx
,
k
),
0
,
0
,
block_size
(
ctx
,
n
),
get_cuda_dtype
(
b
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
b_desc
.
get
()));
}
else
{
NVTE_CHECK
(
b0
==
n
,
"Unsupported tensor dimension in B: expected "
,
n
,
", got "
,
b0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
n
,
block_size
(
ctx
,
k
),
block_size
(
ctx
,
n
),
0
,
0
,
block_size
(
ctx
,
k
),
get_cuda_dtype
(
b
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
b_desc
.
get
()));
}
NVTE_CHECK
(
d1
==
m
,
"Unsupported tensor dimension in D: expected "
,
m
,
", got "
,
d1
);
*
ldd
=
m
;
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
m
,
n
,
m
,
block_size
(
ctx
,
n
),
0
,
0
,
*
ldd
,
get_cuda_dtype
(
d
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
d_desc
.
get
()));
}
void
GemmArInitMatrices
(
NVTECommGemmCtx
*
ctx
,
int64_t
*
ldd
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
Tensor
*
a
,
const
Tensor
*
b
,
const
Tensor
*
d
,
bool
transa
,
bool
transb
)
{
const
auto
a0
=
a
->
flat_first_dim
();
const
auto
a1
=
a
->
flat_last_dim
();
const
auto
b0
=
b
->
flat_first_dim
();
const
auto
b1
=
b
->
flat_last_dim
();
const
auto
d0
=
d
->
flat_first_dim
();
const
auto
d1
=
d
->
flat_last_dim
();
if
(
transa
)
{
NVTE_CHECK
(
a0
==
m
,
"Unsupported tensor dimension in A: expected "
,
m
,
", got "
,
a0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
m
,
block_size
(
ctx
,
k
),
m
,
0
,
0
,
block_size
(
ctx
,
k
),
get_cuda_dtype
(
a
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
a_desc
.
get
()));
}
else
{
NVTE_ERROR
(
"N transpose flag is not supported for input A"
);
}
if
(
transb
)
{
NVTE_ERROR
(
"T transpose flag is not supported for input B"
);
}
else
{
NVTE_CHECK
(
b0
==
n
,
"Unsupported tensor dimension in B: expected "
,
n
,
", got "
,
b0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
n
,
block_size
(
ctx
,
k
),
n
,
0
,
0
,
block_size
(
ctx
,
k
),
get_cuda_dtype
(
b
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
b_desc
.
get
()));
}
NVTE_CHECK
(
d1
==
m
,
"Unsupported tensor dimension in D: expected "
,
m
,
", got "
,
d1
);
*
ldd
=
m
;
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
m
,
n
*
ctx
->
nranks
,
m
,
n
,
0
,
0
,
*
ldd
,
get_cuda_dtype
(
d
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
d_desc
.
get
()));
const
cublasMpMatmulEpilogue_t
epilogue
=
CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE
;
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE
,
&
epilogue
,
sizeof
epilogue
));
}
using
InitMatricesFn
=
void
(
*
)(
NVTECommGemmCtx
*
,
int64_t
*
,
int64_t
,
int64_t
,
int64_t
,
const
Tensor
*
,
const
Tensor
*
,
const
Tensor
*
,
bool
,
bool
);
cublasMpMatmulAlgoType_t
cublasmp_algo
(
NVTECommGemmAlgoType
algo
)
{
static
const
std
::
unordered_map
<
NVTECommGemmAlgoType
,
cublasMpMatmulAlgoType_t
>
s_map
{
{
kNVTECommGemmAlgoDefault
,
CUBLASMP_MATMUL_ALGO_TYPE_DEFAULT
},
{
kNVTECommGemmAlgoSplitP2P
,
CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_P2P
},
{
kNVTECommGemmAlgoSplitMulticast
,
CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_MULTICAST
},
{
kNVTECommGemmAlgoAtomicP2P
,
CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_P2P
},
{
kNVTECommGemmAlgoAtomicMulticast
,
CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_MULTICAST
},
};
auto
it
=
s_map
.
find
(
algo
);
return
it
!=
s_map
.
end
()
?
it
->
second
:
static_cast
<
cublasMpMatmulAlgoType_t
>
(
algo
);
}
void
cublasmp_gemm
(
InitMatricesFn
init_matrices_fn
,
NVTECommGemmCtx
*
ctx
,
NVTECommGemmAlgoType
algo
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
Tensor
*
a
,
const
Tensor
*
b
,
const
Tensor
*
d
,
const
Tensor
*
bias
,
const
Tensor
*
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
)
{
for
(
auto
t
:
{
a
,
b
,
d
})
{
NVTE_CHECK
(
is_tensor_scaling
(
t
->
scaling_mode
),
"Unsupported scaling mode: "
+
std
::
to_string
(
t
->
scaling_mode
));
}
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorInit
(
ctx
->
matmul_desc
.
get
(),
CUBLAS_COMPUTE_32F
));
int64_t
ldd
{};
init_matrices_fn
(
ctx
,
&
ldd
,
m
,
n
,
k
,
a
,
b
,
d
,
transa
,
transb
);
const
cublasOperation_t
trans_a
=
transa
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
const
cublasOperation_t
trans_b
=
transb
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA
,
&
trans_a
,
sizeof
trans_a
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB
,
&
trans_b
,
sizeof
trans_b
));
cublasMpMatmulAlgoType_t
algo_attr
=
cublasmp_algo
(
algo
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE
,
&
algo_attr
,
sizeof
algo_attr
));
const
cublasMpMatmulMatrixScale_t
scale_mode
=
CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32
;
if
(
is_fp8_dtype
(
a
->
dtype
()))
{
NVTE_CHECK
(
a
->
scale_inv
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER
,
&
a
->
scale_inv
.
dptr
,
sizeof
(
void
*
)));
}
if
(
is_fp8_dtype
(
b
->
dtype
()))
{
NVTE_CHECK
(
b
->
scale_inv
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER
,
&
b
->
scale_inv
.
dptr
,
sizeof
(
void
*
)));
}
if
(
is_fp8_dtype
(
d
->
dtype
()))
{
NVTE_CHECK
(
d
->
scale
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER
,
&
d
->
scale
.
dptr
,
sizeof
(
void
*
)));
if
(
d
->
amax
.
dptr
)
{
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER
,
&
d
->
amax
.
dptr
,
sizeof
(
void
*
)));
}
}
// Might be set to ALLREDUCE before, need to OR with the new flags to set.
cublasMpMatmulEpilogue_t
epilogue
{};
size_t
size_read
{};
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeGet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE
,
&
epilogue
,
sizeof
epilogue
,
&
size_read
));
NVTE_CHECK
(
size_read
==
sizeof
epilogue
);
// (bias, gelu, grad) -> epilogue
const
std
::
map
<
std
::
tuple
<
bool
,
bool
,
bool
>
,
cublasMpMatmulEpilogue_t
>
flags_to_epilogue
{
{{
true
,
true
,
false
},
CUBLASMP_MATMUL_EPILOGUE_GELU_AUX_BIAS
},
{{
true
,
true
,
true
},
CUBLASMP_MATMUL_EPILOGUE_DGELU_BGRAD
},
{{
true
,
false
,
false
},
CUBLASMP_MATMUL_EPILOGUE_BIAS
},
{{
true
,
false
,
true
},
CUBLASMP_MATMUL_EPILOGUE_BGRADB
},
{{
false
,
true
,
false
},
CUBLASMP_MATMUL_EPILOGUE_GELU_AUX
},
{{
false
,
true
,
true
},
CUBLASMP_MATMUL_EPILOGUE_DGELU
},
};
if
(
auto
it
=
flags_to_epilogue
.
find
({
bias
?
bias
->
data
.
dptr
!=
nullptr
:
false
,
pre_act_out
?
pre_act_out
->
data
.
dptr
!=
nullptr
:
false
,
grad
});
it
!=
flags_to_epilogue
.
end
())
{
epilogue
=
static_cast
<
cublasMpMatmulEpilogue_t
>
(
epilogue
|
it
->
second
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE
,
&
epilogue
,
sizeof
epilogue
));
}
if
(
bias
&&
bias
->
data
.
dptr
)
{
cudaDataType_t
bias_type
=
get_cuda_dtype
(
bias
->
data
.
dtype
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE
,
&
bias_type
,
sizeof
bias_type
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER
,
&
bias
->
data
.
dptr
,
sizeof
bias
->
data
.
dptr
));
}
if
(
pre_act_out
&&
pre_act_out
->
data
.
dptr
)
{
cudaDataType_t
aux_type
=
get_cuda_dtype
(
pre_act_out
->
data
.
dtype
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE
,
&
aux_type
,
sizeof
aux_type
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER
,
&
pre_act_out
->
data
.
dptr
,
sizeof
pre_act_out
->
data
.
dptr
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD
,
&
ldd
,
sizeof
ldd
));
if
(
is_fp8_dtype
(
pre_act_out
->
dtype
()))
{
NVTE_CHECK
(
pre_act_out
->
scale
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER
,
&
pre_act_out
->
scale
.
dptr
,
sizeof
(
void
*
)));
if
(
pre_act_out
->
amax
.
dptr
)
{
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER
,
&
pre_act_out
->
amax
.
dptr
,
sizeof
(
void
*
)));
}
}
}
if
(
comm_sm_count
)
{
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT
,
&
comm_sm_count
,
sizeof
comm_sm_count
));
}
NVTE_CHECK_CUBLASMP
(
cublasMpStreamSet
(
ctx
->
cublas_mp
.
get
(),
main_stream
));
size_t
wrksp_size_device
{};
size_t
wrksp_size_host
{};
float
alpha
=
1.0
;
float
beta
=
accumulate
?
1.0
:
0.0
;
std
::
tuple
args
{
ctx
->
cublas_mp
.
get
(),
ctx
->
matmul_desc
.
get
(),
m
,
n
,
k
,
&
alpha
,
a
->
data
.
dptr
,
1
,
1
,
ctx
->
a_desc
.
get
(),
b
->
data
.
dptr
,
1
,
1
,
ctx
->
b_desc
.
get
(),
&
beta
,
accumulate
?
d
->
data
.
dptr
:
nullptr
,
1
,
1
,
accumulate
?
ctx
->
d_desc
.
get
()
:
nullptr
,
d
->
data
.
dptr
,
1
,
1
,
ctx
->
d_desc
.
get
()};
NVTE_CHECK_CUBLASMP
(
std
::
apply
(
cublasMpMatmul_bufferSize
,
std
::
tuple_cat
(
args
,
std
::
tuple
{
&
wrksp_size_device
,
&
wrksp_size_host
})));
std
::
vector
<
uint8_t
>
workspace_host
(
wrksp_size_host
);
if
(
ctx
->
workspace_size
<
wrksp_size_device
)
{
nvshmem_free
(
ctx
->
workspace
);
ctx
->
workspace
=
nvshmem_malloc
(
wrksp_size_device
);
ctx
->
workspace_size
=
wrksp_size_device
;
}
NVTE_CHECK_CUBLASMP
(
std
::
apply
(
cublasMpMatmul
,
std
::
tuple_cat
(
args
,
std
::
tuple
{
ctx
->
workspace
,
ctx
->
workspace_size
,
workspace_host
.
data
(),
workspace_host
.
size
()})));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
ctx
->
event
.
get
(),
main_stream
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
ctx
->
stream
.
get
(),
ctx
->
event
.
get
(),
0
));
}
}
// namespace
NVTECommGemmCtx
*
nvte_comm_gemm_ctx_create
(
ncclComm_t
comm
,
int
nranks
,
int
rank
)
{
NVTE_API_CALL
(
nvte_comm_gemm_ctx_create
);
auto
stream
=
CudaStreamCreate
();
auto
event
=
CudaEventCreate
(
cudaEventDisableTiming
);
auto
cublas_mp
=
CublasMpCreate
(
stream
.
get
());
auto
col_major
=
CublasMpGridCreate
(
nranks
,
1
,
CUBLASMP_GRID_LAYOUT_COL_MAJOR
,
comm
);
auto
row_major
=
CublasMpGridCreate
(
1
,
nranks
,
CUBLASMP_GRID_LAYOUT_ROW_MAJOR
,
comm
);
// Pre-creating matrix descriptors here, will be initialized with the actual params later.
auto
a_desc
=
CublasMpMatrixDescCreate
(
1
,
1
,
1
,
1
,
0
,
0
,
1
,
CUDA_R_16F
,
row_major
.
get
());
auto
b_desc
=
CublasMpMatrixDescCreate
(
1
,
1
,
1
,
1
,
0
,
0
,
1
,
CUDA_R_16F
,
row_major
.
get
());
auto
d_desc
=
CublasMpMatrixDescCreate
(
1
,
1
,
1
,
1
,
0
,
0
,
1
,
CUDA_R_16F
,
row_major
.
get
());
auto
matmul_desc
=
CublasMpMatmulDescCreate
(
CUBLAS_COMPUTE_32F
);
return
new
NVTECommGemmCtx
{
.
nranks
=
nranks
,
.
rank
=
rank
,
.
comm
=
comm
,
.
stream
=
std
::
move
(
stream
),
.
event
=
std
::
move
(
event
),
.
cublas_mp
=
std
::
move
(
cublas_mp
),
.
grid_col_major
=
std
::
move
(
col_major
),
.
grid_row_major
=
std
::
move
(
row_major
),
.
a_desc
=
std
::
move
(
a_desc
),
.
b_desc
=
std
::
move
(
b_desc
),
.
d_desc
=
std
::
move
(
d_desc
),
.
matmul_desc
=
std
::
move
(
matmul_desc
),
};
}
void
nvte_comm_gemm_ctx_destroy
(
NVTECommGemmCtx
*
ctx
)
{
NVTE_API_CALL
(
nvte_comm_gemm_ctx_destroy
);
nvshmemx_sync_all_on_stream
(
ctx
->
stream
.
get
());
delete
ctx
;
}
void
nvte_all_gather_gemm
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
)
{
NVTE_API_CALL
(
nvte_all_gather_gemm
);
cublasmp_gemm
(
AgGemmInitMatrices
,
ctx
,
algo
,
m
,
n
,
k
,
convertNVTETensorCheck
(
a
),
convertNVTETensorCheck
(
b
),
convertNVTETensorCheck
(
d
),
convertNVTETensorCheck
(
bias
),
convertNVTETensorCheck
(
pre_act_out
),
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
main_stream
);
}
void
nvte_gemm_reduce_scatter
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
)
{
NVTE_API_CALL
(
nvte_gemm_reduce_scatter
);
cublasmp_gemm
(
GemmRsInitMatrices
,
ctx
,
algo
,
m
,
n
,
k
,
convertNVTETensorCheck
(
a
),
convertNVTETensorCheck
(
b
),
convertNVTETensorCheck
(
d
),
convertNVTETensorCheck
(
bias
),
convertNVTETensorCheck
(
pre_act_out
),
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
main_stream
);
}
void
nvte_gemm_all_reduce
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
)
{
NVTE_API_CALL
(
nvte_gemm_all_reduce
);
cublasmp_gemm
(
GemmArInitMatrices
,
ctx
,
algo
,
m
,
n
,
k
,
convertNVTETensorCheck
(
a
),
convertNVTETensorCheck
(
b
),
convertNVTETensorCheck
(
d
),
convertNVTETensorCheck
(
bias
),
convertNVTETensorCheck
(
pre_act_out
),
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
main_stream
);
}
int64_t
nvte_comm_gemm_numroc
(
NVTECommGemmCtx
*
ctx
,
int64_t
global_size
)
{
NVTE_API_CALL
(
nvte_comm_gemm_numroc
);
return
cublasMpNumroc
(
global_size
,
block_size
(
ctx
,
global_size
),
ctx
->
rank
,
0
,
ctx
->
nranks
);
}
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
27ddce40
...
...
@@ -153,10 +153,10 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
DType
::
kInt32
);
}
// CUDA event creation
cudaEventCreateWithFlags
(
&
_start_compute
,
0
);
cudaEventCreateWithFlags
(
&
_stop_compute
,
0
);
cudaEventCreateWithFlags
(
&
_start_comm
,
0
);
cudaEventCreateWithFlags
(
&
_stop_comm
,
0
);
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_start_compute
,
0
)
)
;
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_stop_compute
,
0
)
)
;
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_start_comm
,
0
)
)
;
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_stop_comm
,
0
)
)
;
/*
Defining the launcher order between the communication and GEMM kernels
...
...
@@ -169,12 +169,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
int
runtime_version
=
6
;
#else
int
runtime_version
=
0
;
cudaRuntimeGetVersion
(
&
runtime_version
);
NVTE_CHECK_CUDA
(
cudaRuntimeGetVersion
(
&
runtime_version
)
)
;
#endif
cudaDeviceProp
deviceProp
;
cudaGetDeviceProperties
(
&
deviceProp
,
0
);
NVTE_CHECK_CUDA
(
cudaGetDeviceProperties
(
&
deviceProp
,
0
)
)
;
if
(
runtime_version
>=
12030
&&
deviceProp
.
major
==
9
&&
max_connection
>
1
)
{
cudaEventCreateWithFlags
(
&
_comm_launch_event
,
cudaEventDisableTiming
);
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_comm_launch_event
,
cudaEventDisableTiming
)
)
;
}
else
{
_comm_launch_event
=
0
;
}
...
...
@@ -185,9 +185,13 @@ CommOverlapCore::~CommOverlapCore() {
cudaEventDestroy
(
_start_comm
);
cudaEventDestroy
(
_stop_compute
);
cudaEventDestroy
(
_start_compute
);
if
(
_comm_launch_event
)
cudaEventDestroy
(
_comm_launch_event
);
if
(
_comm_launch_event
)
{
cudaEventDestroy
(
_comm_launch_event
);
}
if
(
_atomic_gemm
)
cudaFree
(
_counter
.
dptr
());
if
(
_atomic_gemm
)
{
cudaFree
(
_counter
.
dptr
());
}
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
cudaStreamSynchronize
(
_stream_compute
[
i
]);
...
...
@@ -723,17 +727,21 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr
int
comm_bytes_per_rank
=
comm_bytes
/
_tp_size
;
// We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush
userbuffers_send_all
(
_ub_reg
,
0
,
_ub_reg
,
0
,
comm_bytes_per_rank
,
_tp_id
,
_tp_size
,
_
ub_comm
,
send_stream
);
userbuffers_recv_all
(
_ub_reg
,
0
,
_ub_reg
,
0
,
comm_bytes_per_rank
,
_tp_id
,
_tp_size
,
_
ub_comm
,
recv_stream
);
userbuffers_send_all
(
_ub_reg
,
0
,
_ub_reg
,
0
,
comm_bytes_per_rank
,
_tp_id
,
_tp_size
,
_
rank
,
_ub_comm
,
send_stream
);
userbuffers_recv_all
(
_ub_reg
,
0
,
_ub_reg
,
0
,
comm_bytes_per_rank
,
_tp_id
,
_tp_size
,
_
rank
,
_ub_comm
,
recv_stream
);
// We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
for
(
auto
stream
:
{
send_stream
,
recv_stream
})
{
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
stream
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
// We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_stop_comm
,
0
));
}
// Next we sync with the main stream
// We have to recapture an event off the comm stream to enable cuda graph capture otherwise the comm stream will be never be joined in the graph
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
}
/***************************************************************************************************
...
...
@@ -829,7 +837,9 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
cudaEventDestroy
(
_stop_recv
);
cudaEventDestroy
(
_stop_send
);
cudaStreamDestroy
(
_stream_recv
);
for
(
size_t
i
=
0
;
i
<
_stream_send
.
size
();
i
++
)
cudaStreamDestroy
(
_stream_send
[
i
]);
for
(
size_t
i
=
0
;
i
<
_stream_send
.
size
();
i
++
)
{
cudaStreamDestroy
(
_stream_send
[
i
]);
}
}
TensorWrapper
CommOverlapP2PBase
::
get_buffer_chunk_by_id
(
const
TensorWrapper
&
source
,
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
View file @
27ddce40
...
...
@@ -515,7 +515,7 @@ void destroy_communicator_mpi(communicator *comm) {
}
int
register_user_buffer_collective
(
void
**
gpubuff
,
size_t
bytes
,
communicator
*
comm
,
bool
alloc
)
{
if
(
comm
->
free_region
>
NVTE_MAX_REGIONS
)
return
-
1
;
if
(
comm
->
free_region
>
=
NVTE_MAX_REGIONS
)
return
-
1
;
int
hndl
=
comm
->
free_region
;
comm
->
peer_ptr
[
hndl
]
=
reinterpret_cast
<
void
**>
(
malloc
(
sizeof
(
void
*
)
*
(
comm
->
nvsize
)));
size_t
aligned_size
=
bytes
;
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
View file @
27ddce40
...
...
@@ -2436,6 +2436,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
if
(
comm
->
push
==
0
)
{
kuserbuffers_pullsend
<<<
1
,
1
,
0
,
stream
>>>
(
comm
->
myrank
,
peer
,
&
(
comm
->
send_id
[
peer
]),
reinterpret_cast
<
int
*>
(
flagptr
));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
void
*
srcptr
=
reinterpret_cast
<
char
*>
(
comm
->
mem_ptr
[
srchandler
])
+
srcoffset
;
void
*
dstptr
=
reinterpret_cast
<
char
*>
(
comm
->
peer_ptr
[
dsthandler
][
peerlocal
])
+
dstoffset
;
...
...
@@ -2633,8 +2634,11 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
&
(
comm
->
recv_id
[
peer
*
NVTE_MAX_REGIONS
+
dsthandler
]),
reinterpret_cast
<
int
*>
(
flagptr
),
reinterpret_cast
<
int4
*>
(
srcptr
),
reinterpret_cast
<
int4
*>
(
dstptr
),
signalonly
?
0
:
bytes
/
16
,
comm
->
ub_timeout
);
if
(
!
signalonly
)
NVTE_CHECK_CUDA
(
cudaGetLastError
());
if
(
!
signalonly
)
{
kuserbuffers_inc
<<<
1
,
1
,
0
,
stream
>>>
(
&
(
comm
->
recv_id
[
peer
*
NVTE_MAX_REGIONS
+
dsthandler
]));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
if
(
comm
->
use_ce
)
{
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
dstptr
,
srcptr
,
bytes
,
cudaMemcpyDeviceToDevice
,
stream
));
}
...
...
@@ -2649,30 +2653,33 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
reinterpret_cast
<
int
*>
(
0
?
// temporary disable
GET_RECV_PTR_BY_INDEX
(
peer
,
comm
,
dsthandler
,
2
)
:
nullptr
));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
void
userbuffers_send_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
)
{
int
tp_size
,
int
world_rank
,
communicator
*
comm
,
cudaStream_t
stream
)
{
int
rank_round_tp
=
(
world_rank
/
tp_size
)
*
tp_size
;
for
(
int
j
=
1
;
j
<
tp_size
;
j
++
)
{
int
i
=
(
tp_rank
+
j
)
%
tp_size
;
int
send_offset
=
srcoffset
+
bytes_per_slice
*
tp_rank
;
int
recv_offset
=
dstoffset
+
bytes_per_slice
*
tp_rank
;
userbuffers_send
(
srchandler
,
send_offset
,
dsthandler
,
recv_offset
,
bytes_per_slice
,
comm
,
i
,
stream
);
userbuffers_send
(
srchandler
,
send_offset
,
dsthandler
,
recv_offset
,
bytes_per_slice
,
comm
,
rank_round_tp
+
i
,
stream
);
}
}
void
userbuffers_recv_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
)
{
int
tp_size
,
int
world_rank
,
communicator
*
comm
,
cudaStream_t
stream
)
{
int
rank_round_tp
=
(
world_rank
/
tp_size
)
*
tp_size
;
for
(
int
j
=
tp_size
-
1
;
j
>
0
;
j
--
)
{
int
i
=
(
tp_rank
+
j
)
%
tp_size
;
int
send_offset
=
srcoffset
+
bytes_per_slice
*
i
;
int
recv_offset
=
dstoffset
+
bytes_per_slice
*
i
;
userbuffers_recv
(
srchandler
,
send_offset
,
dsthandler
,
recv_offset
,
bytes_per_slice
,
comm
,
i
,
stream
);
userbuffers_recv
(
srchandler
,
send_offset
,
dsthandler
,
recv_offset
,
bytes_per_slice
,
comm
,
rank_round_tp
+
i
,
stream
);
}
}
...
...
@@ -2747,24 +2754,28 @@ void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3
block
(
1
);
dim3
grid
(
1
);
producer_kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
atomic_ptr
,
chunk_i
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
consumer
(
void
*
atomic_ptr
,
int
chunk_i
,
cudaStream_t
stream
)
{
dim3
block
(
1
);
dim3
grid
(
1
);
consumer_kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
atomic_ptr
,
chunk_i
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
consumer_batch
(
void
*
atomic_ptr
,
int
first_chunk_i
,
int
num_chunks
,
cudaStream_t
stream
)
{
dim3
block
(
1
);
dim3
grid
(
1
);
consumer_batch_kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
atomic_ptr
,
first_chunk_i
,
num_chunks
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
reset_counters
(
void
*
atomic_ptr
,
int
num_chunks
,
bool
allgather
,
cudaStream_t
stream
)
{
dim3
block
(
1
);
dim3
grid
(
1
);
reset_counters_kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
atomic_ptr
,
num_chunks
,
allgather
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
template
<
typename
fp8type
,
int
nvec
>
...
...
@@ -2818,6 +2829,7 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
reduce_fp8_in_bf16_out_cuda
<
fp8type
,
nvec
>
<<<
grid
,
block
,
0
,
stream
>>>
(
inputs
,
output
,
scale
,
num_inputs
,
input_size
,
num_aligned_elements_per_input
,
tot_input_size
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
template
void
reduce_fp8_in_bf16_out
<
__nv_fp8_e4m3
>(
void
*
inputs
,
void
*
output
,
float
*
scale
,
...
...
@@ -2877,4 +2889,5 @@ void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cud
dim3
grid
(
num_blocks
);
reduce_bf16_cuda
<
nvec
><<<
grid
,
block
,
0
,
stream
>>>
(
inputs
,
output
,
num_inputs
,
input_size
,
num_aligned_elements_per_input
,
tot_input_size
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
View file @
27ddce40
...
...
@@ -27,7 +27,7 @@
using
ExtAllgatherOp
=
std
::
function
<
void
(
void
*
,
size_t
,
void
*
,
size_t
,
ExtComm
)
>
;
using
ExtBarrierOp
=
std
::
function
<
void
(
ExtComm
)
>
;
#define NVTE_MAX_REGIONS
16
#define NVTE_MAX_REGIONS
32
#define NVTE_MAX_SMS 32
#define NVTE_MAX_OPS 32
#define NVTE_MAX_PEERS 8192
...
...
@@ -314,10 +314,10 @@ void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cuda
void
userbuffers_send_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
);
int
tp_size
,
int
world_rank
,
communicator
*
comm
,
cudaStream_t
stream
);
void
userbuffers_recv_all
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes_per_slice
,
int
tp_rank
,
int
tp_size
,
communicator
*
comm
,
cudaStream_t
stream
);
int
tp_size
,
int
world_rank
,
communicator
*
comm
,
cudaStream_t
stream
);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
transformer_engine/common/common.cu
View file @
27ddce40
...
...
@@ -26,12 +26,31 @@ __global__ void __launch_bounds__(1)
}
// namespace
cudaDataType_t
get_cuda_dtype
(
const
transformer_engine
::
DType
t
)
{
using
namespace
transformer_engine
;
switch
(
t
)
{
case
DType
::
kFloat16
:
return
CUDA_R_16F
;
case
DType
::
kFloat32
:
return
CUDA_R_32F
;
case
DType
::
kBFloat16
:
return
CUDA_R_16BF
;
case
DType
::
kFloat8E4M3
:
return
CUDA_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
return
CUDA_R_8F_E5M2
;
default:
NVTE_ERROR
(
"Invalid type"
);
}
}
void
update_tensor_scale_inv
(
Tensor
*
t
,
cudaStream_t
stream
)
{
if
(
is_fp8_dtype
(
t
->
data
.
dtype
)
&&
is_tensor_scaling
(
t
->
scaling_mode
))
{
NVTE_CHECK
(
t
->
scale_inv
.
dptr
!=
nullptr
,
"Tensor should have allocated scale_inv."
);
update_tensor_scale_inv_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
reinterpret_cast
<
const
float
*>
(
t
->
scale
.
dptr
),
reinterpret_cast
<
float
*>
(
t
->
scale_inv
.
dptr
));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
...
@@ -73,6 +92,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
dim3 grid(numBlocks, 1, 1); \
memset_kernel<vectorizedType> \
<<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \
NVTE_CHECK_CUDA(cudaGetLastError()); \
return; \
}
...
...
@@ -83,7 +103,7 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream
if
(
size_in_bytes
>
4096
)
{
// Use cudaMemsetAsync for larger sizes.
cudaMemsetAsync
(
ptr
,
value
,
size_in_bytes
,
stream
);
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
ptr
,
value
,
size_in_bytes
,
stream
)
)
;
return
;
}
...
...
transformer_engine/common/common.h
View file @
27ddce40
...
...
@@ -276,6 +276,8 @@ struct QuantizationConfig {
};
};
cudaDataType_t
get_cuda_dtype
(
const
transformer_engine
::
DType
t
);
template
<
typename
T
>
constexpr
T
DIVUP
(
const
T
&
x
,
const
T
&
y
)
{
return
(((
x
)
+
((
y
)
-
1
))
/
(
y
));
...
...
@@ -395,9 +397,19 @@ struct BitsNumber {
template
<
typename
T
>
struct
TypeInfo
{
#if FP4_TYPE_SUPPORTED
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
,
fp4e2m1
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
,
fp4e2m1
#if CUDA_VERSION >= 12080
,
fp8e8m0
#endif
>
;
#else
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
#if CUDA_VERSION >= 12080
,
fp8e8m0
#endif
>
;
#endif
template
<
typename
U
,
DType
current
>
...
...
transformer_engine/common/dropout/dropout.cu
0 → 100644
View file @
27ddce40
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include <cmath>
#include "../common.h"
#include "../utils.cuh"
#include "transformer_engine/dropout.h"
namespace
transformer_engine
{
namespace
{
// RNG kernels process chunks of 16 entries
constexpr
size_t
rng_chunk_size
=
16
;
// CUDA block size
constexpr
size_t
block_size
=
128
;
// Vector class to help with vectorized memory accesses
template
<
typename
T
,
size_t
kSize
>
union
Vector
{
using
StorageType
=
typename
BytesToType
<
sizeof
(
T
)
*
kSize
>::
Type
;
StorageType
storage
;
T
entries
[
kSize
];
};
/* Byte-wise less-than comparison
*
* Results are stored in each byte's most-significant bit (MSB). All
* other bits are zero.
*/
__device__
__forceinline__
uint32_t
bytewise_less_than
(
uint32_t
a
,
uint32_t
b
)
{
// Compare low bits by masking MSBs and subtracting. The resulting
// MSBs are 0 if the low bits of a are less than the low bits of b.
uint32_t
result
=
(
a
|
0x80808080
)
-
(
b
&
0x7F7F7F7F
);
// Bitwise logical op to get answer in MSBs
// Equivalent logic: result = (a == b) ? !result : b
asm
(
"lop3.b32 %0, %1, %2, %3, 0x4D;
\n\t
"
:
"=r"
(
result
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
result
));
// Mask out everything except MSBs and return
result
&=
0x80808080
;
return
result
;
}
/* Generate dropout mask with 16 bits.
*
* 1 corresponds to keep and 0 to drop.
*
* Consumes 4 values from cuRAND Philox generator.
*/
__device__
__forceinline__
uint16_t
make_16bit_mask
(
uint64_t
chunk_idx
,
uint64_t
rng_seed
,
uint64_t
rng_offset
,
uint32_t
bytewise_drop_prob
)
{
// Generate random bits
curandStatePhilox4_32_10_t
state
;
curand_init
(
rng_seed
,
chunk_idx
,
rng_offset
,
&
state
);
const
uint4
rand_bits
=
curand4
(
&
state
);
// Compute mask
// Note: bytewise_less_than fills MSBs (bits 7, 15, 23, 31). By
// shifting 2 bits after every call, every other bit will be filled.
uint32_t
result
=
bytewise_less_than
(
rand_bits
.
x
,
bytewise_drop_prob
);
result
=
(
result
>>
2
)
|
bytewise_less_than
(
rand_bits
.
y
,
bytewise_drop_prob
);
result
=
(
result
>>
2
)
|
bytewise_less_than
(
rand_bits
.
z
,
bytewise_drop_prob
);
result
=
(
result
>>
2
)
|
bytewise_less_than
(
rand_bits
.
w
,
bytewise_drop_prob
);
// Consolidate mask in lowest 16 bits
result
|=
result
>>
17
;
// Flip bits so 0 corresponds to drop
result
=
~
result
;
return
result
;
}
// Dropout forward with FP16/BF16 input and output.
template
<
typename
T
>
__global__
void
__launch_bounds__
(
block_size
)
dropout_kernel_fwd_f16
(
const
T
*
__restrict__
input_ptr
,
T
*
__restrict__
output_ptr
,
uint8_t
*
__restrict__
mask_ptr
,
const
uint64_t
*
__restrict__
rng_state_ptr
,
size_t
num_chunks
,
uint32_t
bytewise_drop_prob
,
float
scale
)
{
static_assert
(
sizeof
(
T
)
==
2
);
// Each thread processes a chunk of 16 entries
const
size_t
gid
=
threadIdx
.
x
+
blockIdx
.
x
*
block_size
;
const
size_t
nthreads
=
gridDim
.
x
*
block_size
;
for
(
size_t
chunk_idx
=
gid
;
chunk_idx
<
num_chunks
;
chunk_idx
+=
nthreads
)
{
// Generate dropout mask
auto
local_mask
=
make_16bit_mask
(
chunk_idx
,
rng_state_ptr
[
0
],
rng_state_ptr
[
1
],
bytewise_drop_prob
);
reinterpret_cast
<
uint16_t
*>
(
mask_ptr
)[
chunk_idx
]
=
local_mask
;
// Read input data
using
VectorType
=
Vector
<
T
,
rng_chunk_size
>
;
VectorType
local_data
;
local_data
=
reinterpret_cast
<
const
VectorType
*>
(
input_ptr
)[
chunk_idx
];
// Apply dropout based on mask
#pragma unroll
for
(
size_t
i
=
0
;
i
<
rng_chunk_size
;
i
++
)
{
float
val
=
static_cast
<
float
>
(
local_data
.
entries
[
i
]);
if
((
local_mask
&
0x1
)
==
0
)
{
val
=
0
;
}
val
*=
scale
;
local_data
.
entries
[
i
]
=
static_cast
<
T
>
(
val
);
local_mask
>>=
1
;
}
// Write output data
reinterpret_cast
<
VectorType
*>
(
output_ptr
)[
chunk_idx
]
=
local_data
;
}
}
// Dropout forward with FP8 input and FP16/BF16 output.
template
<
typename
InputType
,
typename
OutputType
>
__global__
void
__launch_bounds__
(
block_size
)
dropout_kernel_fwd_fp8
(
const
InputType
*
__restrict__
input_ptr
,
const
float
*
__restrict__
input_scale_inv_ptr
,
OutputType
*
__restrict__
output_ptr
,
uint8_t
*
__restrict__
mask_ptr
,
const
uint64_t
*
__restrict__
rng_state_ptr
,
size_t
num_chunks
,
uint32_t
bytewise_drop_prob
,
float
scale
)
{
static_assert
(
sizeof
(
InputType
)
==
1
);
static_assert
(
sizeof
(
OutputType
)
==
2
);
const
float
input_scale_inv
=
*
input_scale_inv_ptr
;
// Each thread processes a chunk of 16 entries
const
size_t
gid
=
threadIdx
.
x
+
blockIdx
.
x
*
block_size
;
const
size_t
nthreads
=
gridDim
.
x
*
block_size
;
for
(
size_t
chunk_idx
=
gid
;
chunk_idx
<
num_chunks
;
chunk_idx
+=
nthreads
)
{
// Generate dropout mask
auto
local_mask
=
make_16bit_mask
(
chunk_idx
,
rng_state_ptr
[
0
],
rng_state_ptr
[
1
],
bytewise_drop_prob
);
reinterpret_cast
<
uint16_t
*>
(
mask_ptr
)[
chunk_idx
]
=
local_mask
;
// Read input data
using
InputVectorType
=
Vector
<
InputType
,
rng_chunk_size
>
;
InputVectorType
local_input
;
local_input
=
reinterpret_cast
<
const
InputVectorType
*>
(
input_ptr
)[
chunk_idx
];
// Apply dropout based on mask
using
OutputVectorType
=
Vector
<
OutputType
,
rng_chunk_size
>
;
OutputVectorType
local_output
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
rng_chunk_size
;
i
++
)
{
float
val
=
static_cast
<
float
>
(
local_input
.
entries
[
i
]);
val
*=
input_scale_inv
;
if
((
local_mask
&
0x1
)
==
0
)
{
val
=
0
;
}
val
*=
scale
;
local_output
.
entries
[
i
]
=
static_cast
<
OutputType
>
(
val
);
local_mask
>>=
1
;
}
// Write output data
reinterpret_cast
<
OutputVectorType
*>
(
output_ptr
)[
chunk_idx
]
=
local_output
;
}
}
// Apply dropout mask and scale.
template
<
typename
T
>
__global__
void
__launch_bounds__
(
block_size
)
apply_dropout_mask
(
const
T
*
__restrict__
input_ptr
,
const
uint8_t
*
__restrict__
mask_ptr
,
T
*
__restrict__
output_ptr
,
size_t
num_chunks
,
float
scale
)
{
// Each thread processes a chunk of 8 entries.
const
size_t
gid
=
threadIdx
.
x
+
blockIdx
.
x
*
block_size
;
const
size_t
nthreads
=
gridDim
.
x
*
block_size
;
constexpr
size_t
chunk_size
=
8
;
for
(
size_t
chunk_idx
=
gid
;
chunk_idx
<
num_chunks
;
chunk_idx
+=
nthreads
)
{
// Read dropout mask
uint8_t
local_mask
=
mask_ptr
[
chunk_idx
];
// Read input data
using
VectorType
=
Vector
<
T
,
chunk_size
>
;
VectorType
local_data
;
local_data
=
reinterpret_cast
<
const
VectorType
*>
(
input_ptr
)[
chunk_idx
];
// Apply dropout based on mask
#pragma unroll
for
(
size_t
i
=
0
;
i
<
chunk_size
;
i
++
)
{
float
val
=
static_cast
<
float
>
(
local_data
.
entries
[
i
]);
if
((
local_mask
&
0x1
)
==
0
)
{
val
=
0
;
}
val
*=
scale
;
local_data
.
entries
[
i
]
=
static_cast
<
T
>
(
val
);
local_mask
>>=
1
;
}
// Write output data
reinterpret_cast
<
VectorType
*>
(
output_ptr
)[
chunk_idx
]
=
local_data
;
}
}
}
// namespace
void
dropout_fwd
(
const
Tensor
&
input
,
Tensor
&
output
,
Tensor
&
mask
,
Tensor
&
rng_state
,
float
dropout_probability
,
cudaStream_t
stream
)
{
// Check tensors
const
size_t
numel
=
input
.
numel
();
NVTE_CHECK
(
input
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, "
,
"but scaling mode is "
,
to_string
(
input
.
scaling_mode
),
"."
);
NVTE_CHECK
(
output
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Output tensor must be FP16/BF16 tensor, "
,
"but scaling mode is "
,
to_string
(
output
.
scaling_mode
),
"."
);
NVTE_CHECK
(
mask
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Mask tensor must be plain tensor, "
,
"but scaling mode is "
,
to_string
(
mask
.
scaling_mode
),
"."
);
NVTE_CHECK
(
rng_state
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"RNG state tensor must be INT64 tensor with two entries, "
,
"but scaling mode is "
,
to_string
(
rng_state
.
scaling_mode
),
"."
);
NVTE_CHECK
(
output
.
dtype
()
==
DType
::
kFloat16
||
output
.
dtype
()
==
DType
::
kBFloat16
,
"Output tensor must be FP16/BF16 tensor, but dtype is "
,
to_string
(
output
.
dtype
()),
"."
);
NVTE_CHECK
(
rng_state
.
dtype
()
==
DType
::
kInt64
,
"RNG state tensor must be INT64 tensor with two entries, but dtype is "
,
to_string
(
rng_state
.
dtype
()),
"."
);
NVTE_CHECK
(
numel
%
16
==
0
,
"Input tensor number of elements must be divisible by 16, but shape is "
,
input
.
shape
(),
"."
);
NVTE_CHECK
(
numel
==
output
.
numel
(),
"Input tensor (shape="
,
input
.
shape
(),
") and output tensor (shape="
,
output
.
shape
(),
") do not match."
);
NVTE_CHECK
(
typeToNumBits
(
mask
.
dtype
())
*
mask
.
numel
()
==
numel
,
"Mask tensor must have "
,
numel
,
" bits, but found dtype="
,
to_string
(
mask
.
dtype
()),
" and shape="
,
mask
.
shape
(),
"."
);
NVTE_CHECK
(
rng_state
.
numel
()
==
2
,
"RNG state tensor must be INT64 tensor with two entries, "
,
"but shape is "
,
rng_state
.
shape
(),
"."
);
NVTE_CHECK
(
input
.
data
.
dptr
!=
nullptr
,
"Input tensor is missing data."
);
NVTE_CHECK
(
output
.
data
.
dptr
!=
nullptr
,
"Output tensor is missing data."
);
NVTE_CHECK
(
mask
.
data
.
dptr
!=
nullptr
,
"Mask tensor is missing data."
);
NVTE_CHECK
(
rng_state
.
data
.
dptr
!=
nullptr
,
"RNG state tensor is missing data."
);
// Convert dropout probablity to scale and 8-bit representation
NVTE_CHECK
(
dropout_probability
>=
0
&&
dropout_probability
<
1
,
"Invalid dropout probability ("
,
dropout_probability
,
")."
);
const
float
scale
=
1
/
(
1
-
dropout_probability
);
uint32_t
bytewise_drop_prob
=
static_cast
<
uint32_t
>
(
std
::
floor
(
dropout_probability
*
256
));
bytewise_drop_prob
|=
bytewise_drop_prob
<<
8
;
bytewise_drop_prob
|=
bytewise_drop_prob
<<
16
;
// CUDA config
const
size_t
num_chunks
=
numel
/
rng_chunk_size
;
const
size_t
num_blocks
=
DIVUP
(
num_chunks
,
block_size
);
// Launch kernel depending on input dtype
if
(
input
.
dtype
()
==
DType
::
kFloat16
||
input
.
dtype
()
==
DType
::
kBFloat16
)
{
NVTE_CHECK
(
input
.
dtype
()
==
output
.
dtype
(),
"Input tensor (dtype="
,
to_string
(
input
.
dtype
()),
") and output tensor (dtype="
,
to_string
(
output
.
dtype
()),
") do not match."
);
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT
(
input
.
dtype
(),
DType
,
dropout_kernel_fwd_f16
<
DType
><<<
num_blocks
,
block_size
,
0
,
stream
>>>
(
reinterpret_cast
<
const
DType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
DType
*>
(
output
.
data
.
dptr
),
reinterpret_cast
<
uint8_t
*>
(
mask
.
data
.
dptr
),
reinterpret_cast
<
const
uint64_t
*>
(
rng_state
.
data
.
dptr
),
num_chunks
,
bytewise_drop_prob
,
scale
););
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
if
(
input
.
dtype
()
==
DType
::
kFloat8E4M3
||
input
.
dtype
()
==
DType
::
kFloat8E5M2
)
{
NVTE_CHECK
(
input
.
scale_inv
.
dptr
!=
nullptr
,
"Input tensor scale-inverse is not allocated."
);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
input
.
dtype
(),
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT
(
output
.
dtype
(),
OutputType
,
dropout_kernel_fwd_fp8
<
InputType
,
OutputType
><<<
num_blocks
,
block_size
,
0
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
input
.
scale_inv
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
data
.
dptr
),
reinterpret_cast
<
uint8_t
*>
(
mask
.
data
.
dptr
),
reinterpret_cast
<
const
uint64_t
*>
(
rng_state
.
data
.
dptr
),
num_chunks
,
bytewise_drop_prob
,
scale
);
););
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
NVTE_ERROR
(
"Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, "
,
"but dtype is "
,
to_string
(
input
.
dtype
()),
"."
);
}
}
void
dropout_bwd
(
const
Tensor
&
grad_output
,
const
Tensor
&
mask
,
Tensor
&
grad_input
,
float
dropout_probability
,
cudaStream_t
stream
)
{
// Check tensors
const
size_t
numel
=
grad_output
.
numel
();
NVTE_CHECK
(
grad_output
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Grad output tensor must be FP16/BF16 tensor, "
,
"but scaling mode is "
,
to_string
(
grad_output
.
scaling_mode
),
"."
);
NVTE_CHECK
(
grad_input
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Grad input tensor must be FP16/BF16 tensor, "
,
"but scaling mode is "
,
to_string
(
grad_input
.
scaling_mode
),
"."
);
NVTE_CHECK
(
mask
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Mask tensor must be a plain tensor, but scaling mode is "
,
to_string
(
mask
.
scaling_mode
),
"."
);
NVTE_CHECK
(
grad_output
.
dtype
()
==
DType
::
kFloat16
||
grad_output
.
dtype
()
==
DType
::
kBFloat16
,
"Grad output tensor must be FP16/BF16 tensor, but dtype is "
,
to_string
(
grad_output
.
dtype
()),
"."
);
NVTE_CHECK
(
grad_output
.
dtype
()
==
grad_input
.
dtype
(),
"Grad output tensor (dtype="
,
to_string
(
grad_output
.
dtype
()),
") and grad input tensor (dtype="
,
to_string
(
grad_input
.
dtype
()),
") do not match."
);
NVTE_CHECK
(
numel
%
16
==
0
,
"Grad output tensor number of elements must be divisible by 16, but shape is "
,
grad_output
.
shape
(),
"."
);
NVTE_CHECK
(
numel
==
grad_input
.
numel
(),
"Grad output tensor (shape="
,
grad_output
.
shape
(),
") and grad input tensor (shape="
,
grad_input
.
shape
(),
") do not match."
);
NVTE_CHECK
(
typeToNumBits
(
mask
.
dtype
())
*
mask
.
numel
()
==
numel
,
"Mask tensor must have "
,
numel
,
" bits, but found dtype="
,
to_string
(
mask
.
dtype
()),
" and shape="
,
mask
.
shape
(),
"."
);
NVTE_CHECK
(
grad_output
.
data
.
dptr
!=
nullptr
,
"Grad output tensor is missing data."
);
NVTE_CHECK
(
grad_input
.
data
.
dptr
!=
nullptr
,
"Grad input tensor is missing data."
);
NVTE_CHECK
(
mask
.
data
.
dptr
!=
nullptr
,
"Mask tensor is missing data."
);
// Convert dropout probablity to scale
NVTE_CHECK
(
dropout_probability
>=
0
&&
dropout_probability
<
1
,
"Invalid dropout probability ("
,
dropout_probability
,
")."
);
const
float
scale
=
1
/
(
1
-
dropout_probability
);
// CUDA config
const
size_t
num_chunks
=
numel
/
8
;
const
size_t
num_blocks
=
DIVUP
(
num_chunks
,
block_size
);
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT
(
grad_output
.
dtype
(),
DType
,
apply_dropout_mask
<
DType
><<<
num_blocks
,
block_size
,
0
,
stream
>>>
(
reinterpret_cast
<
const
DType
*>
(
grad_output
.
data
.
dptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask
.
data
.
dptr
),
reinterpret_cast
<
DType
*>
(
grad_input
.
data
.
dptr
),
num_chunks
,
scale
););
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// namespace transformer_engine
void
nvte_dropout_fwd
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
mask
,
NVTETensor
rng_state
,
float
dropout_probability
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dropout_fwd
);
using
namespace
transformer_engine
;
dropout_fwd
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
output
),
*
convertNVTETensorCheck
(
mask
),
*
convertNVTETensorCheck
(
rng_state
),
dropout_probability
,
stream
);
}
void
nvte_dropout_bwd
(
const
NVTETensor
grad_output
,
const
NVTETensor
mask
,
NVTETensor
grad_input
,
float
dropout_probability
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dropout_bwd
);
using
namespace
transformer_engine
;
dropout_bwd
(
*
convertNVTETensorCheck
(
grad_output
),
*
convertNVTETensorCheck
(
mask
),
*
convertNVTETensorCheck
(
grad_input
),
dropout_probability
,
stream
);
}
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment