Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f8c2af4c
Commit
f8c2af4c
authored
May 21, 2025
by
yuguo
Browse files
Merge commit '
1d903f5e
' of...
Merge commit '
1d903f5e
' of
https://github.com/NVIDIA/TransformerEngine
parents
e92773a3
1d903f5e
Changes
211
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1970 additions
and
139 deletions
+1970
-139
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
+1
-1
tests/pytorch/fused_attn/test_kv_cache.py
tests/pytorch/fused_attn/test_kv_cache.py
+10
-9
tests/pytorch/test_float8blockwisetensor.py
tests/pytorch/test_float8blockwisetensor.py
+104
-0
tests/pytorch/test_float8tensor.py
tests/pytorch/test_float8tensor.py
+2
-2
tests/pytorch/test_fused_optimizer.py
tests/pytorch/test_fused_optimizer.py
+8
-7
tests/pytorch/test_fused_rope.py
tests/pytorch/test_fused_rope.py
+69
-13
tests/pytorch/test_multi_tensor.py
tests/pytorch/test_multi_tensor.py
+1
-1
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+13
-6
tests/pytorch/test_parallel_cross_entropy.py
tests/pytorch/test_parallel_cross_entropy.py
+31
-3
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+32
-8
transformer_engine/__init__.py
transformer_engine/__init__.py
+2
-2
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+20
-0
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+238
-41
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+114
-44
transformer_engine/common/common.cu
transformer_engine/common/common.cu
+71
-0
transformer_engine/common/common.h
transformer_engine/common/common.h
+46
-2
transformer_engine/common/fused_attn/context_parallel.cu
transformer_engine/common/fused_attn/context_parallel.cu
+743
-0
transformer_engine/common/fused_attn/flash_attn.cu
transformer_engine/common/fused_attn/flash_attn.cu
+153
-0
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+15
-0
transformer_engine/common/fused_attn/kv_cache.cu
transformer_engine/common/fused_attn/kv_cache.cu
+297
-0
No files found.
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
View file @
f8c2af4c
...
@@ -11,7 +11,7 @@ from transformer_engine.pytorch.utils import (
...
@@ -11,7 +11,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability
,
get_device_compute_capability
,
get_cudnn_version
,
get_cudnn_version
,
)
)
from
transformer_engine.pytorch.dot_product_attention.utils
import
FlashAttentionUtils
from
transformer_engine.pytorch.
attention.
dot_product_attention.utils
import
FlashAttentionUtils
from
test_fused_attn
import
ModelConfig
from
test_fused_attn
import
ModelConfig
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
...
tests/pytorch/fused_attn/test_kv_cache.py
View file @
f8c2af4c
...
@@ -11,6 +11,12 @@ import math
...
@@ -11,6 +11,12 @@ import math
import
pytest
import
pytest
import
torch
import
torch
from
test_fused_attn
import
(
ModelConfig
,
reset_rng_states
,
_get_attention_backends
,
)
from
torch.distributions
import
Exponential
from
torch.distributions
import
Exponential
from
transformer_engine.pytorch
import
make_graphed_callables
from
transformer_engine.pytorch
import
make_graphed_callables
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
...
@@ -18,20 +24,15 @@ from transformer_engine.pytorch import fp8_autocast, fp8_model_init
...
@@ -18,20 +24,15 @@ from transformer_engine.pytorch import fp8_autocast, fp8_model_init
from
transformer_engine.pytorch.transformer
import
(
from
transformer_engine.pytorch.transformer
import
(
TransformerLayer
,
TransformerLayer
,
)
)
from
transformer_engine.pytorch.attention
import
DotProductAttention
from
transformer_engine.pytorch.attention
import
DotProductAttention
,
InferenceParams
from
transformer_engine.pytorch.dot_product_attention.inference
import
InferenceParams
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
from
transformer_engine.pytorch.dot_product_attention.utils
import
FlashAttentionUtils
as
fa_utils
FlashAttentionUtils
as
fa_utils
,
)
from
transformer_engine.pytorch.utils
import
(
from
transformer_engine.pytorch.utils
import
(
get_device_compute_capability
,
init_method_normal
,
init_method_normal
,
scaled_init_method_normal
,
scaled_init_method_normal
,
is_bf16_compatible
,
is_bf16_compatible
,
)
)
from
test_fused_attn
import
(
ModelConfig
,
reset_rng_states
,
_get_attention_backends
,
)
# Initialize RNG state
# Initialize RNG state
seed
=
1234
seed
=
1234
...
...
tests/pytorch/test_float8blockwisetensor.py
View file @
f8c2af4c
...
@@ -392,6 +392,110 @@ class TestFloat8BlockwiseTensor:
...
@@ -392,6 +392,110 @@ class TestFloat8BlockwiseTensor:
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
torch
.
testing
.
assert_close
(
x_view
.
dequantize
(),
-
x_hp
,
**
_tols
[
fp8_dtype
])
torch
.
testing
.
assert_close
(
x_view
.
dequantize
(),
-
x_hp
,
**
_tols
[
fp8_dtype
])
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
16
,
16
,
512
],
[
16
,
16
,
512
,
16
],
[
12
,
7
,
11
],
[
13
,
14
,
16
],
[
2
,
3
,
5
]]
)
def
test_view_and_reshape_1D
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
dims
:
List
[
int
]
)
->
None
:
"""Test view operations that preserve tensor shape"""
device
=
"cuda"
def
is_bitwise_equal
(
a
,
b
):
if
a
.
numel
()
!=
b
.
numel
():
return
False
a_flat
=
a
.
reshape
(
-
1
).
view
(
torch
.
uint8
)
b_flat
=
b
.
reshape
(
-
1
).
view
(
torch
.
uint8
)
return
torch
.
all
((
a_flat
^
b_flat
)
==
0
)
x_hp
=
torch
.
rand
(
dims
,
dtype
=
dtype
,
device
=
device
)
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
1
,
)
x_fp8
=
quantizer
.
make_empty
(
x_hp
.
shape
,
dtype
=
dtype
,
device
=
device
)
quantizer
.
update_quantized
(
x_hp
.
clone
(),
x_fp8
)
# Test view, high dimension tensor -> 2D tensor
x_hp_view
=
x_hp
.
view
(
-
1
,
dims
[
-
1
]).
contiguous
()
x_fp8_view
=
x_fp8
.
view
(
-
1
,
dims
[
-
1
])
# Check the dequantized result
torch
.
testing
.
assert_close
(
x_fp8_view
.
dequantize
().
contiguous
(),
x_hp_view
,
**
_tols
[
fp8_dtype
]
)
# Check the bitwise equality of the inner data
assert
is_bitwise_equal
(
x_fp8_view
.
_rowwise_data
,
x_fp8
.
_rowwise_data
)
assert
is_bitwise_equal
(
x_fp8_view
.
_rowwise_scale_inv
,
x_fp8
.
_rowwise_scale_inv
)
# Check the data ptr
assert
x_fp8_view
.
_rowwise_data
.
data_ptr
()
==
x_fp8
.
_rowwise_data
.
data_ptr
()
assert
x_fp8_view
.
_rowwise_scale_inv
.
data_ptr
()
==
x_fp8
.
_rowwise_scale_inv
.
data_ptr
()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape
=
x_hp
.
reshape
(
-
1
,
dims
[
-
1
]).
contiguous
()
x_fp8_reshape
=
x_fp8
.
reshape
(
-
1
,
dims
[
-
1
])
# Check the dequantized result
torch
.
testing
.
assert_close
(
x_fp8_reshape
.
dequantize
().
contiguous
(),
x_hp_reshape
,
**
_tols
[
fp8_dtype
]
)
# Check the bitwise equality of the inner data
assert
is_bitwise_equal
(
x_fp8_reshape
.
_rowwise_data
,
x_fp8
.
_rowwise_data
)
assert
is_bitwise_equal
(
x_fp8_reshape
.
_rowwise_scale_inv
,
x_fp8
.
_rowwise_scale_inv
)
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
16
,
16
,
512
,
16
],
[
2
,
512
,
512
,
128
],
[
3
,
13
,
14
,
16
]])
def
test_view_and_reshape_2D
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
dims
:
List
[
int
]
)
->
None
:
"""Test view operations that preserve tensor shape"""
device
=
"cuda"
def
is_bitwise_equal
(
a
,
b
):
if
a
.
numel
()
!=
b
.
numel
():
return
False
a_flat
=
a
.
reshape
(
-
1
).
view
(
torch
.
uint8
)
b_flat
=
b
.
reshape
(
-
1
).
view
(
torch
.
uint8
)
return
torch
.
all
((
a_flat
^
b_flat
)
==
0
)
x_hp
=
torch
.
rand
(
dims
,
dtype
=
dtype
,
device
=
device
)
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
2
,
)
x_fp8
=
quantizer
.
make_empty
(
x_hp
.
shape
,
dtype
=
dtype
,
device
=
device
)
quantizer
.
update_quantized
(
x_hp
.
clone
(),
x_fp8
)
# Test view, high dimension tensor -> 2D tensor
x_hp_view
=
x_hp
.
view
(
-
1
,
dims
[
-
2
],
dims
[
-
1
]).
contiguous
()
x_fp8_view
=
x_fp8
.
view
(
-
1
,
dims
[
-
2
],
dims
[
-
1
])
# Check the dequantized result
torch
.
testing
.
assert_close
(
x_fp8_view
.
dequantize
().
contiguous
(),
x_hp_view
,
**
_tols
[
fp8_dtype
]
)
# Check the bitwise equality of the inner data
assert
is_bitwise_equal
(
x_fp8_view
.
_rowwise_data
,
x_fp8
.
_rowwise_data
)
assert
is_bitwise_equal
(
x_fp8_view
.
_rowwise_scale_inv
,
x_fp8
.
_rowwise_scale_inv
)
# Check the data ptr
assert
x_fp8_view
.
_rowwise_data
.
data_ptr
()
==
x_fp8
.
_rowwise_data
.
data_ptr
()
assert
x_fp8_view
.
_rowwise_scale_inv
.
data_ptr
()
==
x_fp8
.
_rowwise_scale_inv
.
data_ptr
()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape
=
x_hp
.
reshape
(
-
1
,
dims
[
-
2
],
dims
[
-
1
]).
contiguous
()
x_fp8_reshape
=
x_fp8
.
reshape
(
-
1
,
dims
[
-
2
],
dims
[
-
1
])
# Check the dequantized result
torch
.
testing
.
assert_close
(
x_fp8_reshape
.
dequantize
().
contiguous
(),
x_hp_reshape
,
**
_tols
[
fp8_dtype
]
)
# Check the bitwise equality of the inner data
assert
is_bitwise_equal
(
x_fp8_reshape
.
_rowwise_data
,
x_fp8
.
_rowwise_data
)
assert
is_bitwise_equal
(
x_fp8_reshape
.
_rowwise_scale_inv
,
x_fp8
.
_rowwise_scale_inv
)
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
...
...
tests/pytorch/test_float8tensor.py
View file @
f8c2af4c
...
@@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
...
@@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer
,
Float8CurrentScalingQuantizer
,
)
)
from
transformer_engine.pytorch.constants
import
TE_DType
,
TE_DType_To_Torch
from
transformer_engine.pytorch.constants
import
TE_DType
,
TE_DType_To_Torch
from
transformer_engine.pytorch.utils
import
non_tn_fp8_gemm_supported
from
transformer_engine.pytorch.utils
import
is_
non_tn_fp8_gemm_supported
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
references.ref_per_tensor_cs
import
ref_per_tensor_cs_cast
from
references.ref_per_tensor_cs
import
ref_per_tensor_cs_cast
...
@@ -400,7 +400,7 @@ class TestCurrentScalingFloat8Tensor:
...
@@ -400,7 +400,7 @@ class TestCurrentScalingFloat8Tensor:
"""Check numerical error when casting to FP8"""
"""Check numerical error when casting to FP8"""
# Skip invalid configurations
# Skip invalid configurations
if
non_tn_fp8_gemm_supported
()
and
return_transpose
:
if
is_
non_tn_fp8_gemm_supported
()
and
return_transpose
:
pytest
.
skip
(
"FP8 transpose is neither needed nor supported on current system"
)
pytest
.
skip
(
"FP8 transpose is neither needed nor supported on current system"
)
# Initialize random high precision data
# Initialize random high precision data
...
...
tests/pytorch/test_fused_optimizer.py
View file @
f8c2af4c
...
@@ -12,10 +12,11 @@ from torch import nn
...
@@ -12,10 +12,11 @@ from torch import nn
from
torch.testing._internal.common_device_type
import
largeTensorTest
from
torch.testing._internal.common_device_type
import
largeTensorTest
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.pytorch.attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention
.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch
import
fp8_model_init
from
transformer_engine.pytorch
import
fp8_model_init
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
gpu_autocast_ctx
# Check if FP8 is supported
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
...
@@ -596,7 +597,7 @@ class AdamTest:
...
@@ -596,7 +597,7 @@ class AdamTest:
gt_
=
gt
.
clone
()
gt_
=
gt
.
clone
()
# Reference
# Reference
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model
(
x
)
y
=
self
.
model
(
x
)
loss
=
((
gt
-
y
)
**
2
).
mean
()
loss
=
((
gt
-
y
)
**
2
).
mean
()
...
@@ -605,7 +606,7 @@ class AdamTest:
...
@@ -605,7 +606,7 @@ class AdamTest:
scaler
.
update
()
scaler
.
update
()
# DUT
# DUT
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model_
(
x
)
y
=
self
.
model_
(
x
)
loss_
=
((
gt_
-
y
)
**
2
).
mean
()
loss_
=
((
gt_
-
y
)
**
2
).
mean
()
...
@@ -647,7 +648,7 @@ class AdamTest:
...
@@ -647,7 +648,7 @@ class AdamTest:
gt_
=
gt
.
clone
()
gt_
=
gt
.
clone
()
# Reference
# Reference
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model
(
x
)
y
=
self
.
model
(
x
)
loss
=
((
gt
-
y
)
**
2
).
mean
()
loss
=
((
gt
-
y
)
**
2
).
mean
()
...
@@ -656,7 +657,7 @@ class AdamTest:
...
@@ -656,7 +657,7 @@ class AdamTest:
scaler
.
update
()
scaler
.
update
()
# DUT
# DUT
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model_
(
x
)
y
=
self
.
model_
(
x
)
loss_
=
((
gt_
-
y
)
**
2
).
mean
()
loss_
=
((
gt_
-
y
)
**
2
).
mean
()
...
@@ -705,7 +706,7 @@ class AdamTest:
...
@@ -705,7 +706,7 @@ class AdamTest:
gt_
=
gt
.
clone
()
gt_
=
gt
.
clone
()
# Reference
# Reference
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model
(
x
)
y
=
self
.
model
(
x
)
loss
=
((
gt
-
y
)
**
2
).
mean
()
loss
=
((
gt
-
y
)
**
2
).
mean
()
...
@@ -714,7 +715,7 @@ class AdamTest:
...
@@ -714,7 +715,7 @@ class AdamTest:
scaler
.
update
()
scaler
.
update
()
# DUT
# DUT
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model_
(
x
)
y
=
self
.
model_
(
x
)
loss_
=
((
gt_
-
y
)
**
2
).
mean
()
loss_
=
((
gt_
-
y
)
**
2
).
mean
()
...
...
tests/pytorch/test_fused_rope.py
View file @
f8c2af4c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
from
typing
import
Callable
,
Tuple
,
Union
import
math
import
math
import
pytest
import
torch
import
torch
from
typing
import
Callable
,
Tuple
,
Union
import
pytest
from
transformer_engine.pytorch.
dot_product_
attention.rope
import
(
from
transformer_engine.pytorch.attention.rope
import
(
RotaryPositionEmbedding
,
RotaryPositionEmbedding
,
apply_rotary_pos_emb
,
apply_rotary_pos_emb
,
)
)
...
@@ -22,6 +22,7 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
...
@@ -22,6 +22,7 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return
torch
.
sum
(
output
*
t
)
return
torch
.
sum
(
output
*
t
)
@
pytest
.
mark
.
parametrize
(
"start_positions"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"seq_length"
,
[
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"seq_length"
,
[
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
...
@@ -43,7 +44,17 @@ def test_fused_rope(
...
@@ -43,7 +44,17 @@ def test_fused_rope(
loss_func
:
Callable
,
loss_func
:
Callable
,
cp_size
:
int
,
cp_size
:
int
,
interleaved
:
bool
,
interleaved
:
bool
,
start_positions
:
bool
,
)
->
None
:
)
->
None
:
if
margin
==
0
and
start_positions
==
True
:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest
.
skip
(
"Skipping test with margin=0 and start_positions=True"
)
if
start_positions
==
True
and
cp_size
>
1
:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest
.
skip
(
"Skipping test with cp_size>1 and start_positions=True"
)
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
batch_size
,
head_num
=
2
,
64
batch_size
,
head_num
=
2
,
64
t
=
torch
.
rand
(
t
=
torch
.
rand
(
...
@@ -51,6 +62,14 @@ def test_fused_rope(
...
@@ -51,6 +62,14 @@ def test_fused_rope(
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions
=
(
torch
.
randint
(
0
,
margin
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
if
start_positions
else
None
)
if
tensor_format
==
"bshd"
:
if
tensor_format
==
"bshd"
:
t
=
t
.
transpose
(
0
,
1
).
contiguous
()
t
=
t
.
transpose
(
0
,
1
).
contiguous
()
if
transpose
:
if
transpose
:
...
@@ -69,14 +88,18 @@ def test_fused_rope(
...
@@ -69,14 +88,18 @@ def test_fused_rope(
t
.
float
(),
t
.
float
(),
emb
,
emb
,
tensor_format
=
tensor_format
,
tensor_format
=
tensor_format
,
start_positions
=
start_positions
,
interleaved
=
interleaved
,
interleaved
=
interleaved
,
fused
=
False
,
fused
=
False
,
cp_size
=
cp_size
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
cp_rank
=
cp_rank
,
).
to
(
dtype
)
).
to
(
dtype
)
loss_unfused
=
loss_func
(
output_unfused
)
loss_unfused
=
loss_func
(
output_unfused
)
loss_unfused
.
backward
()
grad_unfused
=
t
.
grad
.
detach
().
clone
()
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_unfused
.
backward
()
grad_unfused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
t
.
grad
=
None
# fused
# fused
...
@@ -84,21 +107,29 @@ def test_fused_rope(
...
@@ -84,21 +107,29 @@ def test_fused_rope(
t
,
t
,
emb
,
emb
,
tensor_format
=
tensor_format
,
tensor_format
=
tensor_format
,
start_positions
=
start_positions
,
interleaved
=
interleaved
,
interleaved
=
interleaved
,
fused
=
True
,
fused
=
True
,
cp_size
=
cp_size
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
cp_rank
=
cp_rank
,
)
)
loss_fused
=
loss_func
(
output_fused
)
loss_fused
=
loss_func
(
output_fused
)
loss_fused
.
backward
()
grad_fused
=
t
.
grad
.
detach
().
clone
()
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_fused
.
backward
()
grad_fused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
t
.
grad
=
None
torch
.
testing
.
assert_close
(
output_fused
,
output_unfused
)
torch
.
testing
.
assert_close
(
output_fused
,
output_unfused
)
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
assert
output_fused
.
is_contiguous
()
assert
output_fused
.
is_contiguous
()
@
pytest
.
mark
.
parametrize
(
"margin"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"start_positions"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"rotary_percent"
,
[
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"rotary_percent"
,
[
0.5
,
1.0
])
...
@@ -114,10 +145,25 @@ def test_fused_rope_thd(
...
@@ -114,10 +145,25 @@ def test_fused_rope_thd(
loss_func
:
Callable
,
loss_func
:
Callable
,
cp_size
:
int
,
cp_size
:
int
,
interleaved
:
bool
,
interleaved
:
bool
,
start_positions
:
bool
,
margin
:
int
,
)
->
None
:
)
->
None
:
if
start_positions
==
True
and
cp_size
>
1
:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest
.
skip
(
"Skipping test with cp_size>1 and start_positions=True"
)
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
batch_size
,
head_num
=
2
,
64
batch_size
,
head_num
=
2
,
64
cu_seqlens
=
[
0
,
400
,
542
,
711
,
727
,
752
,
1270
,
1426
,
1450
,
1954
,
2044
,
2048
]
cu_seqlens
=
[
0
,
400
,
542
,
711
,
727
,
752
,
1270
,
1426
,
1450
,
1954
,
2044
,
2048
]
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions
=
(
torch
.
randint
(
0
,
margin
,
(
len
(
cu_seqlens
)
-
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
if
start_positions
else
None
)
if
cp_size
>
1
:
if
cp_size
>
1
:
cu_seqlens_padded
=
[
0
]
cu_seqlens_padded
=
[
0
]
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
...
@@ -152,6 +198,7 @@ def test_fused_rope_thd(
...
@@ -152,6 +198,7 @@ def test_fused_rope_thd(
output_unfused
=
apply_rotary_pos_emb
(
output_unfused
=
apply_rotary_pos_emb
(
t
.
float
(),
t
.
float
(),
emb
,
emb
,
start_positions
=
start_positions
,
tensor_format
=
"thd"
,
tensor_format
=
"thd"
,
interleaved
=
interleaved
,
interleaved
=
interleaved
,
fused
=
False
,
fused
=
False
,
...
@@ -160,14 +207,17 @@ def test_fused_rope_thd(
...
@@ -160,14 +207,17 @@ def test_fused_rope_thd(
cp_rank
=
cp_rank
,
cp_rank
=
cp_rank
,
).
to
(
dtype
)
).
to
(
dtype
)
loss_unfused
=
loss_func
(
output_unfused
)
loss_unfused
=
loss_func
(
output_unfused
)
loss_unfused
.
backward
()
grad_unfused
=
t
.
grad
.
detach
().
clone
()
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_unfused
.
backward
()
grad_unfused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
t
.
grad
=
None
# fused
# fused
output_fused
=
apply_rotary_pos_emb
(
output_fused
=
apply_rotary_pos_emb
(
t
,
t
,
emb
,
emb
,
start_positions
=
start_positions
,
interleaved
=
interleaved
,
interleaved
=
interleaved
,
fused
=
True
,
fused
=
True
,
tensor_format
=
"thd"
,
tensor_format
=
"thd"
,
...
@@ -176,9 +226,15 @@ def test_fused_rope_thd(
...
@@ -176,9 +226,15 @@ def test_fused_rope_thd(
cp_rank
=
cp_rank
,
cp_rank
=
cp_rank
,
)
)
loss_fused
=
loss_func
(
output_fused
)
loss_fused
=
loss_func
(
output_fused
)
loss_fused
.
backward
()
grad_fused
=
t
.
grad
.
detach
().
clone
()
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_fused
.
backward
()
grad_fused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
t
.
grad
=
None
torch
.
testing
.
assert_close
(
output_fused
,
output_unfused
)
torch
.
testing
.
assert_close
(
output_fused
,
output_unfused
)
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
assert
output_fused
.
is_contiguous
()
tests/pytorch/test_multi_tensor.py
View file @
f8c2af4c
...
@@ -160,7 +160,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
...
@@ -160,7 +160,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
normab
=
torch
.
cat
((
a
.
norm
().
view
(
1
),
b
.
norm
().
view
(
1
)))
normab
=
torch
.
cat
((
a
.
norm
().
view
(
1
),
b
.
norm
().
view
(
1
)))
norm_per_tensor
=
norm_per_tensor
.
view
(
-
1
,
2
)
norm_per_tensor
=
norm_per_tensor
.
view
(
-
1
,
2
)
else
:
else
:
norm
,
_
=
applier
(
tex
.
multi_tensor_l2norm
,
overflow_buf
,
[
in_list
],
Tru
e
)
norm
,
_
=
applier
(
tex
.
multi_tensor_l2norm
,
overflow_buf
,
[
in_list
],
Fals
e
)
reference
=
torch
.
full
(
reference
=
torch
.
full
(
[(
sizea
+
sizeb
)
*
repeat
],
val
,
dtype
=
torch
.
float32
,
device
=
device
[(
sizea
+
sizeb
)
*
repeat
],
val
,
dtype
=
torch
.
float32
,
device
=
device
...
...
tests/pytorch/test_numerics.py
View file @
f8c2af4c
...
@@ -7,7 +7,6 @@ import math
...
@@ -7,7 +7,6 @@ import math
import
os
import
os
from
typing
import
Dict
,
List
,
Tuple
,
Optional
from
typing
import
Dict
,
List
,
Tuple
,
Optional
import
pytest
import
pytest
import
copy
import
random
import
random
import
torch
import
torch
...
@@ -40,12 +39,12 @@ from transformer_engine.pytorch import (
...
@@ -40,12 +39,12 @@ from transformer_engine.pytorch import (
Fp8Unpadding
,
Fp8Unpadding
,
)
)
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.
dot_product_
attention.inference
import
InferenceParams
from
transformer_engine.pytorch.attention.inference
import
InferenceParams
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
,
get_cudnn_version
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
...
@@ -135,18 +134,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
...
@@ -135,18 +134,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
def
assert_allclose
(
def
assert_allclose
(
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
atol
:
float
,
rtol
:
float
=
None
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
atol
:
float
=
None
,
rtol
:
float
=
None
)
->
bool
:
)
->
bool
:
"""Ensures two lists are equal."""
"""Ensures two lists are equal."""
assert
len
(
l1
)
==
len
(
l2
),
"Unequal number of outputs."
assert
len
(
l1
)
==
len
(
l2
),
"Unequal number of outputs."
for
i
,
(
t1
,
t2
)
in
enumerate
(
zip
(
l1
,
l2
)):
for
i
,
(
t1
,
t2
)
in
enumerate
(
zip
(
l1
,
l2
)):
tols
=
d
ict
(
atol
=
atol
)
tols
=
d
type_tols
(
t2
.
dtype
)
if
rtol
is
not
None
:
if
rtol
is
not
None
:
tols
[
"rtol"
]
=
rtol
tols
[
"rtol"
]
=
rtol
if
atol
is
not
None
:
tols
[
"atol"
]
=
atol
result
=
torch
.
allclose
(
t1
,
t2
,
**
tols
)
result
=
torch
.
allclose
(
t1
,
t2
,
**
tols
)
if
not
result
:
if
not
result
:
diff
=
torch
.
abs
(
t1
-
t2
)
diff
=
torch
.
abs
(
t1
-
t2
)
tol
=
atol
+
(
rtol
*
torch
.
abs
(
t2
))
tol
=
tols
[
"
atol
"
]
+
(
tols
[
"
rtol
"
]
*
torch
.
abs
(
t2
))
exceed_mask
=
diff
>
tol
exceed_mask
=
diff
>
tol
if
exceed_mask
.
any
():
if
exceed_mask
.
any
():
indices
=
torch
.
nonzero
(
exceed_mask
,
as_tuple
=
True
)
indices
=
torch
.
nonzero
(
exceed_mask
,
as_tuple
=
True
)
...
@@ -2304,6 +2305,12 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
...
@@ -2304,6 +2305,12 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
pytest
.
skip
(
"FusedAttention and FlashAttention do not support FP32"
)
pytest
.
skip
(
"FusedAttention and FlashAttention do not support FP32"
)
if
use_RoPE
:
if
use_RoPE
:
pytest
.
skip
(
"KV cache does not support starting positions for RoPE"
)
pytest
.
skip
(
"KV cache does not support starting positions for RoPE"
)
if
(
backend
==
"FusedAttention"
and
get_device_compute_capability
()
==
(
8
,
9
)
and
get_cudnn_version
()
<
(
9
,
11
,
0
)
):
pytest
.
skip
(
"Skip KV cache for sm89 and cuDNN < 9.11"
)
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
...
...
tests/pytorch/test_parallel_cross_entropy.py
View file @
f8c2af4c
...
@@ -19,11 +19,12 @@ class TestParallelCrossEntropy:
...
@@ -19,11 +19,12 @@ class TestParallelCrossEntropy:
label_smoothing
=
label_smoothing
,
reduction
=
"mean"
if
reduce_loss
else
"none"
label_smoothing
=
label_smoothing
,
reduction
=
"mean"
if
reduce_loss
else
"none"
)
)
def
generate_input
(
self
,
dtype
:
torch
.
dtype
,
swap_dim
:
bool
):
def
generate_input
(
self
,
dtype
:
torch
.
dtype
,
swap_dim
:
bool
,
ignore_idx
:
bool
):
SQ
=
random
.
choice
([
64
,
128
])
SQ
=
random
.
choice
([
64
,
128
])
batch
=
random
.
choice
([
1
,
2
])
batch
=
random
.
choice
([
1
,
2
])
vocab
=
random
.
choice
([
64000
,
128000
])
vocab
=
random
.
choice
([
64000
,
128000
])
ignore
=
random
.
sample
(
range
(
0
,
SQ
-
1
),
5
)
if
swap_dim
:
if
swap_dim
:
self
.
input_test
=
torch
.
rand
((
SQ
,
batch
,
vocab
),
dtype
=
dtype
).
cuda
()
self
.
input_test
=
torch
.
rand
((
SQ
,
batch
,
vocab
),
dtype
=
dtype
).
cuda
()
...
@@ -32,14 +33,27 @@ class TestParallelCrossEntropy:
...
@@ -32,14 +33,27 @@ class TestParallelCrossEntropy:
self
.
input_test
=
torch
.
rand
((
batch
,
SQ
,
vocab
),
dtype
=
dtype
).
cuda
()
self
.
input_test
=
torch
.
rand
((
batch
,
SQ
,
vocab
),
dtype
=
dtype
).
cuda
()
self
.
tar_test
=
torch
.
randint
(
0
,
vocab
,
(
batch
,
SQ
)).
cuda
()
self
.
tar_test
=
torch
.
randint
(
0
,
vocab
,
(
batch
,
SQ
)).
cuda
()
if
ignore_idx
:
for
i
in
ignore
:
# Ignore 5 indices
if
swap_dim
:
self
.
tar_test
[
i
][
0
]
=
-
100
else
:
self
.
tar_test
[
0
][
i
]
=
-
100
self
.
input_ref
=
torch
.
reshape
(
self
.
input_test
.
clone
().
detach
(),
(
batch
*
SQ
,
vocab
))
self
.
input_ref
=
torch
.
reshape
(
self
.
input_test
.
clone
().
detach
(),
(
batch
*
SQ
,
vocab
))
self
.
tar_ref
=
torch
.
reshape
(
self
.
tar_test
.
clone
().
detach
(),
(
batch
*
SQ
,))
self
.
tar_ref
=
torch
.
reshape
(
self
.
tar_test
.
clone
().
detach
(),
(
batch
*
SQ
,))
def
one_iteration_test
(
def
one_iteration_test
(
self
,
dtype
:
torch
.
dtype
,
swap_dim
:
bool
,
label_smoothing
:
float
,
reduce_loss
:
bool
self
,
dtype
:
torch
.
dtype
,
swap_dim
:
bool
,
label_smoothing
:
float
,
reduce_loss
:
bool
,
ignore_idx
:
bool
=
False
,
):
):
self
.
generate_input
(
dtype
,
swap_dim
)
self
.
generate_input
(
dtype
,
swap_dim
,
ignore_idx
)
self
.
input_test
.
requires_grad_
(
True
)
self
.
input_test
.
requires_grad_
(
True
)
self
.
input_ref
.
requires_grad_
(
True
)
self
.
input_ref
.
requires_grad_
(
True
)
...
@@ -57,6 +71,8 @@ class TestParallelCrossEntropy:
...
@@ -57,6 +71,8 @@ class TestParallelCrossEntropy:
test_loss
=
torch
.
flatten
(
test_loss
)
if
not
reduce_loss
else
test_loss
test_loss
=
torch
.
flatten
(
test_loss
)
if
not
reduce_loss
else
test_loss
torch
.
testing
.
assert_close
(
test_loss
,
ref_loss
,
check_dtype
=
False
)
torch
.
testing
.
assert_close
(
test_loss
,
ref_loss
,
check_dtype
=
False
)
if
ignore_idx
:
print
(
test_loss
,
ref_loss
)
if
reduce_loss
:
if
reduce_loss
:
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
torch
.
flatten
(
self
.
input_test
.
grad
,
start_dim
=
0
,
end_dim
=
1
),
self
.
input_ref
.
grad
torch
.
flatten
(
self
.
input_test
.
grad
,
start_dim
=
0
,
end_dim
=
1
),
self
.
input_ref
.
grad
...
@@ -106,3 +122,15 @@ class TestParallelCrossEntropy:
...
@@ -106,3 +122,15 @@ class TestParallelCrossEntropy:
self
.
one_iteration_test
(
self
.
one_iteration_test
(
dtype
=
torch
.
float32
,
swap_dim
=
False
,
label_smoothing
=
0
,
reduce_loss
=
False
dtype
=
torch
.
float32
,
swap_dim
=
False
,
label_smoothing
=
0
,
reduce_loss
=
False
)
)
def
test_ignore_idx
(
self
):
self
.
generate_iters
(
5
)
self
.
generate_infra
(
False
,
0
)
for
i
in
range
(
self
.
iters
):
self
.
one_iteration_test
(
dtype
=
torch
.
float32
,
swap_dim
=
random
.
choice
([
True
,
False
]),
label_smoothing
=
0
,
reduce_loss
=
False
,
ignore_idx
=
True
,
)
tests/pytorch/test_sanity.py
View file @
f8c2af4c
...
@@ -373,7 +373,9 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
...
@@ -373,7 +373,9 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
def
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
):
def
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
,
microbatching
=
True
):
if
skip_dgrad
and
skip_wgrad
:
if
skip_dgrad
and
skip_wgrad
:
pytest
.
skip
(
"No gradient computation; Skipping to avoid PyTorch RuntimeError."
)
pytest
.
skip
(
"No gradient computation; Skipping to avoid PyTorch RuntimeError."
)
...
@@ -389,7 +391,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad
...
@@ -389,7 +391,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad
use_fp8
=
fp8_recipe
is
not
None
use_fp8
=
fp8_recipe
is
not
None
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
te_out
=
block
(
te_inp
)
if
not
microbatching
:
te_out
=
block
(
te_inp
)
else
:
_
=
block
(
te_inp
,
is_first_microbatch
=
True
)
te_out
=
block
(
te_inp
,
is_first_microbatch
=
False
)
if
isinstance
(
te_out
,
tuple
):
if
isinstance
(
te_out
,
tuple
):
te_out
=
te_out
[
0
]
te_out
=
te_out
[
0
]
loss
=
te_out
.
sum
()
loss
=
te_out
.
sum
()
...
@@ -443,8 +449,16 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
...
@@ -443,8 +449,16 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_dgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_dgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"microbatching"
,
all_boolean
)
def
test_sanity_layernorm_linear
(
def
test_sanity_layernorm_linear
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
skip_dgrad
,
normalization
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
skip_dgrad
,
normalization
,
microbatching
,
):
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -470,7 +484,7 @@ def test_sanity_layernorm_linear(
...
@@ -470,7 +484,7 @@ def test_sanity_layernorm_linear(
params_dtype
=
dtype
,
params_dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
,
microbatching
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
...
@@ -478,7 +492,8 @@ def test_sanity_layernorm_linear(
...
@@ -478,7 +492,8 @@ def test_sanity_layernorm_linear(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
,
"weird"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
,
"weird"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_dgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_dgrad"
,
all_boolean
)
def
test_sanity_linear
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
skip_dgrad
):
@
pytest
.
mark
.
parametrize
(
"microbatching"
,
all_boolean
)
def
test_sanity_linear
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
skip_dgrad
,
microbatching
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
...
@@ -501,7 +516,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
...
@@ -501,7 +516,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
params_dtype
=
dtype
,
params_dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
,
microbatching
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
...
@@ -600,8 +615,17 @@ def test_sanity_grouped_linear(
...
@@ -600,8 +615,17 @@ def test_sanity_grouped_linear(
@
pytest
.
mark
.
parametrize
(
"skip_dgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_dgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
all_activations
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
all_activations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"microbatching"
,
all_boolean
)
def
test_sanity_layernorm_mlp
(
def
test_sanity_layernorm_mlp
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
skip_dgrad
,
activation
,
normalization
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
skip_dgrad
,
activation
,
normalization
,
microbatching
,
):
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -630,7 +654,7 @@ def test_sanity_layernorm_mlp(
...
@@ -630,7 +654,7 @@ def test_sanity_layernorm_mlp(
params_dtype
=
dtype
,
params_dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
,
microbatching
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
...
...
transformer_engine/__init__.py
View file @
f8c2af4c
...
@@ -11,12 +11,12 @@ import transformer_engine.common
...
@@ -11,12 +11,12 @@ import transformer_engine.common
try
:
try
:
from
.
import
pytorch
from
.
import
pytorch
except
(
ImportError
,
StopIteration
)
as
e
:
except
ImportError
as
e
:
pass
pass
try
:
try
:
from
.
import
jax
from
.
import
jax
except
(
ImportError
,
StopIteration
)
as
e
:
except
ImportError
as
e
:
pass
pass
__version__
=
str
(
metadata
.
version
(
"transformer_engine"
))
__version__
=
str
(
metadata
.
version
(
"transformer_engine"
))
transformer_engine/common/CMakeLists.txt
View file @
f8c2af4c
...
@@ -111,6 +111,11 @@ if(USE_CUDA)
...
@@ -111,6 +111,11 @@ if(USE_CUDA)
cudnn_utils.cpp
cudnn_utils.cpp
transformer_engine.cpp
transformer_engine.cpp
common.cu
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/cast_transpose_fusion.cu
...
@@ -148,6 +153,7 @@ if(USE_CUDA)
...
@@ -148,6 +153,7 @@ if(USE_CUDA)
fused_rope/fused_rope.cu
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/userbuffers/userbuffers.cu
...
@@ -158,6 +164,11 @@ else()
...
@@ -158,6 +164,11 @@ else()
cudnn_utils.cpp
cudnn_utils.cpp
transformer_engine.cpp
transformer_engine.cpp
common.cu
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/cast_transpose_fusion.cu
...
@@ -191,6 +202,7 @@ else()
...
@@ -191,6 +202,7 @@ else()
fused_rope/fused_rope.cu
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/userbuffers/userbuffers.cu
...
@@ -345,6 +357,14 @@ target_include_directories(transformer_engine PRIVATE
...
@@ -345,6 +357,14 @@ target_include_directories(transformer_engine PRIVATE
set_source_files_properties
(
fused_softmax/scaled_masked_softmax.cu
set_source_files_properties
(
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
COMPILE_OPTIONS
"--use_fast_math"
)
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
...
...
transformer_engine/common/__init__.py
View file @
f8c2af4c
...
@@ -9,28 +9,193 @@ import glob
...
@@ -9,28 +9,193 @@ import glob
import
sysconfig
import
sysconfig
import
subprocess
import
subprocess
import
ctypes
import
ctypes
import
logging
import
os
import
os
import
platform
import
platform
import
importlib
import
functools
from
pathlib
import
Path
from
pathlib
import
Path
from
importlib.metadata
import
version
,
metadata
,
PackageNotFoundError
import
transformer_engine
_logger
=
logging
.
getLogger
(
__name__
)
def
is_package_installed
(
package
):
"""Checks if a pip package is installed."""
@
functools
.
lru_cache
(
maxsize
=
None
)
return
(
def
_is_pip_package_installed
(
package
):
subprocess
.
run
(
"""Check if the given package is installed via pip."""
[
sys
.
executable
,
"-m"
,
"pip"
,
"show"
,
package
],
capture_output
=
True
,
check
=
False
).
returncode
# This is needed because we only want to return true
==
0
# if the python package is installed via pip, and not
# if it's importable in the current directory due to
# the presence of the shared library module.
try
:
metadata
(
package
)
except
PackageNotFoundError
:
return
False
return
True
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_find_shared_object_in_te_dir
(
te_path
:
Path
,
prefix
:
str
):
"""
Find a shared object file of given prefix in the top level TE directory.
Only the following locations are searched to avoid stray SOs and build
artifacts:
1. The given top level directory (editable install).
2. `transformer_engine` named directories (source install).
3. `wheel_lib` named directories (PyPI install).
Returns None if no shared object files are found.
Raises an error if multiple shared object files are found.
"""
# Ensure top level dir exists and has the module. before searching.
if
not
te_path
.
exists
()
or
not
(
te_path
/
"transformer_engine"
).
exists
():
return
None
files
=
[]
search_paths
=
(
te_path
,
te_path
/
"transformer_engine"
,
te_path
/
"transformer_engine/wheel_lib"
,
te_path
/
"wheel_lib"
,
)
)
# Search.
for
dirname
,
_
,
names
in
os
.
walk
(
te_path
):
if
Path
(
dirname
)
in
search_paths
:
for
name
in
names
:
if
name
.
startswith
(
prefix
)
and
name
.
endswith
(
f
".
{
_get_sys_extension
()
}
"
):
files
.
append
(
Path
(
dirname
,
name
))
if
len
(
files
)
==
0
:
return
None
if
len
(
files
)
==
1
:
return
files
[
0
]
raise
RuntimeError
(
f
"Multiple files found:
{
files
}
"
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_shared_object_file
(
library
:
str
)
->
Path
:
"""
Return the path of the shared object file for the given TE
library, one of 'core', 'torch', or 'jax'.
Several factors affect finding the correct location of the shared object:
1. System and environment.
2. If the installation is from source or via PyPI.
- Source installed .sos are placed in top level dir
- Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts.
3. For source installations, is the install editable/inplace?
4. The user directory from where TE is being imported.
"""
# Check provided input and determine the correct prefix for .so.
assert
library
in
(
"core"
,
"torch"
,
"jax"
),
f
"Unsupported TE library
{
library
}
."
if
library
==
"core"
:
so_prefix
=
"libtransformer_engine"
else
:
so_prefix
=
f
"transformer_engine_
{
library
}
"
# Check TE install location (will be local if TE is available in current dir for import).
te_install_dir
=
Path
(
importlib
.
util
.
find_spec
(
"transformer_engine"
).
origin
).
parent
.
parent
so_path_in_install_dir
=
_find_shared_object_in_te_dir
(
te_install_dir
,
so_prefix
)
# Check default python package install location in system.
site_packages_dir
=
Path
(
sysconfig
.
get_paths
()[
"purelib"
])
so_path_in_default_dir
=
_find_shared_object_in_te_dir
(
site_packages_dir
,
so_prefix
)
# Case 1: Typical user workflow: Both locations are the same, return any result.
if
te_install_dir
==
site_packages_dir
:
assert
(
so_path_in_install_dir
is
not
None
),
f
"Could not find shared object file for Transformer Engine
{
library
}
lib."
return
so_path_in_install_dir
# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
# editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic.
if
so_path_in_install_dir
is
not
None
and
so_path_in_default_dir
is
not
None
:
raise
RuntimeError
(
f
"Found multiple shared object files:
{
so_path_in_install_dir
}
and"
f
"
{
so_path_in_default_dir
}
. Remove local shared objects installed"
f
" here
{
so_path_in_install_dir
}
or change the working directory to"
"execute from outside TE."
)
# Case 3: Typical dev workflow: Editable install
if
so_path_in_install_dir
is
not
None
:
return
so_path_in_install_dir
# Case 4: Executing from inside a TE directory without an inplace build available.
if
so_path_in_default_dir
is
not
None
:
return
so_path_in_default_dir
raise
RuntimeError
(
f
"Could not find shared object file for Transformer Engine
{
library
}
lib."
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
load_framework_extension
(
framework
:
str
):
"""
Load shared library with Transformer Engine framework bindings
and check verify correctness if installed via PyPI.
"""
# Supported frameworks.
assert
framework
in
(
"jax"
,
"torch"
),
f
"Unsupported framework
{
framework
}
"
# Name of the framework extension library.
module_name
=
f
"transformer_engine_
{
framework
}
"
# Name of the pip extra dependency for framework extensions from PyPI.
extra_dep_name
=
module_name
if
framework
==
"torch"
:
extra_dep_name
=
"pytorch"
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if
_is_pip_package_installed
(
module_name
):
assert
_is_pip_package_installed
(
"transformer_engine"
),
"Could not find `transformer-engine`."
assert
_is_pip_package_installed
(
"transformer_engine_cu12"
),
"Could not find `transformer-engine-cu12`."
assert
(
version
(
module_name
)
==
version
(
"transformer-engine"
)
==
version
(
"transformer-engine-cu12"
)
),
(
"TransformerEngine package version mismatch. Found"
f
"
{
module_name
}
v
{
version
(
module_name
)
}
, transformer-engine"
f
" v
{
version
(
'transformer-engine'
)
}
, and transformer-engine-cu12"
f
" v
{
version
(
'transformer-engine-cu12'
)
}
. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
)
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if
_is_pip_package_installed
(
"transformer-engine-cu12"
):
if
not
_is_pip_package_installed
(
module_name
):
_logger
.
info
(
"Could not find package %s. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
,
module_name
,
)
def
get_te_path
():
# After all checks are completed, load the shared object file.
"""Find Transformer Engine install path using pip"""
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
_get_shared_object_file
(
framework
))
return
Path
(
transformer_engine
.
__path__
[
0
]).
parent
solib
=
importlib
.
util
.
module_from_spec
(
spec
)
sys
.
modules
[
module_name
]
=
solib
spec
.
loader
.
exec_module
(
solib
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_sys_extension
():
def
_get_sys_extension
():
system
=
platform
.
system
()
system
=
platform
.
system
()
if
system
==
"Linux"
:
if
system
==
"Linux"
:
...
@@ -45,20 +210,47 @@ def _get_sys_extension():
...
@@ -45,20 +210,47 @@ def _get_sys_extension():
return
extension
return
extension
def
_load_cudnn
():
@
functools
.
lru_cache
(
maxsize
=
None
)
"""Load CUDNN shared library."""
def
_load_nvidia_cuda_library
(
lib_name
:
str
):
# Attempt to locate cuDNN in Python dist-packages
"""
lib_path
=
glob
.
glob
(
Attempts to load shared object file installed via pip.
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
join
(
sysconfig
.
get_path
(
"purelib"
),
sysconfig
.
get_path
(
"purelib"
),
f
"nvidia/
cudnn
/lib/lib
cudnn
.
{
_get_sys_extension
()
}
.*[0-9]"
,
f
"nvidia/
{
lib_name
}
/lib/lib
*
.
{
_get_sys_extension
()
}
.*[0-9]"
,
)
)
)
)
if
lib_path
:
assert
(
path_found
=
len
(
so_paths
)
>
0
len
(
lib_path
)
==
1
ctypes_handles
=
[]
),
f
"Found
{
len
(
lib_path
)
}
libcudnn.
{
_get_sys_extension
()
}
.x in nvidia-cudnn-cuXX."
return
ctypes
.
CDLL
(
lib_path
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
if
path_found
:
for
so_path
in
so_paths
:
ctypes_handles
.
append
(
ctypes
.
CDLL
(
so_path
,
mode
=
ctypes
.
RTLD_GLOBAL
))
return
path_found
,
ctypes_handles
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_nvidia_cudart_include_dir
():
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
try
:
import
nvidia
except
ModuleNotFoundError
:
return
""
include_dir
=
Path
(
nvidia
.
__file__
).
parent
/
"cuda_runtime"
return
str
(
include_dir
)
if
include_dir
.
exists
()
else
""
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_cudnn
():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home
=
os
.
environ
.
get
(
"CUDNN_HOME"
)
or
os
.
environ
.
get
(
"CUDNN_PATH"
)
cudnn_home
=
os
.
environ
.
get
(
"CUDNN_HOME"
)
or
os
.
environ
.
get
(
"CUDNN_PATH"
)
...
@@ -75,28 +267,16 @@ def _load_cudnn():
...
@@ -75,28 +267,16 @@ def _load_cudnn():
if
libs
:
if
libs
:
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# Attempt to locate cuDNN in Python dist-packages
found
,
handle
=
_load_nvidia_cuda_library
(
"cudnn"
)
if
found
:
return
handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return
ctypes
.
CDLL
(
f
"libcudnn.
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
return
ctypes
.
CDLL
(
f
"libcudnn.
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
def
_load_library
():
@
functools
.
lru_cache
(
maxsize
=
None
)
"""Load shared library with Transformer Engine C extensions"""
so_path
=
get_te_path
()
/
"transformer_engine"
/
f
"libtransformer_engine.
{
_get_sys_extension
()
}
"
if
not
so_path
.
exists
():
so_path
=
(
get_te_path
()
/
"transformer_engine"
/
"wheel_lib"
/
f
"libtransformer_engine.
{
_get_sys_extension
()
}
"
)
if
not
so_path
.
exists
():
so_path
=
get_te_path
()
/
f
"libtransformer_engine.
{
_get_sys_extension
()
}
"
assert
so_path
.
exists
(),
f
"Could not find libtransformer_engine.
{
_get_sys_extension
()
}
"
return
ctypes
.
CDLL
(
so_path
,
mode
=
ctypes
.
RTLD_GLOBAL
)
def
_load_nvrtc
():
def
_load_nvrtc
():
"""Load NVRTC shared library."""
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
...
@@ -107,6 +287,11 @@ def _load_nvrtc():
...
@@ -107,6 +287,11 @@ def _load_nvrtc():
if
libs
:
if
libs
:
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# Attempt to locate NVRTC in Python dist-packages
found
,
handle
=
_load_nvidia_cuda_library
(
"cuda_nvrtc"
)
if
found
:
return
handle
# Attempt to locate NVRTC via ldconfig
# Attempt to locate NVRTC via ldconfig
libs
=
subprocess
.
check_output
(
"ldconfig -p | grep 'libnvrtc'"
,
shell
=
True
)
libs
=
subprocess
.
check_output
(
"ldconfig -p | grep 'libnvrtc'"
,
shell
=
True
)
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
...
@@ -123,10 +308,22 @@ def _load_nvrtc():
...
@@ -123,10 +308,22 @@ def _load_nvrtc():
return
ctypes
.
CDLL
(
f
"libnvrtc.
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
return
ctypes
.
CDLL
(
f
"libnvrtc.
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_core_library
():
"""Load shared library with Transformer Engine C extensions"""
return
ctypes
.
CDLL
(
_get_shared_object_file
(
"core"
),
mode
=
ctypes
.
RTLD_GLOBAL
)
if
"NVTE_PROJECT_BUILDING"
not
in
os
.
environ
or
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
if
"NVTE_PROJECT_BUILDING"
not
in
os
.
environ
or
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
try
:
try
:
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
_CUBLAS_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cublas"
)
_CUDART_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cuda_runtime"
)
# Needed to find the correct headers for NVRTC kernels.
if
not
os
.
getenv
(
"NVTE_CUDA_INCLUDE_DIR"
)
and
_nvidia_cudart_include_dir
():
os
.
environ
[
"NVTE_CUDA_INCLUDE_DIR"
]
=
_nvidia_cudart_include_dir
()
except
OSError
:
except
OSError
:
pass
pass
_TE_LIB_CTYPES
=
_load_library
()
_TE_LIB_CTYPES
=
_load_core_library
()
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
f8c2af4c
...
@@ -21,12 +21,18 @@
...
@@ -21,12 +21,18 @@
#define HALF_BYTES 2
#define HALF_BYTES 2
#define UB_MAX_SM 32
#define UB_MAX_SM 32
#define AS_VECTOR(shape) std::vector<size_t>(shape.data, shape.data + shape.ndim)
using
namespace
std
::
placeholders
;
using
namespace
std
::
placeholders
;
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
{
std
::
vector
<
size_t
>
shape_to_vector
(
const
NVTEShape
&
shape
)
{
return
std
::
vector
<
size_t
>
(
shape
.
data
,
shape
.
data
+
shape
.
ndim
);
}
}
// namespace
/***************************************************************************************************
/***************************************************************************************************
* Comm+GEMM Overlap Common Core
* Comm+GEMM Overlap Common Core
**************************************************************************************************/
**************************************************************************************************/
...
@@ -147,13 +153,50 @@ CommOverlapCore::~CommOverlapCore() {
...
@@ -147,13 +153,50 @@ CommOverlapCore::~CommOverlapCore() {
TensorWrapper
CommOverlapCore
::
get_tensor_chunk
(
const
TensorWrapper
&
source
,
size_t
chunk_offset
,
TensorWrapper
CommOverlapCore
::
get_tensor_chunk
(
const
TensorWrapper
&
source
,
size_t
chunk_offset
,
const
std
::
vector
<
size_t
>
&
chunk_shape
)
{
const
std
::
vector
<
size_t
>
&
chunk_shape
)
{
TensorWrapper
chunk
;
const
auto
scaling_mode
=
source
.
scaling_mode
();
// Tensor dimensions
std
::
vector
<
size_t
>
shape
=
shape_to_vector
(
source
.
shape
());
auto
flatten_shape_to_2d
=
[](
const
std
::
vector
<
size_t
>
&
shape
)
->
std
::
pair
<
size_t
,
size_t
>
{
if
(
shape
.
empty
())
{
return
{
1
,
1
};
}
size_t
height
=
1
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
()
-
1
;
++
i
)
{
height
*=
shape
[
i
];
}
return
{
height
,
shape
.
back
()};
};
size_t
height
,
width
,
chunk_height
,
chunk_width
;
std
::
tie
(
height
,
width
)
=
flatten_shape_to_2d
(
shape
);
std
::
tie
(
chunk_height
,
chunk_width
)
=
flatten_shape_to_2d
(
chunk_shape
);
// Check tensor dimensions
#define NVTE_DIM_CHECK(cond, message) \
NVTE_CHECK(cond, message, " (tensor shape=", shape, ", chunk shape=", chunk_shape, \
", chunk offset=", chunk_offset, ")")
NVTE_DIM_CHECK
(
height
>
0
&&
width
>
0
,
"Attempted to get chunk from empty tensor"
);
NVTE_DIM_CHECK
(
chunk_height
>
0
&&
chunk_width
>
0
,
"Attempted to get empty tensor chunk"
);
NVTE_DIM_CHECK
(
chunk_height
<=
height
&&
chunk_width
<=
width
,
"Attempted to get out-of-bounds tensor chunk"
);
if
(
scaling_mode
==
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
)
{
// MXFP8 scale-inverses are padded to a 2D matrix with dims that
// are divisible by 128. UB doesn't handle this padding yet.
NVTE_DIM_CHECK
(
height
%
128
==
0
&&
width
%
128
==
0
,
"Userbuffers requires MXFP8 tensor dims that are divisible by 128"
);
NVTE_DIM_CHECK
(
chunk_height
%
128
==
0
&&
chunk_width
%
128
==
0
,
"Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128"
);
}
#undef NVTE_DIM_CHECK
// Construct tensor chunk
TensorWrapper
chunk
(
scaling_mode
);
for
(
int
param_id
=
0
;
param_id
<
NVTETensorParam
::
kNVTENumTensorParams
;
param_id
++
)
{
for
(
int
param_id
=
0
;
param_id
<
NVTETensorParam
::
kNVTENumTensorParams
;
param_id
++
)
{
auto
param_type
=
static_cast
<
NVTETensorParam
>
(
param_id
);
auto
param_type
=
static_cast
<
NVTETensorParam
>
(
param_id
);
auto
param
=
source
.
get_parameter
(
param_type
);
auto
param
=
source
.
get_parameter
(
param_type
);
auto
param_dptr
=
reinterpret_cast
<
char
*>
(
param
.
data_ptr
);
auto
param_dptr
=
reinterpret_cast
<
char
*>
(
param
.
data_ptr
);
auto
param_dtype
=
static_cast
<
DType
>
(
param
.
dtype
);
auto
param_dtype
=
static_cast
<
DType
>
(
param
.
dtype
);
auto
param_shape
=
AS_VECTOR
(
param
.
shape
);
auto
param_shape
=
shape_to_vector
(
param
.
shape
);
if
(
param_dptr
!=
nullptr
)
{
if
(
param_dptr
!=
nullptr
)
{
if
(
param_type
==
NVTETensorParam
::
kNVTERowwiseData
||
if
(
param_type
==
NVTETensorParam
::
kNVTERowwiseData
||
...
@@ -163,8 +206,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
...
@@ -163,8 +206,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
param_shape
=
chunk_shape
;
param_shape
=
chunk_shape
;
if
(
param_type
==
NVTETensorParam
::
kNVTEColumnwiseData
&&
if
(
param_type
==
NVTETensorParam
::
kNVTEColumnwiseData
&&
source
.
scaling_mode
()
!
=
NVTEScalingMode
::
NVTE_
MXFP8_1D
_SCALING
)
{
source
.
scaling_mode
()
=
=
NVTEScalingMode
::
NVTE_
DELAYED_TENSOR
_SCALING
)
{
// Columnwise shape for
non-block
scaled tensors shifts the last dimension to the front
// Columnwise shape for
FP8 tensor-
scaled tensors shifts the last dimension to the front
auto
last_dim
=
param_shape
.
back
();
auto
last_dim
=
param_shape
.
back
();
param_shape
.
pop_back
();
param_shape
.
pop_back
();
param_shape
.
insert
(
param_shape
.
begin
(),
last_dim
);
param_shape
.
insert
(
param_shape
.
begin
(),
last_dim
);
...
@@ -172,18 +215,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
...
@@ -172,18 +215,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
}
else
if
(
source
.
scaling_mode
()
==
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
&&
}
else
if
(
source
.
scaling_mode
()
==
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
&&
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
||
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
||
param_type
==
NVTETensorParam
::
kNVTEColumnwiseScaleInv
))
{
param_type
==
NVTETensorParam
::
kNVTEColumnwiseScaleInv
))
{
// Calculate block scaling offset and size
// Calculate offset and size for MXFP8 scale-invs
auto
scaled_tensor_dim_size
=
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
)
size_t
chunk_scale_height
=
chunk_height
;
?
source
.
shape
().
data
[
0
]
size_t
chunk_scale_width
=
chunk_width
;
:
source
.
columnwise_shape
().
data
[
0
];
if
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
)
{
auto
scaled_chunk_dim_size
=
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
)
chunk_scale_width
/=
32
;
?
chunk_shape
.
front
()
}
else
{
:
chunk_shape
.
back
();
chunk_scale_height
/=
32
;
auto
chunk_scale_start
=
chunk_offset
/
32
;
}
auto
chunk_scale_end
=
(
chunk_offset
+
scaled_chunk_dim_size
)
/
32
;
param_dptr
+=
(
chunk_offset
/
32
)
*
typeToSize
(
param_dtype
);
auto
chunk_scale_size
=
chunk_scale_end
-
chunk_scale_start
;
param_shape
=
{
chunk_scale_height
,
chunk_scale_width
};
param_dptr
+=
chunk_scale_start
*
typeToSize
(
param_dtype
);
param_shape
=
std
::
vector
<
size_t
>
{
chunk_scale_size
};
}
}
// Set chunked source parameters into the chunked tensor output
// Set chunked source parameters into the chunked tensor output
...
@@ -434,10 +475,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -434,10 +475,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
size_t
k
=
transa
?
A
.
size
(
1
)
:
A
.
size
(
0
);
size_t
k
=
transa
?
A
.
size
(
1
)
:
A
.
size
(
0
);
size_t
n
=
_ubuf
.
size
(
0
);
size_t
n
=
_ubuf
.
size
(
0
);
size_t
m_chunk
=
m
/
_num_splits
;
size_t
m_chunk
=
m
/
_num_splits
;
const
std
::
vector
<
size_t
>
input_a_chunk_shape
=
(
transa
?
std
::
vector
<
size_t
>
{
m_chunk
,
k
}
:
std
::
vector
<
size_t
>
{
k
,
m_chunk
});
const
std
::
vector
<
size_t
>
output_chunk_shape
=
{
n
,
m_chunk
};
size_t
input_a_chunk_size
=
m_chunk
*
k
;
size_t
input_a_chunk_size
=
m_chunk
*
k
;
size_t
output_chunk_size
=
n
*
m_chunk
;
size_t
output_chunk_size
=
n
*
m_chunk
;
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
// Helper function to get bias chunk if needed
auto
maybe_get_bias_chunk
=
[
this
,
&
bias
,
m_chunk
](
size_t
chunk_id
)
->
TensorWrapper
{
if
(
bias
.
dptr
()
==
nullptr
)
{
return
TensorWrapper
();
}
return
get_tensor_chunk
(
bias
,
chunk_id
*
m_chunk
,
{
m_chunk
});
};
// Catch up the default torch stream
// Catch up the default torch stream
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
...
@@ -449,12 +501,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -449,12 +501,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
if
(
_rs_overlap_first_gemm
)
{
if
(
_rs_overlap_first_gemm
)
{
auto
input_a_chunk
=
get_tensor_chunk
(
A
,
0
,
{
m_chunk
,
k
});
auto
input_a_chunk
=
get_tensor_chunk
(
A
,
0
,
input_a_chunk_shape
);
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
0
,
{
m
,
m_chunk
});
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
0
,
output_chunk_shape
);
auto
bias_chunk
=
maybe_get_bias_chunk
(
0
);
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]);
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]);
}
else
{
}
else
{
...
@@ -464,18 +517,19 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -464,18 +517,19 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
}
for
(
int
i
=
1
;
i
<
_num_splits
;
i
++
)
{
for
(
int
i
=
1
;
i
<
_num_splits
;
i
++
)
{
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
input_a_chunk_shape
);
output_chunk
=
get_buffer_chunk_like
(
D
,
i
*
output_chunk_size
,
{
n
,
m_chunk
});
output_chunk
=
get_buffer_chunk_like
(
D
,
i
*
output_chunk_size
,
output_chunk_shape
);
bias_chunk
=
maybe_get_bias_chunk
(
i
);
workspace_chunk
=
get_tensor_chunk
(
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
_stream_compute
[
i
%
_stream_compute
.
size
()]);
}
else
{
}
else
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
...
@@ -519,13 +573,14 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -519,13 +573,14 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
}
}
else
{
}
else
{
for
(
int
i
=
0
;
i
<
_num_splits
;
i
++
)
{
for
(
int
i
=
0
;
i
<
_num_splits
;
i
++
)
{
auto
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
auto
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
input_a_chunk_shape
);
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
i
*
output_chunk_size
,
{
n
,
m_chunk
});
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
i
*
output_chunk_size
,
output_chunk_shape
);
auto
bias_chunk
=
maybe_get_bias_chunk
(
i
);
auto
workspace_chunk
=
get_tensor_chunk
(
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
_stream_compute
[
i
%
_stream_compute
.
size
()]);
...
@@ -605,14 +660,17 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
...
@@ -605,14 +660,17 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
void
*
buffer_ptr
;
void
*
buffer_ptr
;
_ub_reg
=
register_user_buffer_collective
(
&
buffer_ptr
,
buffer_bytes
,
_ub_comm
,
true
);
_ub_reg
=
register_user_buffer_collective
(
&
buffer_ptr
,
buffer_bytes
,
_ub_comm
,
true
);
if
(
_rank
==
0
)
printf
(
"!!! [UBP2P] Register UBuf %d
\n
"
,
_ub_reg
);
if
(
_rank
==
0
)
printf
(
"!!! [UBP2P] Register UBuf %d
\n
"
,
_ub_reg
);
_ubuf
=
TensorWrapper
(
buffer_ptr
,
{
buffer_shape
[
0
]
/
tp_size
*
_num_ubuf_chunks
,
buffer_shape
[
1
]},
_ubuf
=
TensorWrapper
(
buffer_dtype
);
buffer_ptr
,
std
::
vector
<
size_t
>
{
buffer_shape
[
0
]
/
tp_size
*
_num_ubuf_chunks
,
buffer_shape
[
1
]},
buffer_dtype
);
// Create tensor chunks for easy management
// Create tensor chunks for easy management
char
*
ubuf_byte_ptr
=
reinterpret_cast
<
char
*>
(
buffer_ptr
);
char
*
ubuf_byte_ptr
=
reinterpret_cast
<
char
*>
(
buffer_ptr
);
for
(
int
i
=
0
;
i
<
_num_ubuf_chunks
;
i
++
)
{
for
(
int
i
=
0
;
i
<
_num_ubuf_chunks
;
i
++
)
{
_ubufs
.
push_back
(
TensorWrapper
(
reinterpret_cast
<
void
*>
(
ubuf_byte_ptr
),
_ubufs
.
push_back
(
TensorWrapper
(
reinterpret_cast
<
void
*>
(
ubuf_byte_ptr
),
{
buffer_shape
[
0
]
/
tp_size
,
buffer_shape
[
1
]},
buffer_dtype
));
std
::
vector
<
size_t
>
{
buffer_shape
[
0
]
/
tp_size
,
buffer_shape
[
1
]},
buffer_dtype
));
ubuf_byte_ptr
+=
buffer_chunk_bytes
;
ubuf_byte_ptr
+=
buffer_chunk_bytes
;
}
}
...
@@ -661,7 +719,7 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
...
@@ -661,7 +719,7 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
TensorWrapper
CommOverlapP2PBase
::
get_buffer_chunk_by_id
(
const
TensorWrapper
&
source
,
TensorWrapper
CommOverlapP2PBase
::
get_buffer_chunk_by_id
(
const
TensorWrapper
&
source
,
size_t
chunk_id
)
{
size_t
chunk_id
)
{
// Start with a chunk of the source tensor
// Start with a chunk of the source tensor
auto
chunk
=
get_tensor_chunk
(
source
,
0
,
AS_VECTOR
(
_ubufs
[
chunk_id
].
shape
()));
auto
chunk
=
get_tensor_chunk
(
source
,
0
,
shape_to_vector
(
_ubufs
[
chunk_id
].
shape
()));
// Update chunk with offset data pointers from the communication buffer
// Update chunk with offset data pointers from the communication buffer
if
(
chunk
.
dptr
()
!=
nullptr
)
{
if
(
chunk
.
dptr
()
!=
nullptr
)
{
...
@@ -711,7 +769,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
...
@@ -711,7 +769,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
0
],
_start_compute
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
0
],
_start_compute
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_recv
,
_start_compute
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_recv
,
_start_compute
,
0
));
auto
input_b
=
get_buffer_chunk_like
(
B
,
0
,
AS_VECTOR
(
B
.
shape
()));
auto
input_b
=
get_buffer_chunk_like
(
B
,
0
,
shape_to_vector
(
B
.
shape
()));
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
...
@@ -798,8 +856,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -798,8 +856,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Get communication and GEMM output chunk sizes
// Get communication and GEMM output chunk sizes
const
int
comm_bytes
=
_ubufs
[
0
].
numel
()
*
_ubufs
[
0
].
element_size
();
const
int
comm_bytes
=
_ubufs
[
0
].
numel
()
*
_ubufs
[
0
].
element_size
();
const
bool
do_gelu
=
pre_gelu_out
.
numel
()
>
0
;
const
bool
do_gelu
=
pre_gelu_out
.
numel
()
>
0
;
size_t
input_chunk_size
=
n_chunk
*
k
;
size_t
output_chunk_size
=
n_chunk
*
m
;
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
...
@@ -810,10 +866,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -810,10 +866,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
}
if
(
_aggregate
)
{
if
(
_aggregate
)
{
const
int
num_steps
=
_tp_size
/
2
;
const
int
num_steps
=
_tp_size
/
2
;
#ifndef __HIP_PLATFORM_AMD__
input_chunk_size
*=
2
;
// Chunk dims
output_chunk_size
*=
2
;
std
::
vector
<
size_t
>
input_b_chunk_shape
=
#endif
(
transb
?
std
::
vector
<
size_t
>
{
k
,
2
*
n_chunk
}
:
std
::
vector
<
size_t
>
{
2
*
n_chunk
,
k
});
std
::
vector
<
size_t
>
output_chunk_shape
=
{
2
*
n_chunk
,
k
};
size_t
input_b_chunk_size
=
2
*
n_chunk
*
k
;
size_t
output_chunk_size
=
2
*
n_chunk
*
m
;
// Initial 1X input chunk exchange between neighboring peers
// Initial 1X input chunk exchange between neighboring peers
int
send_chunk_id
=
_tp_id
;
int
send_chunk_id
=
_tp_id
;
...
@@ -842,8 +901,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -842,8 +901,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM
// GEMM
auto
input_b_chunk
=
auto
input_b_chunk
=
get_buffer_chunk_like
(
B
,
input_chunk_size
*
send_chunk_id
,
{
n_chunk
*
2
,
k
});
get_buffer_chunk_like
(
B
,
input_b_chunk_size
*
send_chunk_id
,
input_b_chunk_shape
);
auto
output_chunk
=
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
*
2
,
m
});
auto
output_chunk
=
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
,
output_chunk_shape
);
auto
aux_chunk
=
auto
aux_chunk
=
(
do_gelu
)
(
do_gelu
)
?
get_tensor_chunk
(
pre_gelu_out
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
*
2
,
k
})
?
get_tensor_chunk
(
pre_gelu_out
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
*
2
,
k
})
...
@@ -882,6 +942,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -882,6 +942,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
}
}
}
}
else
{
}
else
{
// Chunk dims
std
::
vector
<
size_t
>
input_b_chunk_shape
=
(
transb
?
std
::
vector
<
size_t
>
{
k
,
n_chunk
}
:
std
::
vector
<
size_t
>
{
n_chunk
,
k
});
std
::
vector
<
size_t
>
output_chunk_shape
=
{
n_chunk
,
m
};
size_t
input_b_chunk_size
=
n_chunk
*
k
;
size_t
output_chunk_size
=
n_chunk
*
m
;
for
(
int
i
=
0
;
i
<
_tp_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
_tp_size
;
i
++
)
{
// Set the userbuffer id. Buffer under send is the input for the current
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
...
@@ -893,8 +960,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -893,8 +960,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
int
recv_offset
=
comm_bytes
*
recv_chunk_id
;
int
recv_offset
=
comm_bytes
*
recv_chunk_id
;
// GEMM
// GEMM
auto
input_b_chunk
=
get_buffer_chunk_like
(
B
,
input_chunk_size
*
send_chunk_id
,
{
n_chunk
,
k
});
auto
input_b_chunk
=
auto
output_chunk
=
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
,
m
});
get_buffer_chunk_like
(
B
,
input_b_chunk_size
*
send_chunk_id
,
input_b_chunk_shape
);
auto
output_chunk
=
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
,
output_chunk_shape
);
auto
aux_chunk
=
auto
aux_chunk
=
(
do_gelu
)
(
do_gelu
)
?
get_tensor_chunk
(
pre_gelu_out
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
,
k
})
?
get_tensor_chunk
(
pre_gelu_out
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
,
k
})
...
@@ -972,7 +1041,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
...
@@ -972,7 +1041,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
// Atomic GEMM
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
auto
output_d
=
get_buffer_chunk_like
(
D
,
0
,
AS_VECTOR
(
D
.
shape
()));
auto
output_d
=
get_buffer_chunk_like
(
D
,
0
,
shape_to_vector
(
D
.
shape
()));
nvte_cublas_atomic_gemm
(
A
.
data
(),
B
.
data
(),
output_d
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
nvte_cublas_atomic_gemm
(
A
.
data
(),
B
.
data
(),
output_d
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace
.
data
(),
accumulate
,
use_split_accumulator
,
transa
,
transb
,
grad
,
workspace
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
0
,
_tp_size
,
true
,
_counter
.
data
(),
stream_main
);
_math_sms
,
0
,
_tp_size
,
true
,
_counter
.
data
(),
stream_main
);
...
@@ -1053,6 +1122,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
...
@@ -1053,6 +1122,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto
input_b_chunk
=
get_tensor_chunk
(
B
,
input_b_chunk_id
*
input_chunk_size
,
{
n_chunk
,
k
});
auto
input_b_chunk
=
get_tensor_chunk
(
B
,
input_b_chunk_id
*
input_chunk_size
,
{
n_chunk
,
k
});
auto
output_chunk
=
get_buffer_chunk_by_id
(
D
,
i
);
auto
output_chunk
=
get_buffer_chunk_by_id
(
D
,
i
);
auto
workspace_chunk
=
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
stream_id
*
workspace_size_chunk
,
{
workspace_size_chunk
});
get_tensor_chunk
(
workspace
,
stream_id
*
workspace_size_chunk
,
{
workspace_size_chunk
});
...
...
transformer_engine/common/common.cu
View file @
f8c2af4c
...
@@ -35,6 +35,65 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
...
@@ -35,6 +35,65 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
}
}
}
}
namespace
{
constexpr
size_t
kThreadsPerBlock
=
256
;
template
<
typename
TVectorized
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
memset_kernel
(
void
*
__restrict__
ptr
,
int
value
,
size_t
size_in_bytes
)
{
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
*
sizeof
(
TVectorized
)
>=
size_in_bytes
)
{
return
;
// Out of bounds
}
if
((
idx
+
1
)
*
sizeof
(
TVectorized
)
>
size_in_bytes
)
{
// If the buffer size is not an even multiple of the vectorization, manually set the remaining bytes unvectorized.
size_t
remaining_bytes
=
size_in_bytes
-
idx
*
sizeof
(
TVectorized
);
memset
(
reinterpret_cast
<
uint8_t
*>
(
ptr
)
+
idx
*
sizeof
(
TVectorized
),
value
,
remaining_bytes
);
return
;
}
union
{
TVectorized
value
;
uint8_t
data
[
sizeof
(
TVectorized
)];
}
data
;
for
(
size_t
i
=
0
;
i
<
sizeof
(
TVectorized
);
++
i
)
{
data
.
data
[
i
]
=
static_cast
<
uint8_t
>
(
value
);
}
reinterpret_cast
<
TVectorized
*>
(
ptr
)[
idx
]
=
data
.
value
;
}
}
// namespace
#define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \
if (size_in_bytes >= sizeof(vectorizedType) && \
reinterpret_cast<size_t>(ptr) % sizeof(vectorizedType) == 0) { \
size_t numBlocks = DIVUP(size_in_bytes, kThreadsPerBlock * sizeof(vectorizedType)); \
dim3 grid(numBlocks, 1, 1); \
memset_kernel<vectorizedType> \
<<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \
return; \
}
extern
"C"
{
void
nvte_memset
(
void
*
ptr
,
int
value
,
size_t
size_in_bytes
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_memset
);
NVTE_CHECK
(
ptr
!=
nullptr
,
"Pointer for memset must be allocated."
);
if
(
size_in_bytes
>
4096
)
{
// Use cudaMemsetAsync for larger sizes.
cudaMemsetAsync
(
ptr
,
value
,
size_in_bytes
,
stream
);
return
;
}
MEMSET_VECTORIZED_KERNEL_DISPATCH
(
ptr
,
size_in_bytes
,
value
,
float4
,
stream
);
MEMSET_VECTORIZED_KERNEL_DISPATCH
(
ptr
,
size_in_bytes
,
value
,
float2
,
stream
);
MEMSET_VECTORIZED_KERNEL_DISPATCH
(
ptr
,
size_in_bytes
,
value
,
float
,
stream
);
MEMSET_VECTORIZED_KERNEL_DISPATCH
(
ptr
,
size_in_bytes
,
value
,
uint8_t
,
stream
);
}
}
// extern "C"
void
checkCuDriverContext
(
CUstream
stream
)
{
void
checkCuDriverContext
(
CUstream
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
return
;
return
;
...
@@ -144,4 +203,16 @@ bool is_supported_by_CC_100() {
...
@@ -144,4 +203,16 @@ bool is_supported_by_CC_100() {
#endif
#endif
}
}
std
::
vector
<
std
::
vector
<
Tensor
*>>
convert_tensor_array
(
NVTETensor
**
nvte_tensors
,
size_t
outer_size
,
size_t
inner_size
)
{
std
::
vector
<
std
::
vector
<
Tensor
*>>
ret
;
for
(
size_t
i
=
0
;
i
<
outer_size
;
++
i
)
{
ret
.
emplace_back
();
for
(
size_t
j
=
0
;
j
<
inner_size
;
++
j
)
{
ret
.
back
().
push_back
(
reinterpret_cast
<
Tensor
*>
(
nvte_tensors
[
i
][
j
]));
}
}
return
ret
;
}
}
// namespace transformer_engine
}
// namespace transformer_engine
transformer_engine/common/common.h
View file @
f8c2af4c
...
@@ -116,7 +116,7 @@ struct Tensor {
...
@@ -116,7 +116,7 @@ struct Tensor {
columnwise_scale_inv
(
nullptr
,
{
1
},
DType
::
kFloat32
),
columnwise_scale_inv
(
nullptr
,
{
1
},
DType
::
kFloat32
),
scaling_mode
(
NVTE_DELAYED_TENSOR_SCALING
)
{}
scaling_mode
(
NVTE_DELAYED_TENSOR_SCALING
)
{}
in
t
numel
()
const
{
size_
t
numel
()
const
{
size_t
acc
=
1
;
size_t
acc
=
1
;
for
(
const
auto
dim
:
shape
())
{
for
(
const
auto
dim
:
shape
())
{
acc
*=
dim
;
acc
*=
dim
;
...
@@ -138,6 +138,14 @@ struct Tensor {
...
@@ -138,6 +138,14 @@ struct Tensor {
return
data
.
dtype
;
return
data
.
dtype
;
}
}
size_t
dim
()
const
{
if
(
!
has_data
()
&&
has_columnwise_data
())
{
return
columnwise_data
.
shape
.
size
();
}
else
{
return
data
.
shape
.
size
();
}
}
std
::
vector
<
size_t
>
shape
()
const
{
std
::
vector
<
size_t
>
shape
()
const
{
/* Note: We sometimes experience spurious compiler errors
/* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC
* (-Wstringop-overflow) from this function. It appears that GCC
...
@@ -243,6 +251,7 @@ constexpr T DIVUP(const T &x, const T &y) {
...
@@ -243,6 +251,7 @@ constexpr T DIVUP(const T &x, const T &y) {
}
}
using
byte
=
uint8_t
;
using
byte
=
uint8_t
;
using
int16
=
int16_t
;
using
int32
=
int32_t
;
using
int32
=
int32_t
;
using
int64
=
int64_t
;
using
int64
=
int64_t
;
using
fp32
=
float
;
using
fp32
=
float
;
...
@@ -271,6 +280,7 @@ constexpr inline const char *type_name() noexcept;
...
@@ -271,6 +280,7 @@ constexpr inline const char *type_name() noexcept;
return #T; \
return #T; \
}
}
TRANSFORMER_ENGINE_TYPE_NAME
(
uint8_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
uint8_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
int16_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
int32_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
int32_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
int64_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
int64_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
float
)
TRANSFORMER_ENGINE_TYPE_NAME
(
float
)
...
@@ -327,7 +337,7 @@ struct TypeExtrema {
...
@@ -327,7 +337,7 @@ struct TypeExtrema {
template
<
typename
T
>
template
<
typename
T
>
struct
TypeInfo
{
struct
TypeInfo
{
using
types
=
std
::
tuple
<
byte
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
>
;
template
<
typename
U
,
DType
current
>
template
<
typename
U
,
DType
current
>
struct
Helper
{
struct
Helper
{
...
@@ -364,6 +374,10 @@ struct TypeInfo {
...
@@ -364,6 +374,10 @@ struct TypeInfo {
using type = unsigned char; \
using type = unsigned char; \
{ __VA_ARGS__ } \
{ __VA_ARGS__ } \
} break; \
} break; \
case DType::kInt16: { \
using type = int16_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt32: { \
case DType::kInt32: { \
using type = int32_t; \
using type = int32_t; \
{ __VA_ARGS__ } \
{ __VA_ARGS__ } \
...
@@ -400,6 +414,33 @@ struct TypeInfo {
...
@@ -400,6 +414,33 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
NVTE_ERROR("Invalid type."); \
}
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
switch (dtype) { \
switch (dtype) { \
using namespace transformer_engine; \
using namespace transformer_engine; \
...
@@ -599,6 +640,9 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
...
@@ -599,6 +640,9 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
bool
is_supported_by_CC_100
();
bool
is_supported_by_CC_100
();
std
::
vector
<
std
::
vector
<
Tensor
*>>
convert_tensor_array
(
NVTETensor
**
nvte_tensors
,
size_t
outer_size
,
size_t
inner_size
);
}
// namespace transformer_engine
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
transformer_engine/
pytorch/csrc/thd_utils
.cu
h
→
transformer_engine/
common/fused_attn/context_parallel
.cu
View file @
f8c2af4c
...
@@ -3,21 +3,25 @@
...
@@ -3,21 +3,25 @@
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#include <assert.h>
#include <assert.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace
transformer_engine
{
namespace
context_parallel
{
struct
LseCorrectionFunctor
{
struct
LseCorrectionFunctor
{
__forceinline__
__device__
static
void
run
(
double
*
lse
,
float
*
half_lse
,
size_t
idx
,
__forceinline__
__device__
static
void
run
(
float
*
lse
,
float
*
half_lse
,
size_t
idx
,
size_t
half_idx
)
{
size_t
half_idx
)
{
double
val
=
lse
[
idx
];
float
val
=
lse
[
idx
];
float
val_per_step
=
half_lse
[
half_idx
];
float
val_per_step
=
half_lse
[
half_idx
];
double
max_scale
=
max
(
val
,
val_per_step
);
float
max_scale
=
max
(
val
,
val_per_step
);
double
min_scale
=
min
(
val
,
val_per_step
);
float
min_scale
=
min
(
val
,
val_per_step
);
lse
[
idx
]
=
max_scale
+
log
(
1.0
+
exp
(
min_scale
-
max_scale
));
lse
[
idx
]
=
max_scale
+
log
1pf
(
exp
f
(
min_scale
-
max_scale
));
}
}
};
};
...
@@ -49,16 +53,13 @@ struct AddFunctor {
...
@@ -49,16 +53,13 @@ struct AddFunctor {
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
sizeof
(
float4
)
/
sizeof
(
dtype
);
i
++
)
{
for
(
int
i
=
0
;
i
<
sizeof
(
float4
)
/
sizeof
(
dtype
);
i
++
)
{
p_
[
i
]
+
=
p
[
i
];
p_
[
i
]
=
p_
[
i
]
+
p
[
i
];
}
}
reinterpret_cast
<
float4
*>
(
token
)[
idx
]
=
d_
;
reinterpret_cast
<
float4
*>
(
token
)[
idx
]
=
d_
;
}
}
};
};
namespace
transformer_engine
{
namespace
fused_attn
{
/***************************************************************************************************
/***************************************************************************************************
* Support THD format for Context Parallel: Binary search an array for a target value
* Support THD format for Context Parallel: Binary search an array for a target value
**************************************************************************************************/
**************************************************************************************************/
...
@@ -107,6 +108,7 @@ __global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int b
...
@@ -107,6 +108,7 @@ __global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int b
/***************************************************************************************************
/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
**************************************************************************************************/
__global__
void
thd_read_half_tensor_kernel
(
void
*
half
,
void
*
tensor
,
int
*
cu_seqlens
,
int
batch
,
__global__
void
thd_read_half_tensor_kernel
(
void
*
half
,
void
*
tensor
,
int
*
cu_seqlens
,
int
batch
,
int
hidden_size_in_bytes
,
int
half_idx
,
int
hidden_size_in_bytes
,
int
half_idx
,
int
dim_size_of_token
)
{
int
dim_size_of_token
)
{
...
@@ -148,8 +150,8 @@ __global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_se
...
@@ -148,8 +150,8 @@ __global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_se
* Support THD format for Context Parallel: softmax_lse related operations
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
**************************************************************************************************/
template
<
typename
lse_dtype
,
bool
lse_packed
,
typename
Functor
>
template
<
bool
lse_packed
,
typename
Functor
>
__global__
void
thd_lse_kernel
(
lse_dtype
*
lse
,
float
*
half_lse
,
int
*
cu_seqlens
,
int
batch
,
__global__
void
thd_lse_kernel
(
float
*
lse
,
float
*
half_lse
,
int
*
cu_seqlens
,
int
batch
,
int
num_heads
,
int
lse_seqlen
,
int
second_half_lse_seqlen
)
{
int
num_heads
,
int
lse_seqlen
,
int
second_half_lse_seqlen
)
{
extern
__shared__
int
cu_seqlens_s
[];
extern
__shared__
int
cu_seqlens_s
[];
for
(
int
i
=
threadIdx
.
x
;
i
<=
batch
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<=
batch
;
i
+=
blockDim
.
x
)
{
...
@@ -218,7 +220,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
...
@@ -218,7 +220,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
idx
=
row
*
lse_seqlen
+
col
+
seq_len
*
only_second_half
;
idx
=
row
*
lse_seqlen
+
col
+
seq_len
*
only_second_half
;
idx_per_step
=
row
*
lse_per_step_seqlen
+
col
;
idx_per_step
=
row
*
lse_per_step_seqlen
+
col
;
}
}
float
lse_corrected_exp
=
exp
(
lse_per_step
[
idx_per_step
]
-
lse
[
idx
]);
float
lse_corrected_exp
=
exp
f
(
lse_per_step
[
idx_per_step
]
-
lse
[
idx
]);
idx
=
token_id
+
cu_seqlens_s
[
seq_id
+
1
]
*
only_second_half
;
idx
=
token_id
+
cu_seqlens_s
[
seq_id
+
1
]
*
only_second_half
;
idx
=
(
idx
*
num_heads
+
head_id
)
*
dim_per_head
;
idx
=
(
idx
*
num_heads
+
head_id
)
*
dim_per_head
;
...
@@ -232,7 +234,10 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
...
@@ -232,7 +234,10 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
dtype
*
p_per_step
=
reinterpret_cast
<
dtype
*>
(
&
data_per_step
);
dtype
*
p_per_step
=
reinterpret_cast
<
dtype
*>
(
&
data_per_step
);
dtype
*
p
=
reinterpret_cast
<
dtype
*>
(
&
data
);
dtype
*
p
=
reinterpret_cast
<
dtype
*>
(
&
data
);
for
(
int
k
=
0
;
k
<
sizeof
(
float4
)
/
sizeof
(
dtype
);
k
++
)
{
for
(
int
k
=
0
;
k
<
sizeof
(
float4
)
/
sizeof
(
dtype
);
k
++
)
{
p
[
k
]
+=
(
p_per_step
[
k
]
==
0
?
0
:
p_per_step
[
k
]
*
lse_corrected_exp
);
p
[
k
]
=
p
[
k
]
+
(
p_per_step
[
k
]
==
static_cast
<
dtype
>
(
0.
f
)
?
static_cast
<
dtype
>
(
0.
f
)
:
static_cast
<
dtype
>
(
static_cast
<
float
>
(
p_per_step
[
k
])
*
lse_corrected_exp
));
}
}
reinterpret_cast
<
float4
*>
(
cur_out
)[
j
]
=
data
;
reinterpret_cast
<
float4
*>
(
cur_out
)[
j
]
=
data
;
}
}
...
@@ -297,6 +302,442 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in
...
@@ -297,6 +302,442 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in
}
}
}
}
}
// namespace fused_attn
/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
void
thd_read_half_tensor
(
const
Tensor
&
tensor
,
const
Tensor
&
cu_seqlens
,
Tensor
&
half
,
int
half_idx
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
tensor
.
dim
()
==
3
||
tensor
.
dim
()
==
4
);
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
auto
tensor_shape
=
tensor
.
shape
();
NVTE_CHECK
(
cu_seqlens
.
dim
()
==
1
);
NVTE_CHECK
(
cu_seqlens_shape
[
0
]
>=
2
);
// Shapes of q and dq are [t, h, d], so the dimension of "t" is 0
// Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1
int
seq_dim
=
tensor
.
dim
()
==
3
?
0
:
1
;
int
batch
=
cu_seqlens_shape
[
0
]
-
1
;
int
num_heads
=
tensor_shape
[
seq_dim
+
1
];
int
dim_per_head
=
tensor_shape
[
seq_dim
+
2
];
int
hidden_size_in_bytes
=
num_heads
*
dim_per_head
*
typeToSize
(
tensor
.
dtype
());
// For 128-bits load/store
NVTE_CHECK
(
hidden_size_in_bytes
%
16
==
0
);
// Launch Kernel
constexpr
unsigned
int
block
=
256
;
unsigned
int
grid_x
=
(
tensor_shape
[
seq_dim
]
/
2
*
32
+
block
-
1
)
/
block
;
unsigned
int
grid_y
=
1
;
for
(
int
i
=
0
;
i
<
seq_dim
;
i
++
)
{
grid_y
*=
tensor_shape
[
i
];
}
dim3
grid
=
{
grid_x
,
grid_y
};
thd_read_half_tensor_kernel
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
half
.
data
.
dptr
,
tensor
.
data
.
dptr
,
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
hidden_size_in_bytes
,
half_idx
,
tensor_shape
[
seq_dim
]);
}
/***************************************************************************************************
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
void
thd_second_half_lse_correction
(
Tensor
lse
,
const
Tensor
&
lse_per_step
,
const
Tensor
&
cu_seqlens
,
bool
lse_packed
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
lse
.
dtype
()
==
DType
::
kFloat32
);
NVTE_CHECK
(
lse_per_step
.
dtype
()
==
DType
::
kFloat32
);
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
NVTE_CHECK
(
cu_seqlens
.
dim
()
==
1
);
int
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
;
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
auto
lse_shape
=
lse
.
shape
();
auto
lse_per_step_shape
=
lse_per_step
.
shape
();
if
(
lse_packed
)
{
NVTE_CHECK
(
lse
.
dim
()
==
2
);
NVTE_CHECK
(
lse_per_step
.
dim
()
==
2
);
batch
=
cu_seqlens_shape
[
0
]
-
1
;
num_heads
=
lse_shape
[
0
];
lse_seqlen
=
lse_shape
[
1
];
second_half_lse_seqlen
=
lse_per_step_shape
[
1
];
NVTE_CHECK
(
lse_per_step_shape
[
0
]
==
num_heads
);
NVTE_CHECK
(
second_half_lse_seqlen
>=
lse_seqlen
/
2
);
}
else
{
NVTE_CHECK
(
lse
.
dim
()
==
3
);
NVTE_CHECK
(
lse_per_step
.
dim
()
==
3
);
batch
=
lse_shape
[
0
];
num_heads
=
lse_shape
[
1
];
lse_seqlen
=
lse_shape
[
2
];
second_half_lse_seqlen
=
lse_per_step_shape
[
2
];
NVTE_CHECK
(
lse_per_step_shape
[
0
]
==
batch
);
NVTE_CHECK
(
lse_per_step_shape
[
1
]
==
num_heads
);
NVTE_CHECK
(
second_half_lse_seqlen
==
lse_seqlen
/
2
);
NVTE_CHECK
(
cu_seqlens_shape
[
0
]
==
batch
+
1
);
}
constexpr
unsigned
int
block
=
256
;
unsigned
int
grid_x
=
(
lse_seqlen
/
2
+
block
-
1
)
/
block
;
unsigned
int
grid_y
=
num_heads
;
dim3
grid
=
{
grid_x
,
grid_y
};
if
(
lse_packed
)
{
thd_lse_kernel
<
true
,
LseCorrectionFunctor
><<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
}
else
{
thd_lse_kernel
<
false
,
LseCorrectionFunctor
><<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
}
}
void
thd_read_second_half_lse
(
const
Tensor
&
lse
,
const
Tensor
&
cu_seqlens
,
Tensor
&
half_lse
,
bool
lse_packed
,
int
second_half_lse_seqlen
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
lse
.
dtype
()
==
DType
::
kFloat32
);
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
NVTE_CHECK
(
cu_seqlens
.
dim
()
==
1
);
int
batch
,
num_heads
,
lse_seqlen
;
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
auto
lse_shape
=
lse
.
shape
();
if
(
lse_packed
)
{
NVTE_CHECK
(
lse
.
dim
()
==
2
);
batch
=
cu_seqlens_shape
[
0
]
-
1
;
num_heads
=
lse_shape
[
0
];
lse_seqlen
=
lse_shape
[
1
];
NVTE_CHECK
(
second_half_lse_seqlen
>=
lse_seqlen
/
2
);
}
else
{
NVTE_CHECK
(
lse
.
dim
()
==
3
);
batch
=
lse_shape
[
0
];
num_heads
=
lse_shape
[
1
];
lse_seqlen
=
lse_shape
[
2
];
NVTE_CHECK
(
cu_seqlens_shape
[
0
]
==
batch
+
1
);
NVTE_CHECK
(
second_half_lse_seqlen
==
lse_seqlen
/
2
);
}
constexpr
unsigned
int
block
=
256
;
unsigned
int
grid_x
=
(
lse_seqlen
/
2
+
block
-
1
)
/
block
;
unsigned
int
grid_y
=
num_heads
;
dim3
grid
=
{
grid_x
,
grid_y
};
if
(
lse_packed
)
{
thd_lse_kernel
<
true
,
ReadLseFunctor
><<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
half_lse
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
}
else
{
thd_lse_kernel
<
false
,
ReadLseFunctor
><<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
half_lse
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/
template
<
typename
dtype
,
int
only_second_half
>
static
void
thd_out_correction_helper
(
Tensor
out
,
const
Tensor
&
out_per_step
,
const
Tensor
&
lse
,
const
Tensor
&
lse_per_step
,
const
Tensor
&
cu_seqlens
,
bool
lse_packed
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
out
.
dtype
()
==
out_per_step
.
dtype
());
NVTE_CHECK
(
lse
.
dtype
()
==
DType
::
kFloat32
);
NVTE_CHECK
(
lse_per_step
.
dtype
()
==
DType
::
kFloat32
);
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
auto
out_shape
=
out
.
shape
();
auto
lse_shape
=
lse
.
shape
();
auto
out_per_step_shape
=
out_per_step
.
shape
();
auto
lse_per_step_shape
=
lse_per_step
.
shape
();
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
int
total_tokens
=
out_shape
[
0
];
int
num_heads
=
out_shape
[
1
];
int
dim_per_head
=
out_shape
[
2
];
NVTE_CHECK
(
out_per_step_shape
[
0
]
==
total_tokens
/
(
only_second_half
+
1
));
NVTE_CHECK
(
out_per_step_shape
[
1
]
==
num_heads
);
NVTE_CHECK
(
out_per_step_shape
[
2
]
==
dim_per_head
);
int
batch
,
lse_seqlen
,
lse_per_step_seqlen
;
if
(
lse_packed
)
{
batch
=
cu_seqlens_shape
[
0
]
-
1
;
lse_seqlen
=
lse_shape
[
1
];
lse_per_step_seqlen
=
lse_per_step_shape
[
1
];
NVTE_CHECK
(
lse_shape
[
0
]
==
num_heads
);
NVTE_CHECK
(
lse_seqlen
>=
total_tokens
);
NVTE_CHECK
(
lse_per_step_shape
[
0
]
==
num_heads
);
NVTE_CHECK
(
lse_per_step_seqlen
>=
lse_seqlen
/
(
only_second_half
+
1
));
}
else
{
batch
=
lse_shape
[
0
];
lse_seqlen
=
lse_shape
[
2
];
lse_per_step_seqlen
=
lse_per_step_shape
[
2
];
NVTE_CHECK
(
lse_shape
[
1
]
==
num_heads
);
NVTE_CHECK
(
lse_per_step_shape
[
0
]
==
batch
);
NVTE_CHECK
(
lse_per_step_shape
[
1
]
==
num_heads
);
NVTE_CHECK
(
lse_per_step_seqlen
==
lse_seqlen
/
(
only_second_half
+
1
));
NVTE_CHECK
(
cu_seqlens_shape
[
0
]
==
batch
+
1
);
}
constexpr
int
tile
=
16
;
constexpr
int
block
=
512
;
unsigned
int
grid_x
=
(
static_cast
<
size_t
>
(
total_tokens
)
/
(
only_second_half
+
1
)
*
tile
+
block
-
1
)
/
block
;
dim3
grid
=
{
grid_x
,
(
unsigned
int
)
num_heads
};
if
(
lse_packed
)
{
thd_out_correction_kernel
<
dtype
,
only_second_half
,
tile
,
true
>
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
out
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
out_per_step
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
dim_per_head
,
lse_seqlen
,
lse_per_step_seqlen
);
}
else
{
thd_out_correction_kernel
<
dtype
,
only_second_half
,
tile
,
false
>
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
out
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
out_per_step
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
dim_per_head
,
lse_seqlen
,
lse_per_step_seqlen
);
}
}
void
thd_out_correction
(
Tensor
out
,
const
Tensor
&
out_per_step
,
const
Tensor
&
lse
,
const
Tensor
&
lse_per_step
,
const
Tensor
&
cu_seqlens
,
bool
only_second_half
,
bool
lse_packed
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
if
(
only_second_half
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
out
.
dtype
(),
dtype
,
thd_out_correction_helper
<
dtype
,
1
>
(
out
,
out_per_step
,
lse
,
lse_per_step
,
cu_seqlens
,
lse_packed
,
stream
););
}
else
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
out
.
dtype
(),
dtype
,
thd_out_correction_helper
<
dtype
,
0
>
(
out
,
out_per_step
,
lse
,
lse_per_step
,
cu_seqlens
,
lse_packed
,
stream
););
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/
template
<
typename
dtype
,
typename
Functor_0
,
typename
Functor_1
,
int
functor_idx
>
static
void
thd_grad_correction_helper
(
Tensor
grad
,
const
Tensor
&
grad_per_step
,
const
Tensor
&
cu_seqlens
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
grad
.
dim
()
==
3
||
grad
.
dim
()
==
4
);
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
NVTE_CHECK
(
cu_seqlens
.
dim
()
==
1
);
auto
grad_shape
=
grad
.
shape
();
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
auto
grad_per_step_shape
=
grad_per_step
.
shape
();
// Shape of dq is [t, h, d], so the dimension of "t" is 0
// Shape of dkv is [2, t, h, d], so the dimension of "t" is 1
int
seq_dim
=
grad
.
dim
()
==
3
?
0
:
1
;
int
total_tokens
=
grad_shape
[
seq_dim
];
int
num_heads
=
grad_shape
[
seq_dim
+
1
];
int
dim_per_head
=
grad_shape
[
seq_dim
+
2
];
int
batch
=
cu_seqlens_shape
[
0
]
-
1
;
if
constexpr
(
functor_idx
<
2
)
{
NVTE_CHECK
(
grad_per_step_shape
[
seq_dim
]
==
total_tokens
/
2
);
}
else
{
NVTE_CHECK
(
grad_per_step_shape
[
seq_dim
]
==
total_tokens
);
}
NVTE_CHECK
(
grad_per_step_shape
[
seq_dim
+
1
]
==
num_heads
);
NVTE_CHECK
(
grad_per_step_shape
[
seq_dim
+
2
]
==
dim_per_head
);
size_t
hidden_size
=
num_heads
*
dim_per_head
;
NVTE_CHECK
((
hidden_size
*
typeToSize
(
grad
.
dtype
()))
%
16
==
0
);
constexpr
unsigned
int
block
=
256
;
unsigned
int
grid_x
;
if
constexpr
(
functor_idx
<
2
)
{
grid_x
=
(
total_tokens
/
2
*
32
+
block
-
1
)
/
block
;
}
else
{
grid_x
=
(
total_tokens
*
32
+
block
-
1
)
/
block
;
}
unsigned
int
grid_y
=
1
;
for
(
int
i
=
0
;
i
<
seq_dim
;
i
++
)
{
grid_y
*=
grad_shape
[
i
];
}
dim3
grid
=
{
grid_x
,
grid_y
};
thd_grad_correction_kernel
<
dtype
,
Functor_0
,
Functor_1
,
functor_idx
,
32
>
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
grad
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
grad_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
hidden_size
,
total_tokens
);
}
template
<
typename
dtype
>
static
void
thd_grad_dispatcher
(
Tensor
grad
,
const
Tensor
&
grad_per_step
,
const
Tensor
&
cu_seqlens
,
const
std
::
string
&
first_half
,
const
std
::
string
&
second_half
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
if
(
first_half
==
"add"
&&
second_half
==
"none"
)
{
thd_grad_correction_helper
<
dtype
,
AddFunctor
<
dtype
>
,
EmptyFunctor
,
0
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
if
(
first_half
==
"copy"
&&
second_half
==
"none"
)
{
thd_grad_correction_helper
<
dtype
,
CopyFunctor
,
EmptyFunctor
,
0
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
if
(
first_half
==
"none"
&&
second_half
==
"add"
)
{
thd_grad_correction_helper
<
dtype
,
EmptyFunctor
,
AddFunctor
<
dtype
>
,
1
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
if
(
first_half
==
"none"
&&
second_half
==
"copy"
)
{
thd_grad_correction_helper
<
dtype
,
EmptyFunctor
,
CopyFunctor
,
1
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
if
(
first_half
==
"add"
&&
second_half
==
"copy"
)
{
thd_grad_correction_helper
<
dtype
,
AddFunctor
<
dtype
>
,
CopyFunctor
,
2
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
if
(
first_half
==
"copy"
&&
second_half
==
"add"
)
{
thd_grad_correction_helper
<
dtype
,
CopyFunctor
,
AddFunctor
<
dtype
>
,
2
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
{
NVTE_ERROR
(
"Unsupported Functor of first half and second_half
\n
"
);
}
}
void
thd_grad_correction
(
Tensor
grad
,
const
Tensor
&
grad_per_step
,
const
Tensor
&
cu_seqlens
,
const
std
::
string
&
first_half
,
const
std
::
string
&
second_half
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
grad
.
dtype
(),
dtype
,
thd_grad_dispatcher
<
dtype
>
(
grad
,
grad_per_step
,
cu_seqlens
,
first_half
,
second_half
,
stream
););
}
/***************************************************************************************************
* Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/
void
thd_get_partitioned_indices
(
const
Tensor
&
cu_seqlens
,
Tensor
output
,
int
total_tokens
,
int
world_size
,
int
rank
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
NVTE_CHECK
(
cu_seqlens
.
dim
()
==
1
);
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
auto
output_shape
=
output
.
shape
();
NVTE_CHECK
(
cu_seqlens_shape
[
0
]
>=
2
);
NVTE_CHECK
(
rank
>=
0
&&
rank
<
world_size
);
NVTE_CHECK
(
world_size
>
0
);
NVTE_CHECK
(
total_tokens
>
0
&&
total_tokens
%
(
world_size
*
2
)
==
0
);
int
batch
=
cu_seqlens_shape
[
0
]
-
1
;
constexpr
unsigned
int
block
=
256
;
unsigned
int
grid
=
(
output_shape
[
0
]
+
block
-
1
)
/
block
;
thd_partition_indices_kernel
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
int
*>
(
output
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
total_tokens
,
world_size
,
rank
);
}
}
// namespace context_parallel
}
// namespace transformer_engine
}
// namespace transformer_engine
#endif
void
nvte_cp_thd_read_half_tensor
(
const
NVTETensor
&
tensor
,
const
NVTETensor
&
cu_seqlens
,
NVTETensor
half
,
int
half_idx
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_read_half_tensor
);
using
namespace
transformer_engine
;
context_parallel
::
thd_read_half_tensor
(
*
reinterpret_cast
<
Tensor
*>
(
tensor
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
half
),
half_idx
,
stream
);
}
void
nvte_cp_thd_second_half_lse_correction
(
NVTETensor
lse
,
const
NVTETensor
&
lse_per_step
,
const
NVTETensor
&
cu_seqlens
,
int
lse_packed
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_second_half_lse_correction
);
using
namespace
transformer_engine
;
context_parallel
::
thd_second_half_lse_correction
(
*
reinterpret_cast
<
Tensor
*>
(
lse
),
*
reinterpret_cast
<
Tensor
*>
(
lse_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
lse_packed
,
stream
);
}
void
nvte_cp_thd_read_second_half_lse
(
const
NVTETensor
&
lse
,
const
NVTETensor
&
cu_seqlens
,
NVTETensor
half_lse
,
int
lse_packed
,
int
second_half_lse_seqlen
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_read_second_half_lse
);
using
namespace
transformer_engine
;
context_parallel
::
thd_read_second_half_lse
(
*
reinterpret_cast
<
Tensor
*>
(
lse
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
half_lse
),
lse_packed
,
second_half_lse_seqlen
,
stream
);
}
void
nvte_cp_thd_out_correction
(
NVTETensor
out
,
const
NVTETensor
&
out_per_step
,
const
NVTETensor
&
lse
,
const
NVTETensor
&
lse_per_step
,
const
NVTETensor
&
cu_seqlens
,
int
only_second_half
,
int
lse_packed
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_out_correction
);
using
namespace
transformer_engine
;
context_parallel
::
thd_out_correction
(
*
reinterpret_cast
<
Tensor
*>
(
out
),
*
reinterpret_cast
<
Tensor
*>
(
out_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
lse
),
*
reinterpret_cast
<
Tensor
*>
(
lse_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
only_second_half
,
lse_packed
,
stream
);
}
void
nvte_cp_thd_grad_correction
(
NVTETensor
grad
,
const
NVTETensor
&
grad_per_step
,
const
NVTETensor
&
cu_seqlens
,
const
char
*
first_half
,
const
char
*
second_half
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_grad_correction
);
using
namespace
transformer_engine
;
std
::
string
first_half_str
(
first_half
);
std
::
string
second_half_str
(
second_half
);
context_parallel
::
thd_grad_correction
(
*
reinterpret_cast
<
Tensor
*>
(
grad
),
*
reinterpret_cast
<
Tensor
*>
(
grad_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
first_half_str
,
second_half_str
,
stream
);
}
void
nvte_cp_thd_get_partitioned_indices
(
const
NVTETensor
&
cu_seqlens
,
NVTETensor
output
,
int
total_tokens
,
int
world_size
,
int
rank
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_get_partitioned_indices
);
using
namespace
transformer_engine
;
context_parallel
::
thd_get_partitioned_indices
(
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
output
),
total_tokens
,
world_size
,
rank
,
stream
);
}
transformer_engine/common/fused_attn/flash_attn.cu
0 → 100644
View file @
f8c2af4c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace
transformer_engine
{
namespace
flash_attention
{
constexpr
int
warp_size
=
32
;
constexpr
int
type_size
=
2
;
// FP16 or BF16
constexpr
int
nvec
=
sizeof
(
uint64_t
)
/
type_size
;
constexpr
int
load_size
=
warp_size
*
nvec
;
constexpr
int
block_size
=
512
;
template
<
typename
T
>
__launch_bounds__
(
block_size
)
__global__
void
prepare_kernel_fwd
(
const
T
*
qkvi
,
T
*
qkv
,
const
size_t
B
,
const
size_t
S
,
const
size_t
Z
,
const
size_t
W
)
{
const
int
warpid
=
(
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
)
/
warp_size
;
const
int
id_in_warp
=
threadIdx
.
x
%
warp_size
;
const
size_t
offset_input
=
blockIdx
.
y
*
W
+
warpid
*
3
*
W
*
Z
+
id_in_warp
*
nvec
;
const
T
*
my_input
=
qkvi
+
offset_input
;
const
size_t
s
=
warpid
/
B
;
if
(
s
>=
S
)
return
;
const
size_t
b
=
warpid
%
B
;
const
size_t
offset_output
=
blockIdx
.
y
*
B
*
S
*
Z
*
W
+
(
s
+
b
*
S
)
*
W
*
Z
+
id_in_warp
*
nvec
;
T
*
my_output
=
qkv
+
offset_output
;
for
(
int
i
=
0
;
i
<
Z
;
++
i
)
{
uint64_t
*
out
=
reinterpret_cast
<
uint64_t
*>
(
my_output
+
i
*
load_size
);
*
out
=
*
reinterpret_cast
<
const
uint64_t
*>
(
my_input
+
i
*
load_size
*
3
);
}
}
template
<
typename
T
>
__launch_bounds__
(
block_size
)
__global__
void
prepare_kernel_bwd
(
const
T
*
q
,
const
T
*
k
,
const
T
*
v
,
T
*
qkv
,
const
size_t
B
,
const
size_t
S
,
const
size_t
Z
,
const
size_t
W
)
{
const
T
*
input
=
blockIdx
.
y
==
0
?
q
:
(
blockIdx
.
y
==
1
?
k
:
v
);
const
int
warpid
=
(
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
)
/
warp_size
;
const
int
id_in_warp
=
threadIdx
.
x
%
warp_size
;
const
size_t
offset_input
=
warpid
*
W
*
Z
+
id_in_warp
*
nvec
;
const
T
*
my_input
=
input
+
offset_input
;
const
size_t
b
=
warpid
/
S
;
if
(
b
>=
B
)
return
;
const
size_t
s
=
warpid
%
S
;
const
size_t
offset_output
=
(
b
+
s
*
B
)
*
3
*
W
*
Z
+
id_in_warp
*
nvec
+
blockIdx
.
y
*
W
;
T
*
my_output
=
qkv
+
offset_output
;
for
(
int
i
=
0
;
i
<
Z
;
++
i
)
{
uint64_t
*
out
=
reinterpret_cast
<
uint64_t
*>
(
my_output
+
i
*
load_size
*
3
);
*
out
=
*
reinterpret_cast
<
const
uint64_t
*>
(
my_input
+
i
*
load_size
);
}
}
void
prepare_flash_attn_fwd
(
Tensor
qkvi
,
Tensor
qkv
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
qkvi
.
dim
()
==
4
,
"Expected 4-dim tensor."
);
NVTE_CHECK
(
qkvi
.
dtype
()
==
DType
::
kFloat16
||
qkvi
.
dtype
()
==
DType
::
kBFloat16
);
auto
qkvi_shape
=
qkvi
.
shape
();
NVTE_CHECK
(
qkvi_shape
[
3
]
%
load_size
==
0
);
NVTE_CHECK
(
qkvi_shape
[
3
]
==
load_size
);
// [s, b, n, h * 3] -> [3, b, s, n, h]
std
::
vector
<
uint64_t
>
shape
=
{
3
,
qkvi_shape
[
1
],
qkvi_shape
[
0
],
qkvi_shape
[
2
],
qkvi_shape
[
3
]};
size_t
warps
=
qkvi_shape
[
0
]
*
qkvi_shape
[
1
];
size_t
warps_per_block
=
block_size
/
warp_size
;
size_t
blocks
=
(
warps
+
warps_per_block
-
1
)
/
warps_per_block
;
dim3
grid
(
blocks
,
3
);
int
threads
=
block_size
;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT
(
qkvi
.
dtype
(),
dtype
,
prepare_kernel_fwd
<
dtype
><<<
grid
,
threads
,
0
,
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
qkvi
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
qkv
.
data
.
dptr
),
shape
[
1
],
shape
[
2
],
shape
[
3
],
shape
[
4
]););
}
void
prepare_flash_attn_bwd
(
Tensor
q
,
Tensor
k
,
Tensor
v
,
Tensor
qkv
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
q
.
dim
()
==
4
,
"Expected 4-dim tensor."
);
NVTE_CHECK
(
k
.
dim
()
==
4
,
"Expected 4-dim tensor."
);
NVTE_CHECK
(
v
.
dim
()
==
4
,
"Expected 4-dim tensor."
);
NVTE_CHECK
(
q
.
dtype
()
==
DType
::
kFloat16
||
q
.
dtype
()
==
DType
::
kBFloat16
);
NVTE_CHECK
(
k
.
dtype
()
==
q
.
dtype
());
NVTE_CHECK
(
v
.
dtype
()
==
q
.
dtype
());
auto
q_shape
=
q
.
shape
();
auto
k_shape
=
k
.
shape
();
auto
v_shape
=
v
.
shape
();
NVTE_CHECK
(
q_shape
[
3
]
%
load_size
==
0
);
NVTE_CHECK
(
q_shape
[
3
]
==
load_size
);
NVTE_CHECK
(
k_shape
[
3
]
%
load_size
==
0
);
NVTE_CHECK
(
k_shape
[
3
]
==
load_size
);
NVTE_CHECK
(
v_shape
[
3
]
%
load_size
==
0
);
NVTE_CHECK
(
v_shape
[
3
]
==
load_size
);
// 3 x [s, b, n, h] -> [b, s, n, 3 * h]
std
::
vector
<
uint64_t
>
shape
=
{
q_shape
[
1
],
q_shape
[
0
],
q_shape
[
2
],
3
*
q_shape
[
3
]};
size_t
warps
=
q_shape
[
0
]
*
q_shape
[
1
];
size_t
warps_per_block
=
block_size
/
warp_size
;
size_t
blocks
=
(
warps
+
warps_per_block
-
1
)
/
warps_per_block
;
dim3
grid
(
blocks
,
3
);
int
threads
=
block_size
;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT
(
q
.
dtype
(),
dtype
,
prepare_kernel_bwd
<
dtype
><<<
grid
,
threads
,
0
,
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
q
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
k
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
v
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
qkv
.
data
.
dptr
),
q_shape
[
0
],
q_shape
[
1
],
q_shape
[
2
],
q_shape
[
3
]););
}
}
// namespace flash_attention
}
// namespace transformer_engine
void
nvte_prepare_flash_attn_fwd
(
NVTETensor
qkvi
,
NVTETensor
qkv
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_prepare_flash_attn_fwd
);
using
namespace
transformer_engine
;
flash_attention
::
prepare_flash_attn_fwd
(
*
reinterpret_cast
<
Tensor
*>
(
qkvi
),
*
reinterpret_cast
<
Tensor
*>
(
qkv
),
stream
);
}
void
nvte_prepare_flash_attn_bwd
(
NVTETensor
q
,
NVTETensor
k
,
NVTETensor
v
,
NVTETensor
qkv
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_prepare_flash_attn_bwd
);
using
namespace
transformer_engine
;
flash_attention
::
prepare_flash_attn_bwd
(
*
reinterpret_cast
<
Tensor
*>
(
q
),
*
reinterpret_cast
<
Tensor
*>
(
k
),
*
reinterpret_cast
<
Tensor
*>
(
v
),
*
reinterpret_cast
<
Tensor
*>
(
qkv
),
stream
);
}
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
f8c2af4c
...
@@ -1006,3 +1006,18 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -1006,3 +1006,18 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_ERROR
(
"Invalid combination of data type and sequence length for fused attention.
\n
"
);
NVTE_ERROR
(
"Invalid combination of data type and sequence length for fused attention.
\n
"
);
}
}
}
}
uint32_t
nvte_get_runtime_num_segments
(
NVTETensor
cu_seqlen
,
NVTETensor
workspace
,
size_t
len
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_get_runtime_num_segments
);
using
namespace
transformer_engine
::
fused_attn
;
return
GetRuntimeNumSegments
(
cu_seqlen
,
workspace
,
len
,
stream
);
}
void
nvte_populate_rng_state_async
(
NVTETensor
rng_state_dst
,
const
NVTETensor
seed
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
NVTE_Fused_Attn_Backend
backend
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_populate_rng_state_async
);
using
namespace
transformer_engine
::
fused_attn
;
PopulateRngStateAsync
(
rng_state_dst
,
seed
,
q_max_seqlen
,
kv_max_seqlen
,
backend
,
stream
);
}
transformer_engine/
pytorch/csrc
/kv_cache.cu
h
→
transformer_engine/
common/fused_attn
/kv_cache.cu
View file @
f8c2af4c
...
@@ -3,48 +3,15 @@
...
@@ -3,48 +3,15 @@
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace
transformer_engine
{
#include "../common.h"
namespace
fused_attn
{
#include "transformer_engine/fused_attn.h"
template
<
typename
scalar_t
>
__global__
void
convert_thd_to_bshd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
num_elts
=
(
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
])
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
scalar_t
*
thd_token
=
tensor
+
thd_offset
;
scalar_t
*
bshd_token
=
new_tensor
+
bshd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
bshd_token
+
i
)
=
*
(
thd_token
+
i
);
}
}
}
template
<
typename
scalar_t
>
namespace
transformer_engine
{
__global__
void
convert_bshd_to_thd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
namespace
kv_cache
{
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
seqlen
=
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
];
int
num_elts
=
seqlen
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
scalar_t
*
bshd_token
=
tensor
+
bshd_offset
;
scalar_t
*
thd_token
=
new_tensor
+
thd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
thd_token
+
i
)
=
*
(
bshd_token
+
i
);
}
}
}
template
<
typename
scalar_t
>
template
<
typename
dtype
>
__global__
void
reindex_kv_cache_kernel
(
scalar_t
*
k_cache
,
scalar_t
*
v_cache
,
int
*
batch_indices
,
__global__
void
reindex_kv_cache_kernel
(
dtype
*
k_cache
,
dtype
*
v_cache
,
int
*
batch_indices
,
int
*
cu_new_lens
,
int
*
cu_cached_lens
,
int
h_kv
,
int
d_k
,
int
*
cu_new_lens
,
int
*
cu_cached_lens
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
b
,
int
max_seq_len
)
{
int
d_v
,
int
b
,
int
max_seq_len
)
{
// k_cache, v_cache: bshd
// k_cache, v_cache: bshd
...
@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in
...
@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in
}
}
}
}
template
<
typename
scalar_t
>
template
<
typename
dtype
>
__global__
void
copy_to_kv_cache_kernel
(
scalar_t
*
new_k
,
scalar_t
*
new_v
,
scalar_t
*
k
_cache
,
__global__
void
copy_to_kv_cache_kernel
(
dtype
*
new_k
,
dtype
*
new_v
,
dtype
*
k_cache
,
dtype
*
v
_cache
,
scalar_t
*
v_cache
,
int
*
page_table
,
int
*
cu_new_lens
,
int
*
page_table
,
int
*
cu_new_lens
,
int
*
cu_cached_lens
,
int
*
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
h_kv
,
NVTE_QKV_Format
qkv_format
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
d_k
,
int
d_v
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
bool
is_non_paged
)
{
int
max_pages_per_seq
,
bool
is_non_paged
)
{
// new_k, new_v: qkv_format; k_cache, v_cache: bshd
// new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1]
// cu_new_lens, cu_cached_lens: [b + 1]
...
@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
...
@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
}
}
}
}
}
}
}
// namespace fused_attn
template
<
typename
dtype
>
void
copy_to_kv_cache_launcher
(
Tensor
new_k
,
Tensor
new_v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
page_table
,
Tensor
cu_new_lens
,
Tensor
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
bool
is_non_paged
,
cudaStream_t
stream
)
{
if
(
new_k
.
has_data
()
&&
new_v
.
has_data
()
&&
k_cache
.
has_data
()
&&
v_cache
.
has_data
())
{
if
(
is_non_paged
)
{
reindex_kv_cache_kernel
<<<
16
,
256
,
0
,
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
k_cache
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
v_cache
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
page_table
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_new_lens
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_cached_lens
.
data
.
dptr
),
h_kv
,
d_k
,
d_v
,
b
,
max_seq_len
);
}
copy_to_kv_cache_kernel
<<<
16
,
256
,
0
,
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
new_k
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
new_v
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
k_cache
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
v_cache
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
page_table
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_new_lens
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_cached_lens
.
data
.
dptr
),
qkv_format
,
h_kv
,
d_k
,
d_v
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
);
}
}
void
copy_to_kv_cache
(
Tensor
new_k
,
Tensor
new_v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
page_table
,
Tensor
cu_new_lens
,
Tensor
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
bool
is_non_paged
,
cudaStream_t
stream
)
{
int
h_kv
=
new_k
.
shape
()[
new_k
.
dim
()
-
2
];
int
d_k
=
new_k
.
shape
()[
new_k
.
dim
()
-
1
];
int
d_v
=
new_v
.
shape
()[
new_v
.
dim
()
-
1
];
NVTE_CHECK
(
k_cache
.
dtype
()
==
v_cache
.
dtype
()
&&
new_k
.
dtype
()
==
new_v
.
dtype
()
&&
new_k
.
dtype
()
==
k_cache
.
dtype
(),
"new_k, new_v, k_cache and v_cache must be of the same data type."
);
NVTE_CHECK
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
,
"qkv_format must be {BSHD, SBHD, THD}."
);
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT
(
k_cache
.
dtype
(),
dtype
,
copy_to_kv_cache_launcher
<
dtype
>
(
new_k
,
new_v
,
k_cache
,
v_cache
,
page_table
,
cu_new_lens
,
cu_cached_lens
,
qkv_format
,
h_kv
,
d_k
,
d_v
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
,
stream
););
}
template
<
typename
scalar_t
>
__global__
void
convert_thd_to_bshd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
num_elts
=
(
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
])
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
scalar_t
*
thd_token
=
tensor
+
thd_offset
;
scalar_t
*
bshd_token
=
new_tensor
+
bshd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
bshd_token
+
i
)
=
*
(
thd_token
+
i
);
}
}
}
template
<
typename
scalar_t
>
void
convert_thd_to_bshd_launcher
(
Tensor
tensor
,
Tensor
new_tensor
,
Tensor
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
convert_thd_to_bshd_kernel
<<<
16
,
256
,
0
,
stream
>>>
(
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
b
,
max_seq_len
,
h
,
d
);
}
void
convert_thd_to_bshd
(
Tensor
tensor
,
Tensor
cu_seqlens
,
Tensor
new_tensor
,
int
b
,
int
max_seq_len
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
auto
tensor_shape
=
tensor
.
shape
();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT
(
new_tensor
.
dtype
(),
dtype
,
convert_thd_to_bshd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
tensor_shape
[
1
],
tensor_shape
[
2
],
stream
););
}
template
<
typename
scalar_t
>
__global__
void
convert_bshd_to_thd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
seqlen
=
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
];
int
num_elts
=
seqlen
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
scalar_t
*
bshd_token
=
tensor
+
bshd_offset
;
scalar_t
*
thd_token
=
new_tensor
+
thd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
thd_token
+
i
)
=
*
(
bshd_token
+
i
);
}
}
}
template
<
typename
scalar_t
>
void
convert_bshd_to_thd_launcher
(
Tensor
tensor
,
Tensor
new_tensor
,
Tensor
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
convert_bshd_to_thd_kernel
<<<
16
,
256
,
0
,
stream
>>>
(
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
b
,
max_seq_len
,
h
,
d
);
}
void
convert_bshd_to_thd
(
Tensor
tensor
,
Tensor
cu_seqlens
,
Tensor
new_tensor
,
int
t
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
auto
tensor_shape
=
tensor
.
shape
();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT
(
tensor
.
dtype
(),
dtype
,
convert_bshd_to_thd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
tensor_shape
[
0
],
tensor_shape
[
1
],
tensor_shape
[
2
],
tensor_shape
[
3
],
stream
););
}
}
// namespace kv_cache
}
// namespace transformer_engine
}
// namespace transformer_engine
#endif
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
void
nvte_copy_to_kv_cache
(
NVTETensor
new_k
,
NVTETensor
new_v
,
NVTETensor
k_cache
,
NVTETensor
v_cache
,
NVTETensor
page_table
,
NVTETensor
cu_new_lens
,
NVTETensor
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
int
is_non_paged
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_copy_to_kv_cache
);
using
namespace
transformer_engine
;
kv_cache
::
copy_to_kv_cache
(
*
reinterpret_cast
<
Tensor
*>
(
new_k
),
*
reinterpret_cast
<
Tensor
*>
(
new_v
),
*
reinterpret_cast
<
Tensor
*>
(
k_cache
),
*
reinterpret_cast
<
Tensor
*>
(
v_cache
),
*
reinterpret_cast
<
Tensor
*>
(
page_table
),
*
reinterpret_cast
<
Tensor
*>
(
cu_new_lens
),
*
reinterpret_cast
<
Tensor
*>
(
cu_cached_lens
),
qkv_format
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
,
stream
);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
void
nvte_convert_thd_to_bshd
(
NVTETensor
tensor
,
NVTETensor
cu_seqlens
,
NVTETensor
new_tensor
,
int
b
,
int
max_seq_len
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_convert_thd_to_bshd
);
using
namespace
transformer_engine
;
kv_cache
::
convert_thd_to_bshd
(
*
reinterpret_cast
<
Tensor
*>
(
tensor
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
new_tensor
),
b
,
max_seq_len
,
stream
);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
void
nvte_convert_bshd_to_thd
(
NVTETensor
tensor
,
NVTETensor
cu_seqlens
,
NVTETensor
new_tensor
,
int
t
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_convert_bshd_to_thd
);
using
namespace
transformer_engine
;
kv_cache
::
convert_bshd_to_thd
(
*
reinterpret_cast
<
Tensor
*>
(
tensor
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
new_tensor
),
t
,
stream
);
}
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment