Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f8c2af4c
Commit
f8c2af4c
authored
May 21, 2025
by
yuguo
Browse files
Merge commit '
1d903f5e
' of...
Merge commit '
1d903f5e
' of
https://github.com/NVIDIA/TransformerEngine
parents
e92773a3
1d903f5e
Changes
211
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1970 additions
and
139 deletions
+1970
-139
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
+1
-1
tests/pytorch/fused_attn/test_kv_cache.py
tests/pytorch/fused_attn/test_kv_cache.py
+10
-9
tests/pytorch/test_float8blockwisetensor.py
tests/pytorch/test_float8blockwisetensor.py
+104
-0
tests/pytorch/test_float8tensor.py
tests/pytorch/test_float8tensor.py
+2
-2
tests/pytorch/test_fused_optimizer.py
tests/pytorch/test_fused_optimizer.py
+8
-7
tests/pytorch/test_fused_rope.py
tests/pytorch/test_fused_rope.py
+69
-13
tests/pytorch/test_multi_tensor.py
tests/pytorch/test_multi_tensor.py
+1
-1
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+13
-6
tests/pytorch/test_parallel_cross_entropy.py
tests/pytorch/test_parallel_cross_entropy.py
+31
-3
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+32
-8
transformer_engine/__init__.py
transformer_engine/__init__.py
+2
-2
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+20
-0
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+238
-41
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+114
-44
transformer_engine/common/common.cu
transformer_engine/common/common.cu
+71
-0
transformer_engine/common/common.h
transformer_engine/common/common.h
+46
-2
transformer_engine/common/fused_attn/context_parallel.cu
transformer_engine/common/fused_attn/context_parallel.cu
+743
-0
transformer_engine/common/fused_attn/flash_attn.cu
transformer_engine/common/fused_attn/flash_attn.cu
+153
-0
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+15
-0
transformer_engine/common/fused_attn/kv_cache.cu
transformer_engine/common/fused_attn/kv_cache.cu
+297
-0
No files found.
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
View file @
f8c2af4c
...
...
@@ -11,7 +11,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability
,
get_cudnn_version
,
)
from
transformer_engine.pytorch.dot_product_attention.utils
import
FlashAttentionUtils
from
transformer_engine.pytorch.
attention.
dot_product_attention.utils
import
FlashAttentionUtils
from
test_fused_attn
import
ModelConfig
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
...
tests/pytorch/fused_attn/test_kv_cache.py
View file @
f8c2af4c
...
...
@@ -11,6 +11,12 @@ import math
import
pytest
import
torch
from
test_fused_attn
import
(
ModelConfig
,
reset_rng_states
,
_get_attention_backends
,
)
from
torch.distributions
import
Exponential
from
transformer_engine.pytorch
import
make_graphed_callables
from
transformer_engine.common
import
recipe
...
...
@@ -18,20 +24,15 @@ from transformer_engine.pytorch import fp8_autocast, fp8_model_init
from
transformer_engine.pytorch.transformer
import
(
TransformerLayer
,
)
from
transformer_engine.pytorch.attention
import
DotProductAttention
from
transformer_engine.pytorch.dot_product_attention.inference
import
InferenceParams
from
transformer_engine.pytorch.dot_product_attention.utils
import
FlashAttentionUtils
as
fa_utils
from
transformer_engine.pytorch.attention
import
DotProductAttention
,
InferenceParams
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
FlashAttentionUtils
as
fa_utils
,
)
from
transformer_engine.pytorch.utils
import
(
get_device_compute_capability
,
init_method_normal
,
scaled_init_method_normal
,
is_bf16_compatible
,
)
from
test_fused_attn
import
(
ModelConfig
,
reset_rng_states
,
_get_attention_backends
,
)
# Initialize RNG state
seed
=
1234
...
...
tests/pytorch/test_float8blockwisetensor.py
View file @
f8c2af4c
...
...
@@ -392,6 +392,110 @@ class TestFloat8BlockwiseTensor:
with
pytest
.
raises
(
AssertionError
):
torch
.
testing
.
assert_close
(
x_view
.
dequantize
(),
-
x_hp
,
**
_tols
[
fp8_dtype
])
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
16
,
16
,
512
],
[
16
,
16
,
512
,
16
],
[
12
,
7
,
11
],
[
13
,
14
,
16
],
[
2
,
3
,
5
]]
)
def
test_view_and_reshape_1D
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
dims
:
List
[
int
]
)
->
None
:
"""Test view operations that preserve tensor shape"""
device
=
"cuda"
def
is_bitwise_equal
(
a
,
b
):
if
a
.
numel
()
!=
b
.
numel
():
return
False
a_flat
=
a
.
reshape
(
-
1
).
view
(
torch
.
uint8
)
b_flat
=
b
.
reshape
(
-
1
).
view
(
torch
.
uint8
)
return
torch
.
all
((
a_flat
^
b_flat
)
==
0
)
x_hp
=
torch
.
rand
(
dims
,
dtype
=
dtype
,
device
=
device
)
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
1
,
)
x_fp8
=
quantizer
.
make_empty
(
x_hp
.
shape
,
dtype
=
dtype
,
device
=
device
)
quantizer
.
update_quantized
(
x_hp
.
clone
(),
x_fp8
)
# Test view, high dimension tensor -> 2D tensor
x_hp_view
=
x_hp
.
view
(
-
1
,
dims
[
-
1
]).
contiguous
()
x_fp8_view
=
x_fp8
.
view
(
-
1
,
dims
[
-
1
])
# Check the dequantized result
torch
.
testing
.
assert_close
(
x_fp8_view
.
dequantize
().
contiguous
(),
x_hp_view
,
**
_tols
[
fp8_dtype
]
)
# Check the bitwise equality of the inner data
assert
is_bitwise_equal
(
x_fp8_view
.
_rowwise_data
,
x_fp8
.
_rowwise_data
)
assert
is_bitwise_equal
(
x_fp8_view
.
_rowwise_scale_inv
,
x_fp8
.
_rowwise_scale_inv
)
# Check the data ptr
assert
x_fp8_view
.
_rowwise_data
.
data_ptr
()
==
x_fp8
.
_rowwise_data
.
data_ptr
()
assert
x_fp8_view
.
_rowwise_scale_inv
.
data_ptr
()
==
x_fp8
.
_rowwise_scale_inv
.
data_ptr
()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape
=
x_hp
.
reshape
(
-
1
,
dims
[
-
1
]).
contiguous
()
x_fp8_reshape
=
x_fp8
.
reshape
(
-
1
,
dims
[
-
1
])
# Check the dequantized result
torch
.
testing
.
assert_close
(
x_fp8_reshape
.
dequantize
().
contiguous
(),
x_hp_reshape
,
**
_tols
[
fp8_dtype
]
)
# Check the bitwise equality of the inner data
assert
is_bitwise_equal
(
x_fp8_reshape
.
_rowwise_data
,
x_fp8
.
_rowwise_data
)
assert
is_bitwise_equal
(
x_fp8_reshape
.
_rowwise_scale_inv
,
x_fp8
.
_rowwise_scale_inv
)
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
16
,
16
,
512
,
16
],
[
2
,
512
,
512
,
128
],
[
3
,
13
,
14
,
16
]])
def
test_view_and_reshape_2D
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
dims
:
List
[
int
]
)
->
None
:
"""Test view operations that preserve tensor shape"""
device
=
"cuda"
def
is_bitwise_equal
(
a
,
b
):
if
a
.
numel
()
!=
b
.
numel
():
return
False
a_flat
=
a
.
reshape
(
-
1
).
view
(
torch
.
uint8
)
b_flat
=
b
.
reshape
(
-
1
).
view
(
torch
.
uint8
)
return
torch
.
all
((
a_flat
^
b_flat
)
==
0
)
x_hp
=
torch
.
rand
(
dims
,
dtype
=
dtype
,
device
=
device
)
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
2
,
)
x_fp8
=
quantizer
.
make_empty
(
x_hp
.
shape
,
dtype
=
dtype
,
device
=
device
)
quantizer
.
update_quantized
(
x_hp
.
clone
(),
x_fp8
)
# Test view, high dimension tensor -> 2D tensor
x_hp_view
=
x_hp
.
view
(
-
1
,
dims
[
-
2
],
dims
[
-
1
]).
contiguous
()
x_fp8_view
=
x_fp8
.
view
(
-
1
,
dims
[
-
2
],
dims
[
-
1
])
# Check the dequantized result
torch
.
testing
.
assert_close
(
x_fp8_view
.
dequantize
().
contiguous
(),
x_hp_view
,
**
_tols
[
fp8_dtype
]
)
# Check the bitwise equality of the inner data
assert
is_bitwise_equal
(
x_fp8_view
.
_rowwise_data
,
x_fp8
.
_rowwise_data
)
assert
is_bitwise_equal
(
x_fp8_view
.
_rowwise_scale_inv
,
x_fp8
.
_rowwise_scale_inv
)
# Check the data ptr
assert
x_fp8_view
.
_rowwise_data
.
data_ptr
()
==
x_fp8
.
_rowwise_data
.
data_ptr
()
assert
x_fp8_view
.
_rowwise_scale_inv
.
data_ptr
()
==
x_fp8
.
_rowwise_scale_inv
.
data_ptr
()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape
=
x_hp
.
reshape
(
-
1
,
dims
[
-
2
],
dims
[
-
1
]).
contiguous
()
x_fp8_reshape
=
x_fp8
.
reshape
(
-
1
,
dims
[
-
2
],
dims
[
-
1
])
# Check the dequantized result
torch
.
testing
.
assert_close
(
x_fp8_reshape
.
dequantize
().
contiguous
(),
x_hp_reshape
,
**
_tols
[
fp8_dtype
]
)
# Check the bitwise equality of the inner data
assert
is_bitwise_equal
(
x_fp8_reshape
.
_rowwise_data
,
x_fp8
.
_rowwise_data
)
assert
is_bitwise_equal
(
x_fp8_reshape
.
_rowwise_scale_inv
,
x_fp8
.
_rowwise_scale_inv
)
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
...
...
tests/pytorch/test_float8tensor.py
View file @
f8c2af4c
...
...
@@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.constants
import
TE_DType
,
TE_DType_To_Torch
from
transformer_engine.pytorch.utils
import
non_tn_fp8_gemm_supported
from
transformer_engine.pytorch.utils
import
is_
non_tn_fp8_gemm_supported
import
transformer_engine_torch
as
tex
from
references.ref_per_tensor_cs
import
ref_per_tensor_cs_cast
...
...
@@ -400,7 +400,7 @@ class TestCurrentScalingFloat8Tensor:
"""Check numerical error when casting to FP8"""
# Skip invalid configurations
if
non_tn_fp8_gemm_supported
()
and
return_transpose
:
if
is_
non_tn_fp8_gemm_supported
()
and
return_transpose
:
pytest
.
skip
(
"FP8 transpose is neither needed nor supported on current system"
)
# Initialize random high precision data
...
...
tests/pytorch/test_fused_optimizer.py
View file @
f8c2af4c
...
...
@@ -12,10 +12,11 @@ from torch import nn
from
torch.testing._internal.common_device_type
import
largeTensorTest
import
transformer_engine.pytorch
as
te
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.pytorch.attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention
.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch
import
fp8_model_init
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
gpu_autocast_ctx
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
...
...
@@ -596,7 +597,7 @@ class AdamTest:
gt_
=
gt
.
clone
()
# Reference
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model
(
x
)
loss
=
((
gt
-
y
)
**
2
).
mean
()
...
...
@@ -605,7 +606,7 @@ class AdamTest:
scaler
.
update
()
# DUT
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model_
(
x
)
loss_
=
((
gt_
-
y
)
**
2
).
mean
()
...
...
@@ -647,7 +648,7 @@ class AdamTest:
gt_
=
gt
.
clone
()
# Reference
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model
(
x
)
loss
=
((
gt
-
y
)
**
2
).
mean
()
...
...
@@ -656,7 +657,7 @@ class AdamTest:
scaler
.
update
()
# DUT
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model_
(
x
)
loss_
=
((
gt_
-
y
)
**
2
).
mean
()
...
...
@@ -705,7 +706,7 @@ class AdamTest:
gt_
=
gt
.
clone
()
# Reference
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model
(
x
)
loss
=
((
gt
-
y
)
**
2
).
mean
()
...
...
@@ -714,7 +715,7 @@ class AdamTest:
scaler
.
update
()
# DUT
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
with
gpu_
autocast
_ctx
(
enabled
=
True
):
y
=
self
.
model_
(
x
)
loss_
=
((
gt_
-
y
)
**
2
).
mean
()
...
...
tests/pytorch/test_fused_rope.py
View file @
f8c2af4c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
Callable
,
Tuple
,
Union
import
math
import
pytest
import
torch
from
typing
import
Callable
,
Tuple
,
Union
from
transformer_engine.pytorch.
dot_product_
attention.rope
import
(
import
pytest
from
transformer_engine.pytorch.attention.rope
import
(
RotaryPositionEmbedding
,
apply_rotary_pos_emb
,
)
...
...
@@ -22,6 +22,7 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return
torch
.
sum
(
output
*
t
)
@
pytest
.
mark
.
parametrize
(
"start_positions"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"seq_length"
,
[
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
...
...
@@ -43,7 +44,17 @@ def test_fused_rope(
loss_func
:
Callable
,
cp_size
:
int
,
interleaved
:
bool
,
start_positions
:
bool
,
)
->
None
:
if
margin
==
0
and
start_positions
==
True
:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest
.
skip
(
"Skipping test with margin=0 and start_positions=True"
)
if
start_positions
==
True
and
cp_size
>
1
:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest
.
skip
(
"Skipping test with cp_size>1 and start_positions=True"
)
device
=
torch
.
device
(
"cuda:0"
)
batch_size
,
head_num
=
2
,
64
t
=
torch
.
rand
(
...
...
@@ -51,6 +62,14 @@ def test_fused_rope(
dtype
=
dtype
,
device
=
device
,
)
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions
=
(
torch
.
randint
(
0
,
margin
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
if
start_positions
else
None
)
if
tensor_format
==
"bshd"
:
t
=
t
.
transpose
(
0
,
1
).
contiguous
()
if
transpose
:
...
...
@@ -69,14 +88,18 @@ def test_fused_rope(
t
.
float
(),
emb
,
tensor_format
=
tensor_format
,
start_positions
=
start_positions
,
interleaved
=
interleaved
,
fused
=
False
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
).
to
(
dtype
)
loss_unfused
=
loss_func
(
output_unfused
)
loss_unfused
.
backward
()
grad_unfused
=
t
.
grad
.
detach
().
clone
()
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_unfused
.
backward
()
grad_unfused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
# fused
...
...
@@ -84,21 +107,29 @@ def test_fused_rope(
t
,
emb
,
tensor_format
=
tensor_format
,
start_positions
=
start_positions
,
interleaved
=
interleaved
,
fused
=
True
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
)
loss_fused
=
loss_func
(
output_fused
)
loss_fused
.
backward
()
grad_fused
=
t
.
grad
.
detach
().
clone
()
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_fused
.
backward
()
grad_fused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
torch
.
testing
.
assert_close
(
output_fused
,
output_unfused
)
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
assert
output_fused
.
is_contiguous
()
@
pytest
.
mark
.
parametrize
(
"margin"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"start_positions"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"rotary_percent"
,
[
0.5
,
1.0
])
...
...
@@ -114,10 +145,25 @@ def test_fused_rope_thd(
loss_func
:
Callable
,
cp_size
:
int
,
interleaved
:
bool
,
start_positions
:
bool
,
margin
:
int
,
)
->
None
:
if
start_positions
==
True
and
cp_size
>
1
:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest
.
skip
(
"Skipping test with cp_size>1 and start_positions=True"
)
device
=
torch
.
device
(
"cuda:0"
)
batch_size
,
head_num
=
2
,
64
cu_seqlens
=
[
0
,
400
,
542
,
711
,
727
,
752
,
1270
,
1426
,
1450
,
1954
,
2044
,
2048
]
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions
=
(
torch
.
randint
(
0
,
margin
,
(
len
(
cu_seqlens
)
-
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
if
start_positions
else
None
)
if
cp_size
>
1
:
cu_seqlens_padded
=
[
0
]
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
...
...
@@ -152,6 +198,7 @@ def test_fused_rope_thd(
output_unfused
=
apply_rotary_pos_emb
(
t
.
float
(),
emb
,
start_positions
=
start_positions
,
tensor_format
=
"thd"
,
interleaved
=
interleaved
,
fused
=
False
,
...
...
@@ -160,14 +207,17 @@ def test_fused_rope_thd(
cp_rank
=
cp_rank
,
).
to
(
dtype
)
loss_unfused
=
loss_func
(
output_unfused
)
loss_unfused
.
backward
()
grad_unfused
=
t
.
grad
.
detach
().
clone
()
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_unfused
.
backward
()
grad_unfused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
# fused
output_fused
=
apply_rotary_pos_emb
(
t
,
emb
,
start_positions
=
start_positions
,
interleaved
=
interleaved
,
fused
=
True
,
tensor_format
=
"thd"
,
...
...
@@ -176,9 +226,15 @@ def test_fused_rope_thd(
cp_rank
=
cp_rank
,
)
loss_fused
=
loss_func
(
output_fused
)
loss_fused
.
backward
()
grad_fused
=
t
.
grad
.
detach
().
clone
()
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
loss_fused
.
backward
()
grad_fused
=
t
.
grad
.
detach
().
clone
()
t
.
grad
=
None
torch
.
testing
.
assert_close
(
output_fused
,
output_unfused
)
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
if
not
isinstance
(
start_positions
,
torch
.
Tensor
):
torch
.
testing
.
assert_close
(
grad_fused
,
grad_unfused
)
assert
output_fused
.
is_contiguous
()
tests/pytorch/test_multi_tensor.py
View file @
f8c2af4c
...
...
@@ -160,7 +160,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
normab
=
torch
.
cat
((
a
.
norm
().
view
(
1
),
b
.
norm
().
view
(
1
)))
norm_per_tensor
=
norm_per_tensor
.
view
(
-
1
,
2
)
else
:
norm
,
_
=
applier
(
tex
.
multi_tensor_l2norm
,
overflow_buf
,
[
in_list
],
Tru
e
)
norm
,
_
=
applier
(
tex
.
multi_tensor_l2norm
,
overflow_buf
,
[
in_list
],
Fals
e
)
reference
=
torch
.
full
(
[(
sizea
+
sizeb
)
*
repeat
],
val
,
dtype
=
torch
.
float32
,
device
=
device
...
...
tests/pytorch/test_numerics.py
View file @
f8c2af4c
...
...
@@ -7,7 +7,6 @@ import math
import
os
from
typing
import
Dict
,
List
,
Tuple
,
Optional
import
pytest
import
copy
import
random
import
torch
...
...
@@ -40,12 +39,12 @@ from transformer_engine.pytorch import (
Fp8Unpadding
,
)
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.
dot_product_
attention.inference
import
InferenceParams
from
transformer_engine.pytorch.attention.inference
import
InferenceParams
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
,
get_cudnn_version
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
...
...
@@ -135,18 +134,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
def
assert_allclose
(
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
atol
:
float
,
rtol
:
float
=
None
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
atol
:
float
=
None
,
rtol
:
float
=
None
)
->
bool
:
"""Ensures two lists are equal."""
assert
len
(
l1
)
==
len
(
l2
),
"Unequal number of outputs."
for
i
,
(
t1
,
t2
)
in
enumerate
(
zip
(
l1
,
l2
)):
tols
=
d
ict
(
atol
=
atol
)
tols
=
d
type_tols
(
t2
.
dtype
)
if
rtol
is
not
None
:
tols
[
"rtol"
]
=
rtol
if
atol
is
not
None
:
tols
[
"atol"
]
=
atol
result
=
torch
.
allclose
(
t1
,
t2
,
**
tols
)
if
not
result
:
diff
=
torch
.
abs
(
t1
-
t2
)
tol
=
atol
+
(
rtol
*
torch
.
abs
(
t2
))
tol
=
tols
[
"
atol
"
]
+
(
tols
[
"
rtol
"
]
*
torch
.
abs
(
t2
))
exceed_mask
=
diff
>
tol
if
exceed_mask
.
any
():
indices
=
torch
.
nonzero
(
exceed_mask
,
as_tuple
=
True
)
...
...
@@ -2304,6 +2305,12 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
pytest
.
skip
(
"FusedAttention and FlashAttention do not support FP32"
)
if
use_RoPE
:
pytest
.
skip
(
"KV cache does not support starting positions for RoPE"
)
if
(
backend
==
"FusedAttention"
and
get_device_compute_capability
()
==
(
8
,
9
)
and
get_cudnn_version
()
<
(
9
,
11
,
0
)
):
pytest
.
skip
(
"Skip KV cache for sm89 and cuDNN < 9.11"
)
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"0"
...
...
tests/pytorch/test_parallel_cross_entropy.py
View file @
f8c2af4c
...
...
@@ -19,11 +19,12 @@ class TestParallelCrossEntropy:
label_smoothing
=
label_smoothing
,
reduction
=
"mean"
if
reduce_loss
else
"none"
)
def
generate_input
(
self
,
dtype
:
torch
.
dtype
,
swap_dim
:
bool
):
def
generate_input
(
self
,
dtype
:
torch
.
dtype
,
swap_dim
:
bool
,
ignore_idx
:
bool
):
SQ
=
random
.
choice
([
64
,
128
])
batch
=
random
.
choice
([
1
,
2
])
vocab
=
random
.
choice
([
64000
,
128000
])
ignore
=
random
.
sample
(
range
(
0
,
SQ
-
1
),
5
)
if
swap_dim
:
self
.
input_test
=
torch
.
rand
((
SQ
,
batch
,
vocab
),
dtype
=
dtype
).
cuda
()
...
...
@@ -32,14 +33,27 @@ class TestParallelCrossEntropy:
self
.
input_test
=
torch
.
rand
((
batch
,
SQ
,
vocab
),
dtype
=
dtype
).
cuda
()
self
.
tar_test
=
torch
.
randint
(
0
,
vocab
,
(
batch
,
SQ
)).
cuda
()
if
ignore_idx
:
for
i
in
ignore
:
# Ignore 5 indices
if
swap_dim
:
self
.
tar_test
[
i
][
0
]
=
-
100
else
:
self
.
tar_test
[
0
][
i
]
=
-
100
self
.
input_ref
=
torch
.
reshape
(
self
.
input_test
.
clone
().
detach
(),
(
batch
*
SQ
,
vocab
))
self
.
tar_ref
=
torch
.
reshape
(
self
.
tar_test
.
clone
().
detach
(),
(
batch
*
SQ
,))
def
one_iteration_test
(
self
,
dtype
:
torch
.
dtype
,
swap_dim
:
bool
,
label_smoothing
:
float
,
reduce_loss
:
bool
self
,
dtype
:
torch
.
dtype
,
swap_dim
:
bool
,
label_smoothing
:
float
,
reduce_loss
:
bool
,
ignore_idx
:
bool
=
False
,
):
self
.
generate_input
(
dtype
,
swap_dim
)
self
.
generate_input
(
dtype
,
swap_dim
,
ignore_idx
)
self
.
input_test
.
requires_grad_
(
True
)
self
.
input_ref
.
requires_grad_
(
True
)
...
...
@@ -57,6 +71,8 @@ class TestParallelCrossEntropy:
test_loss
=
torch
.
flatten
(
test_loss
)
if
not
reduce_loss
else
test_loss
torch
.
testing
.
assert_close
(
test_loss
,
ref_loss
,
check_dtype
=
False
)
if
ignore_idx
:
print
(
test_loss
,
ref_loss
)
if
reduce_loss
:
torch
.
testing
.
assert_close
(
torch
.
flatten
(
self
.
input_test
.
grad
,
start_dim
=
0
,
end_dim
=
1
),
self
.
input_ref
.
grad
...
...
@@ -106,3 +122,15 @@ class TestParallelCrossEntropy:
self
.
one_iteration_test
(
dtype
=
torch
.
float32
,
swap_dim
=
False
,
label_smoothing
=
0
,
reduce_loss
=
False
)
def
test_ignore_idx
(
self
):
self
.
generate_iters
(
5
)
self
.
generate_infra
(
False
,
0
)
for
i
in
range
(
self
.
iters
):
self
.
one_iteration_test
(
dtype
=
torch
.
float32
,
swap_dim
=
random
.
choice
([
True
,
False
]),
label_smoothing
=
0
,
reduce_loss
=
False
,
ignore_idx
=
True
,
)
tests/pytorch/test_sanity.py
View file @
f8c2af4c
...
...
@@ -373,7 +373,9 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
torch
.
cuda
.
synchronize
()
def
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
):
def
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
,
microbatching
=
True
):
if
skip_dgrad
and
skip_wgrad
:
pytest
.
skip
(
"No gradient computation; Skipping to avoid PyTorch RuntimeError."
)
...
...
@@ -389,7 +391,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad
use_fp8
=
fp8_recipe
is
not
None
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
te_out
=
block
(
te_inp
)
if
not
microbatching
:
te_out
=
block
(
te_inp
)
else
:
_
=
block
(
te_inp
,
is_first_microbatch
=
True
)
te_out
=
block
(
te_inp
,
is_first_microbatch
=
False
)
if
isinstance
(
te_out
,
tuple
):
te_out
=
te_out
[
0
]
loss
=
te_out
.
sum
()
...
...
@@ -443,8 +449,16 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_dgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"microbatching"
,
all_boolean
)
def
test_sanity_layernorm_linear
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
skip_dgrad
,
normalization
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
skip_dgrad
,
normalization
,
microbatching
,
):
config
=
model_configs
[
model
]
...
...
@@ -470,7 +484,7 @@ def test_sanity_layernorm_linear(
params_dtype
=
dtype
,
device
=
"cuda"
,
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
,
microbatching
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
...
...
@@ -478,7 +492,8 @@ def test_sanity_layernorm_linear(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
,
"weird"
])
@
pytest
.
mark
.
parametrize
(
"skip_wgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"skip_dgrad"
,
all_boolean
)
def
test_sanity_linear
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
skip_dgrad
):
@
pytest
.
mark
.
parametrize
(
"microbatching"
,
all_boolean
)
def
test_sanity_linear
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
skip_dgrad
,
microbatching
):
config
=
model_configs
[
model
]
if
fp8_recipe
is
not
None
:
...
...
@@ -501,7 +516,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
params_dtype
=
dtype
,
device
=
"cuda"
,
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
,
microbatching
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
...
...
@@ -600,8 +615,17 @@ def test_sanity_grouped_linear(
@
pytest
.
mark
.
parametrize
(
"skip_dgrad"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
all_activations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"microbatching"
,
all_boolean
)
def
test_sanity_layernorm_mlp
(
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
skip_dgrad
,
activation
,
normalization
dtype
,
fp8_recipe
,
model
,
skip_wgrad
,
zero_centered_gamma
,
skip_dgrad
,
activation
,
normalization
,
microbatching
,
):
config
=
model_configs
[
model
]
...
...
@@ -630,7 +654,7 @@ def test_sanity_layernorm_mlp(
params_dtype
=
dtype
,
device
=
"cuda"
,
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
,
microbatching
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
...
...
transformer_engine/__init__.py
View file @
f8c2af4c
...
...
@@ -11,12 +11,12 @@ import transformer_engine.common
try
:
from
.
import
pytorch
except
(
ImportError
,
StopIteration
)
as
e
:
except
ImportError
as
e
:
pass
try
:
from
.
import
jax
except
(
ImportError
,
StopIteration
)
as
e
:
except
ImportError
as
e
:
pass
__version__
=
str
(
metadata
.
version
(
"transformer_engine"
))
transformer_engine/common/CMakeLists.txt
View file @
f8c2af4c
...
...
@@ -111,6 +111,11 @@ if(USE_CUDA)
cudnn_utils.cpp
transformer_engine.cpp
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
...
...
@@ -148,6 +153,7 @@ if(USE_CUDA)
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
...
...
@@ -158,6 +164,11 @@ else()
cudnn_utils.cpp
transformer_engine.cpp
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
...
...
@@ -191,6 +202,7 @@ else()
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
...
...
@@ -345,6 +357,14 @@ target_include_directories(transformer_engine PRIVATE
set_source_files_properties
(
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
...
...
transformer_engine/common/__init__.py
View file @
f8c2af4c
...
...
@@ -9,28 +9,193 @@ import glob
import
sysconfig
import
subprocess
import
ctypes
import
logging
import
os
import
platform
import
importlib
import
functools
from
pathlib
import
Path
from
importlib.metadata
import
version
,
metadata
,
PackageNotFoundError
import
transformer_engine
_logger
=
logging
.
getLogger
(
__name__
)
def
is_package_installed
(
package
):
"""Checks if a pip package is installed."""
return
(
subprocess
.
run
(
[
sys
.
executable
,
"-m"
,
"pip"
,
"show"
,
package
],
capture_output
=
True
,
check
=
False
).
returncode
==
0
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_is_pip_package_installed
(
package
):
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
# if the python package is installed via pip, and not
# if it's importable in the current directory due to
# the presence of the shared library module.
try
:
metadata
(
package
)
except
PackageNotFoundError
:
return
False
return
True
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_find_shared_object_in_te_dir
(
te_path
:
Path
,
prefix
:
str
):
"""
Find a shared object file of given prefix in the top level TE directory.
Only the following locations are searched to avoid stray SOs and build
artifacts:
1. The given top level directory (editable install).
2. `transformer_engine` named directories (source install).
3. `wheel_lib` named directories (PyPI install).
Returns None if no shared object files are found.
Raises an error if multiple shared object files are found.
"""
# Ensure top level dir exists and has the module. before searching.
if
not
te_path
.
exists
()
or
not
(
te_path
/
"transformer_engine"
).
exists
():
return
None
files
=
[]
search_paths
=
(
te_path
,
te_path
/
"transformer_engine"
,
te_path
/
"transformer_engine/wheel_lib"
,
te_path
/
"wheel_lib"
,
)
# Search.
for
dirname
,
_
,
names
in
os
.
walk
(
te_path
):
if
Path
(
dirname
)
in
search_paths
:
for
name
in
names
:
if
name
.
startswith
(
prefix
)
and
name
.
endswith
(
f
".
{
_get_sys_extension
()
}
"
):
files
.
append
(
Path
(
dirname
,
name
))
if
len
(
files
)
==
0
:
return
None
if
len
(
files
)
==
1
:
return
files
[
0
]
raise
RuntimeError
(
f
"Multiple files found:
{
files
}
"
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_shared_object_file
(
library
:
str
)
->
Path
:
"""
Return the path of the shared object file for the given TE
library, one of 'core', 'torch', or 'jax'.
Several factors affect finding the correct location of the shared object:
1. System and environment.
2. If the installation is from source or via PyPI.
- Source installed .sos are placed in top level dir
- Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts.
3. For source installations, is the install editable/inplace?
4. The user directory from where TE is being imported.
"""
# Check provided input and determine the correct prefix for .so.
assert
library
in
(
"core"
,
"torch"
,
"jax"
),
f
"Unsupported TE library
{
library
}
."
if
library
==
"core"
:
so_prefix
=
"libtransformer_engine"
else
:
so_prefix
=
f
"transformer_engine_
{
library
}
"
# Check TE install location (will be local if TE is available in current dir for import).
te_install_dir
=
Path
(
importlib
.
util
.
find_spec
(
"transformer_engine"
).
origin
).
parent
.
parent
so_path_in_install_dir
=
_find_shared_object_in_te_dir
(
te_install_dir
,
so_prefix
)
# Check default python package install location in system.
site_packages_dir
=
Path
(
sysconfig
.
get_paths
()[
"purelib"
])
so_path_in_default_dir
=
_find_shared_object_in_te_dir
(
site_packages_dir
,
so_prefix
)
# Case 1: Typical user workflow: Both locations are the same, return any result.
if
te_install_dir
==
site_packages_dir
:
assert
(
so_path_in_install_dir
is
not
None
),
f
"Could not find shared object file for Transformer Engine
{
library
}
lib."
return
so_path_in_install_dir
# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
# editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic.
if
so_path_in_install_dir
is
not
None
and
so_path_in_default_dir
is
not
None
:
raise
RuntimeError
(
f
"Found multiple shared object files:
{
so_path_in_install_dir
}
and"
f
"
{
so_path_in_default_dir
}
. Remove local shared objects installed"
f
" here
{
so_path_in_install_dir
}
or change the working directory to"
"execute from outside TE."
)
# Case 3: Typical dev workflow: Editable install
if
so_path_in_install_dir
is
not
None
:
return
so_path_in_install_dir
# Case 4: Executing from inside a TE directory without an inplace build available.
if
so_path_in_default_dir
is
not
None
:
return
so_path_in_default_dir
raise
RuntimeError
(
f
"Could not find shared object file for Transformer Engine
{
library
}
lib."
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
load_framework_extension
(
framework
:
str
):
"""
Load shared library with Transformer Engine framework bindings
and check verify correctness if installed via PyPI.
"""
# Supported frameworks.
assert
framework
in
(
"jax"
,
"torch"
),
f
"Unsupported framework
{
framework
}
"
# Name of the framework extension library.
module_name
=
f
"transformer_engine_
{
framework
}
"
# Name of the pip extra dependency for framework extensions from PyPI.
extra_dep_name
=
module_name
if
framework
==
"torch"
:
extra_dep_name
=
"pytorch"
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if
_is_pip_package_installed
(
module_name
):
assert
_is_pip_package_installed
(
"transformer_engine"
),
"Could not find `transformer-engine`."
assert
_is_pip_package_installed
(
"transformer_engine_cu12"
),
"Could not find `transformer-engine-cu12`."
assert
(
version
(
module_name
)
==
version
(
"transformer-engine"
)
==
version
(
"transformer-engine-cu12"
)
),
(
"TransformerEngine package version mismatch. Found"
f
"
{
module_name
}
v
{
version
(
module_name
)
}
, transformer-engine"
f
" v
{
version
(
'transformer-engine'
)
}
, and transformer-engine-cu12"
f
" v
{
version
(
'transformer-engine-cu12'
)
}
. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
)
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if
_is_pip_package_installed
(
"transformer-engine-cu12"
):
if
not
_is_pip_package_installed
(
module_name
):
_logger
.
info
(
"Could not find package %s. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
,
module_name
,
)
def
get_te_path
():
"""Find Transformer Engine install path using pip"""
return
Path
(
transformer_engine
.
__path__
[
0
]).
parent
# After all checks are completed, load the shared object file.
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
_get_shared_object_file
(
framework
))
solib
=
importlib
.
util
.
module_from_spec
(
spec
)
sys
.
modules
[
module_name
]
=
solib
spec
.
loader
.
exec_module
(
solib
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_sys_extension
():
system
=
platform
.
system
()
if
system
==
"Linux"
:
...
...
@@ -45,20 +210,47 @@ def _get_sys_extension():
return
extension
def
_load_cudnn
():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in Python dist-packages
lib_path
=
glob
.
glob
(
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_nvidia_cuda_library
(
lib_name
:
str
):
"""
Attempts to load shared object file installed via pip.
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths
=
glob
.
glob
(
os
.
path
.
join
(
sysconfig
.
get_path
(
"purelib"
),
f
"nvidia/
cudnn
/lib/lib
cudnn
.
{
_get_sys_extension
()
}
.*[0-9]"
,
f
"nvidia/
{
lib_name
}
/lib/lib
*
.
{
_get_sys_extension
()
}
.*[0-9]"
,
)
)
if
lib_path
:
assert
(
len
(
lib_path
)
==
1
),
f
"Found
{
len
(
lib_path
)
}
libcudnn.
{
_get_sys_extension
()
}
.x in nvidia-cudnn-cuXX."
return
ctypes
.
CDLL
(
lib_path
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
path_found
=
len
(
so_paths
)
>
0
ctypes_handles
=
[]
if
path_found
:
for
so_path
in
so_paths
:
ctypes_handles
.
append
(
ctypes
.
CDLL
(
so_path
,
mode
=
ctypes
.
RTLD_GLOBAL
))
return
path_found
,
ctypes_handles
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_nvidia_cudart_include_dir
():
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
try
:
import
nvidia
except
ModuleNotFoundError
:
return
""
include_dir
=
Path
(
nvidia
.
__file__
).
parent
/
"cuda_runtime"
return
str
(
include_dir
)
if
include_dir
.
exists
()
else
""
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_cudnn
():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home
=
os
.
environ
.
get
(
"CUDNN_HOME"
)
or
os
.
environ
.
get
(
"CUDNN_PATH"
)
...
...
@@ -75,28 +267,16 @@ def _load_cudnn():
if
libs
:
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# Attempt to locate cuDNN in Python dist-packages
found
,
handle
=
_load_nvidia_cuda_library
(
"cudnn"
)
if
found
:
return
handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return
ctypes
.
CDLL
(
f
"libcudnn.
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
def
_load_library
():
"""Load shared library with Transformer Engine C extensions"""
so_path
=
get_te_path
()
/
"transformer_engine"
/
f
"libtransformer_engine.
{
_get_sys_extension
()
}
"
if
not
so_path
.
exists
():
so_path
=
(
get_te_path
()
/
"transformer_engine"
/
"wheel_lib"
/
f
"libtransformer_engine.
{
_get_sys_extension
()
}
"
)
if
not
so_path
.
exists
():
so_path
=
get_te_path
()
/
f
"libtransformer_engine.
{
_get_sys_extension
()
}
"
assert
so_path
.
exists
(),
f
"Could not find libtransformer_engine.
{
_get_sys_extension
()
}
"
return
ctypes
.
CDLL
(
so_path
,
mode
=
ctypes
.
RTLD_GLOBAL
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_nvrtc
():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
...
...
@@ -107,6 +287,11 @@ def _load_nvrtc():
if
libs
:
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# Attempt to locate NVRTC in Python dist-packages
found
,
handle
=
_load_nvidia_cuda_library
(
"cuda_nvrtc"
)
if
found
:
return
handle
# Attempt to locate NVRTC via ldconfig
libs
=
subprocess
.
check_output
(
"ldconfig -p | grep 'libnvrtc'"
,
shell
=
True
)
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
...
...
@@ -123,10 +308,22 @@ def _load_nvrtc():
return
ctypes
.
CDLL
(
f
"libnvrtc.
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_core_library
():
"""Load shared library with Transformer Engine C extensions"""
return
ctypes
.
CDLL
(
_get_shared_object_file
(
"core"
),
mode
=
ctypes
.
RTLD_GLOBAL
)
if
"NVTE_PROJECT_BUILDING"
not
in
os
.
environ
or
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
try
:
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
_CUBLAS_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cublas"
)
_CUDART_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cuda_runtime"
)
# Needed to find the correct headers for NVRTC kernels.
if
not
os
.
getenv
(
"NVTE_CUDA_INCLUDE_DIR"
)
and
_nvidia_cudart_include_dir
():
os
.
environ
[
"NVTE_CUDA_INCLUDE_DIR"
]
=
_nvidia_cudart_include_dir
()
except
OSError
:
pass
_TE_LIB_CTYPES
=
_load_library
()
pass
_TE_LIB_CTYPES
=
_load_core_library
()
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
f8c2af4c
...
...
@@ -21,12 +21,18 @@
#define HALF_BYTES 2
#define UB_MAX_SM 32
#define AS_VECTOR(shape) std::vector<size_t>(shape.data, shape.data + shape.ndim)
using
namespace
std
::
placeholders
;
namespace
transformer_engine
{
namespace
{
std
::
vector
<
size_t
>
shape_to_vector
(
const
NVTEShape
&
shape
)
{
return
std
::
vector
<
size_t
>
(
shape
.
data
,
shape
.
data
+
shape
.
ndim
);
}
}
// namespace
/***************************************************************************************************
* Comm+GEMM Overlap Common Core
**************************************************************************************************/
...
...
@@ -147,13 +153,50 @@ CommOverlapCore::~CommOverlapCore() {
TensorWrapper
CommOverlapCore
::
get_tensor_chunk
(
const
TensorWrapper
&
source
,
size_t
chunk_offset
,
const
std
::
vector
<
size_t
>
&
chunk_shape
)
{
TensorWrapper
chunk
;
const
auto
scaling_mode
=
source
.
scaling_mode
();
// Tensor dimensions
std
::
vector
<
size_t
>
shape
=
shape_to_vector
(
source
.
shape
());
auto
flatten_shape_to_2d
=
[](
const
std
::
vector
<
size_t
>
&
shape
)
->
std
::
pair
<
size_t
,
size_t
>
{
if
(
shape
.
empty
())
{
return
{
1
,
1
};
}
size_t
height
=
1
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
()
-
1
;
++
i
)
{
height
*=
shape
[
i
];
}
return
{
height
,
shape
.
back
()};
};
size_t
height
,
width
,
chunk_height
,
chunk_width
;
std
::
tie
(
height
,
width
)
=
flatten_shape_to_2d
(
shape
);
std
::
tie
(
chunk_height
,
chunk_width
)
=
flatten_shape_to_2d
(
chunk_shape
);
// Check tensor dimensions
#define NVTE_DIM_CHECK(cond, message) \
NVTE_CHECK(cond, message, " (tensor shape=", shape, ", chunk shape=", chunk_shape, \
", chunk offset=", chunk_offset, ")")
NVTE_DIM_CHECK
(
height
>
0
&&
width
>
0
,
"Attempted to get chunk from empty tensor"
);
NVTE_DIM_CHECK
(
chunk_height
>
0
&&
chunk_width
>
0
,
"Attempted to get empty tensor chunk"
);
NVTE_DIM_CHECK
(
chunk_height
<=
height
&&
chunk_width
<=
width
,
"Attempted to get out-of-bounds tensor chunk"
);
if
(
scaling_mode
==
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
)
{
// MXFP8 scale-inverses are padded to a 2D matrix with dims that
// are divisible by 128. UB doesn't handle this padding yet.
NVTE_DIM_CHECK
(
height
%
128
==
0
&&
width
%
128
==
0
,
"Userbuffers requires MXFP8 tensor dims that are divisible by 128"
);
NVTE_DIM_CHECK
(
chunk_height
%
128
==
0
&&
chunk_width
%
128
==
0
,
"Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128"
);
}
#undef NVTE_DIM_CHECK
// Construct tensor chunk
TensorWrapper
chunk
(
scaling_mode
);
for
(
int
param_id
=
0
;
param_id
<
NVTETensorParam
::
kNVTENumTensorParams
;
param_id
++
)
{
auto
param_type
=
static_cast
<
NVTETensorParam
>
(
param_id
);
auto
param
=
source
.
get_parameter
(
param_type
);
auto
param_dptr
=
reinterpret_cast
<
char
*>
(
param
.
data_ptr
);
auto
param_dtype
=
static_cast
<
DType
>
(
param
.
dtype
);
auto
param_shape
=
AS_VECTOR
(
param
.
shape
);
auto
param_shape
=
shape_to_vector
(
param
.
shape
);
if
(
param_dptr
!=
nullptr
)
{
if
(
param_type
==
NVTETensorParam
::
kNVTERowwiseData
||
...
...
@@ -163,8 +206,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
param_shape
=
chunk_shape
;
if
(
param_type
==
NVTETensorParam
::
kNVTEColumnwiseData
&&
source
.
scaling_mode
()
!
=
NVTEScalingMode
::
NVTE_
MXFP8_1D
_SCALING
)
{
// Columnwise shape for
non-block
scaled tensors shifts the last dimension to the front
source
.
scaling_mode
()
=
=
NVTEScalingMode
::
NVTE_
DELAYED_TENSOR
_SCALING
)
{
// Columnwise shape for
FP8 tensor-
scaled tensors shifts the last dimension to the front
auto
last_dim
=
param_shape
.
back
();
param_shape
.
pop_back
();
param_shape
.
insert
(
param_shape
.
begin
(),
last_dim
);
...
...
@@ -172,18 +215,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
}
else
if
(
source
.
scaling_mode
()
==
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
&&
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
||
param_type
==
NVTETensorParam
::
kNVTEColumnwiseScaleInv
))
{
// Calculate block scaling offset and size
auto
scaled_tensor_dim_size
=
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
)
?
source
.
shape
().
data
[
0
]
:
source
.
columnwise_shape
().
data
[
0
];
auto
scaled_chunk_dim_size
=
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
)
?
chunk_shape
.
front
()
:
chunk_shape
.
back
();
auto
chunk_scale_start
=
chunk_offset
/
32
;
auto
chunk_scale_end
=
(
chunk_offset
+
scaled_chunk_dim_size
)
/
32
;
auto
chunk_scale_size
=
chunk_scale_end
-
chunk_scale_start
;
param_dptr
+=
chunk_scale_start
*
typeToSize
(
param_dtype
);
param_shape
=
std
::
vector
<
size_t
>
{
chunk_scale_size
};
// Calculate offset and size for MXFP8 scale-invs
size_t
chunk_scale_height
=
chunk_height
;
size_t
chunk_scale_width
=
chunk_width
;
if
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
)
{
chunk_scale_width
/=
32
;
}
else
{
chunk_scale_height
/=
32
;
}
param_dptr
+=
(
chunk_offset
/
32
)
*
typeToSize
(
param_dtype
);
param_shape
=
{
chunk_scale_height
,
chunk_scale_width
};
}
// Set chunked source parameters into the chunked tensor output
...
...
@@ -434,10 +475,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
size_t
k
=
transa
?
A
.
size
(
1
)
:
A
.
size
(
0
);
size_t
n
=
_ubuf
.
size
(
0
);
size_t
m_chunk
=
m
/
_num_splits
;
const
std
::
vector
<
size_t
>
input_a_chunk_shape
=
(
transa
?
std
::
vector
<
size_t
>
{
m_chunk
,
k
}
:
std
::
vector
<
size_t
>
{
k
,
m_chunk
});
const
std
::
vector
<
size_t
>
output_chunk_shape
=
{
n
,
m_chunk
};
size_t
input_a_chunk_size
=
m_chunk
*
k
;
size_t
output_chunk_size
=
n
*
m_chunk
;
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
// Helper function to get bias chunk if needed
auto
maybe_get_bias_chunk
=
[
this
,
&
bias
,
m_chunk
](
size_t
chunk_id
)
->
TensorWrapper
{
if
(
bias
.
dptr
()
==
nullptr
)
{
return
TensorWrapper
();
}
return
get_tensor_chunk
(
bias
,
chunk_id
*
m_chunk
,
{
m_chunk
});
};
// Catch up the default torch stream
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
...
...
@@ -449,12 +501,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
if
(
_rs_overlap_first_gemm
)
{
auto
input_a_chunk
=
get_tensor_chunk
(
A
,
0
,
{
m_chunk
,
k
});
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
0
,
{
m
,
m_chunk
});
auto
input_a_chunk
=
get_tensor_chunk
(
A
,
0
,
input_a_chunk_shape
);
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
0
,
output_chunk_shape
);
auto
bias_chunk
=
maybe_get_bias_chunk
(
0
);
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]);
}
else
{
...
...
@@ -464,18 +517,19 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
for
(
int
i
=
1
;
i
<
_num_splits
;
i
++
)
{
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
output_chunk
=
get_buffer_chunk_like
(
D
,
i
*
output_chunk_size
,
{
n
,
m_chunk
});
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
input_a_chunk_shape
);
output_chunk
=
get_buffer_chunk_like
(
D
,
i
*
output_chunk_size
,
output_chunk_shape
);
bias_chunk
=
maybe_get_bias_chunk
(
i
);
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
}
else
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
...
...
@@ -519,13 +573,14 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
}
else
{
for
(
int
i
=
0
;
i
<
_num_splits
;
i
++
)
{
auto
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
i
*
output_chunk_size
,
{
n
,
m_chunk
});
auto
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
input_a_chunk_shape
);
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
i
*
output_chunk_size
,
output_chunk_shape
);
auto
bias_chunk
=
maybe_get_bias_chunk
(
i
);
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
...
...
@@ -605,14 +660,17 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
void
*
buffer_ptr
;
_ub_reg
=
register_user_buffer_collective
(
&
buffer_ptr
,
buffer_bytes
,
_ub_comm
,
true
);
if
(
_rank
==
0
)
printf
(
"!!! [UBP2P] Register UBuf %d
\n
"
,
_ub_reg
);
_ubuf
=
TensorWrapper
(
buffer_ptr
,
{
buffer_shape
[
0
]
/
tp_size
*
_num_ubuf_chunks
,
buffer_shape
[
1
]},
buffer_dtype
);
_ubuf
=
TensorWrapper
(
buffer_ptr
,
std
::
vector
<
size_t
>
{
buffer_shape
[
0
]
/
tp_size
*
_num_ubuf_chunks
,
buffer_shape
[
1
]},
buffer_dtype
);
// Create tensor chunks for easy management
char
*
ubuf_byte_ptr
=
reinterpret_cast
<
char
*>
(
buffer_ptr
);
for
(
int
i
=
0
;
i
<
_num_ubuf_chunks
;
i
++
)
{
_ubufs
.
push_back
(
TensorWrapper
(
reinterpret_cast
<
void
*>
(
ubuf_byte_ptr
),
{
buffer_shape
[
0
]
/
tp_size
,
buffer_shape
[
1
]},
buffer_dtype
));
std
::
vector
<
size_t
>
{
buffer_shape
[
0
]
/
tp_size
,
buffer_shape
[
1
]},
buffer_dtype
));
ubuf_byte_ptr
+=
buffer_chunk_bytes
;
}
...
...
@@ -661,7 +719,7 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
TensorWrapper
CommOverlapP2PBase
::
get_buffer_chunk_by_id
(
const
TensorWrapper
&
source
,
size_t
chunk_id
)
{
// Start with a chunk of the source tensor
auto
chunk
=
get_tensor_chunk
(
source
,
0
,
AS_VECTOR
(
_ubufs
[
chunk_id
].
shape
()));
auto
chunk
=
get_tensor_chunk
(
source
,
0
,
shape_to_vector
(
_ubufs
[
chunk_id
].
shape
()));
// Update chunk with offset data pointers from the communication buffer
if
(
chunk
.
dptr
()
!=
nullptr
)
{
...
...
@@ -711,7 +769,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
0
],
_start_compute
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_recv
,
_start_compute
,
0
));
auto
input_b
=
get_buffer_chunk_like
(
B
,
0
,
AS_VECTOR
(
B
.
shape
()));
auto
input_b
=
get_buffer_chunk_like
(
B
,
0
,
shape_to_vector
(
B
.
shape
()));
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
...
...
@@ -798,8 +856,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Get communication and GEMM output chunk sizes
const
int
comm_bytes
=
_ubufs
[
0
].
numel
()
*
_ubufs
[
0
].
element_size
();
const
bool
do_gelu
=
pre_gelu_out
.
numel
()
>
0
;
size_t
input_chunk_size
=
n_chunk
*
k
;
size_t
output_chunk_size
=
n_chunk
*
m
;
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
...
...
@@ -810,10 +866,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
if
(
_aggregate
)
{
const
int
num_steps
=
_tp_size
/
2
;
#ifndef __HIP_PLATFORM_AMD__
input_chunk_size
*=
2
;
output_chunk_size
*=
2
;
#endif
// Chunk dims
std
::
vector
<
size_t
>
input_b_chunk_shape
=
(
transb
?
std
::
vector
<
size_t
>
{
k
,
2
*
n_chunk
}
:
std
::
vector
<
size_t
>
{
2
*
n_chunk
,
k
});
std
::
vector
<
size_t
>
output_chunk_shape
=
{
2
*
n_chunk
,
k
};
size_t
input_b_chunk_size
=
2
*
n_chunk
*
k
;
size_t
output_chunk_size
=
2
*
n_chunk
*
m
;
// Initial 1X input chunk exchange between neighboring peers
int
send_chunk_id
=
_tp_id
;
...
...
@@ -842,8 +901,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM
auto
input_b_chunk
=
get_buffer_chunk_like
(
B
,
input_chunk_size
*
send_chunk_id
,
{
n_chunk
*
2
,
k
});
auto
output_chunk
=
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
*
2
,
m
});
get_buffer_chunk_like
(
B
,
input_b_chunk_size
*
send_chunk_id
,
input_b_chunk_shape
);
auto
output_chunk
=
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
,
output_chunk_shape
);
auto
aux_chunk
=
(
do_gelu
)
?
get_tensor_chunk
(
pre_gelu_out
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
*
2
,
k
})
...
...
@@ -882,6 +942,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
}
}
else
{
// Chunk dims
std
::
vector
<
size_t
>
input_b_chunk_shape
=
(
transb
?
std
::
vector
<
size_t
>
{
k
,
n_chunk
}
:
std
::
vector
<
size_t
>
{
n_chunk
,
k
});
std
::
vector
<
size_t
>
output_chunk_shape
=
{
n_chunk
,
m
};
size_t
input_b_chunk_size
=
n_chunk
*
k
;
size_t
output_chunk_size
=
n_chunk
*
m
;
for
(
int
i
=
0
;
i
<
_tp_size
;
i
++
)
{
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
...
...
@@ -893,8 +960,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
int
recv_offset
=
comm_bytes
*
recv_chunk_id
;
// GEMM
auto
input_b_chunk
=
get_buffer_chunk_like
(
B
,
input_chunk_size
*
send_chunk_id
,
{
n_chunk
,
k
});
auto
output_chunk
=
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
,
m
});
auto
input_b_chunk
=
get_buffer_chunk_like
(
B
,
input_b_chunk_size
*
send_chunk_id
,
input_b_chunk_shape
);
auto
output_chunk
=
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
,
output_chunk_shape
);
auto
aux_chunk
=
(
do_gelu
)
?
get_tensor_chunk
(
pre_gelu_out
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
,
k
})
...
...
@@ -972,7 +1041,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
auto
output_d
=
get_buffer_chunk_like
(
D
,
0
,
AS_VECTOR
(
D
.
shape
()));
auto
output_d
=
get_buffer_chunk_like
(
D
,
0
,
shape_to_vector
(
D
.
shape
()));
nvte_cublas_atomic_gemm
(
A
.
data
(),
B
.
data
(),
output_d
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
0
,
_tp_size
,
true
,
_counter
.
data
(),
stream_main
);
...
...
@@ -1053,6 +1122,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto
input_b_chunk
=
get_tensor_chunk
(
B
,
input_b_chunk_id
*
input_chunk_size
,
{
n_chunk
,
k
});
auto
output_chunk
=
get_buffer_chunk_by_id
(
D
,
i
);
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
stream_id
*
workspace_size_chunk
,
{
workspace_size_chunk
});
...
...
transformer_engine/common/common.cu
View file @
f8c2af4c
...
...
@@ -35,6 +35,65 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
}
}
namespace
{
constexpr
size_t
kThreadsPerBlock
=
256
;
template
<
typename
TVectorized
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
memset_kernel
(
void
*
__restrict__
ptr
,
int
value
,
size_t
size_in_bytes
)
{
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
*
sizeof
(
TVectorized
)
>=
size_in_bytes
)
{
return
;
// Out of bounds
}
if
((
idx
+
1
)
*
sizeof
(
TVectorized
)
>
size_in_bytes
)
{
// If the buffer size is not an even multiple of the vectorization, manually set the remaining bytes unvectorized.
size_t
remaining_bytes
=
size_in_bytes
-
idx
*
sizeof
(
TVectorized
);
memset
(
reinterpret_cast
<
uint8_t
*>
(
ptr
)
+
idx
*
sizeof
(
TVectorized
),
value
,
remaining_bytes
);
return
;
}
union
{
TVectorized
value
;
uint8_t
data
[
sizeof
(
TVectorized
)];
}
data
;
for
(
size_t
i
=
0
;
i
<
sizeof
(
TVectorized
);
++
i
)
{
data
.
data
[
i
]
=
static_cast
<
uint8_t
>
(
value
);
}
reinterpret_cast
<
TVectorized
*>
(
ptr
)[
idx
]
=
data
.
value
;
}
}
// namespace
#define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \
if (size_in_bytes >= sizeof(vectorizedType) && \
reinterpret_cast<size_t>(ptr) % sizeof(vectorizedType) == 0) { \
size_t numBlocks = DIVUP(size_in_bytes, kThreadsPerBlock * sizeof(vectorizedType)); \
dim3 grid(numBlocks, 1, 1); \
memset_kernel<vectorizedType> \
<<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \
return; \
}
extern
"C"
{
void
nvte_memset
(
void
*
ptr
,
int
value
,
size_t
size_in_bytes
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_memset
);
NVTE_CHECK
(
ptr
!=
nullptr
,
"Pointer for memset must be allocated."
);
if
(
size_in_bytes
>
4096
)
{
// Use cudaMemsetAsync for larger sizes.
cudaMemsetAsync
(
ptr
,
value
,
size_in_bytes
,
stream
);
return
;
}
MEMSET_VECTORIZED_KERNEL_DISPATCH
(
ptr
,
size_in_bytes
,
value
,
float4
,
stream
);
MEMSET_VECTORIZED_KERNEL_DISPATCH
(
ptr
,
size_in_bytes
,
value
,
float2
,
stream
);
MEMSET_VECTORIZED_KERNEL_DISPATCH
(
ptr
,
size_in_bytes
,
value
,
float
,
stream
);
MEMSET_VECTORIZED_KERNEL_DISPATCH
(
ptr
,
size_in_bytes
,
value
,
uint8_t
,
stream
);
}
}
// extern "C"
void
checkCuDriverContext
(
CUstream
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
return
;
...
...
@@ -144,4 +203,16 @@ bool is_supported_by_CC_100() {
#endif
}
std
::
vector
<
std
::
vector
<
Tensor
*>>
convert_tensor_array
(
NVTETensor
**
nvte_tensors
,
size_t
outer_size
,
size_t
inner_size
)
{
std
::
vector
<
std
::
vector
<
Tensor
*>>
ret
;
for
(
size_t
i
=
0
;
i
<
outer_size
;
++
i
)
{
ret
.
emplace_back
();
for
(
size_t
j
=
0
;
j
<
inner_size
;
++
j
)
{
ret
.
back
().
push_back
(
reinterpret_cast
<
Tensor
*>
(
nvte_tensors
[
i
][
j
]));
}
}
return
ret
;
}
}
// namespace transformer_engine
transformer_engine/common/common.h
View file @
f8c2af4c
...
...
@@ -116,7 +116,7 @@ struct Tensor {
columnwise_scale_inv
(
nullptr
,
{
1
},
DType
::
kFloat32
),
scaling_mode
(
NVTE_DELAYED_TENSOR_SCALING
)
{}
in
t
numel
()
const
{
size_
t
numel
()
const
{
size_t
acc
=
1
;
for
(
const
auto
dim
:
shape
())
{
acc
*=
dim
;
...
...
@@ -138,6 +138,14 @@ struct Tensor {
return
data
.
dtype
;
}
size_t
dim
()
const
{
if
(
!
has_data
()
&&
has_columnwise_data
())
{
return
columnwise_data
.
shape
.
size
();
}
else
{
return
data
.
shape
.
size
();
}
}
std
::
vector
<
size_t
>
shape
()
const
{
/* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC
...
...
@@ -243,6 +251,7 @@ constexpr T DIVUP(const T &x, const T &y) {
}
using
byte
=
uint8_t
;
using
int16
=
int16_t
;
using
int32
=
int32_t
;
using
int64
=
int64_t
;
using
fp32
=
float
;
...
...
@@ -271,6 +280,7 @@ constexpr inline const char *type_name() noexcept;
return #T; \
}
TRANSFORMER_ENGINE_TYPE_NAME
(
uint8_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
int16_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
int32_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
int64_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
float
)
...
...
@@ -327,7 +337,7 @@ struct TypeExtrema {
template
<
typename
T
>
struct
TypeInfo
{
using
types
=
std
::
tuple
<
byte
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
>
;
template
<
typename
U
,
DType
current
>
struct
Helper
{
...
...
@@ -364,6 +374,10 @@ struct TypeInfo {
using type = unsigned char; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt16: { \
using type = int16_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt32: { \
using type = int32_t; \
{ __VA_ARGS__ } \
...
...
@@ -400,6 +414,33 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
...
...
@@ -599,6 +640,9 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
bool
is_supported_by_CC_100
();
std
::
vector
<
std
::
vector
<
Tensor
*>>
convert_tensor_array
(
NVTETensor
**
nvte_tensors
,
size_t
outer_size
,
size_t
inner_size
);
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
transformer_engine/
pytorch/csrc/thd_utils
.cu
h
→
transformer_engine/
common/fused_attn/context_parallel
.cu
View file @
f8c2af4c
...
...
@@ -3,21 +3,25 @@
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#include <assert.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace
transformer_engine
{
namespace
context_parallel
{
struct
LseCorrectionFunctor
{
__forceinline__
__device__
static
void
run
(
double
*
lse
,
float
*
half_lse
,
size_t
idx
,
__forceinline__
__device__
static
void
run
(
float
*
lse
,
float
*
half_lse
,
size_t
idx
,
size_t
half_idx
)
{
double
val
=
lse
[
idx
];
float
val
=
lse
[
idx
];
float
val_per_step
=
half_lse
[
half_idx
];
double
max_scale
=
max
(
val
,
val_per_step
);
double
min_scale
=
min
(
val
,
val_per_step
);
lse
[
idx
]
=
max_scale
+
log
(
1.0
+
exp
(
min_scale
-
max_scale
));
float
max_scale
=
max
(
val
,
val_per_step
);
float
min_scale
=
min
(
val
,
val_per_step
);
lse
[
idx
]
=
max_scale
+
log
1pf
(
exp
f
(
min_scale
-
max_scale
));
}
};
...
...
@@ -49,16 +53,13 @@ struct AddFunctor {
#pragma unroll
for
(
int
i
=
0
;
i
<
sizeof
(
float4
)
/
sizeof
(
dtype
);
i
++
)
{
p_
[
i
]
+
=
p
[
i
];
p_
[
i
]
=
p_
[
i
]
+
p
[
i
];
}
reinterpret_cast
<
float4
*>
(
token
)[
idx
]
=
d_
;
}
};
namespace
transformer_engine
{
namespace
fused_attn
{
/***************************************************************************************************
* Support THD format for Context Parallel: Binary search an array for a target value
**************************************************************************************************/
...
...
@@ -107,6 +108,7 @@ __global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int b
/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
__global__
void
thd_read_half_tensor_kernel
(
void
*
half
,
void
*
tensor
,
int
*
cu_seqlens
,
int
batch
,
int
hidden_size_in_bytes
,
int
half_idx
,
int
dim_size_of_token
)
{
...
...
@@ -148,8 +150,8 @@ __global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_se
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
template
<
typename
lse_dtype
,
bool
lse_packed
,
typename
Functor
>
__global__
void
thd_lse_kernel
(
lse_dtype
*
lse
,
float
*
half_lse
,
int
*
cu_seqlens
,
int
batch
,
template
<
bool
lse_packed
,
typename
Functor
>
__global__
void
thd_lse_kernel
(
float
*
lse
,
float
*
half_lse
,
int
*
cu_seqlens
,
int
batch
,
int
num_heads
,
int
lse_seqlen
,
int
second_half_lse_seqlen
)
{
extern
__shared__
int
cu_seqlens_s
[];
for
(
int
i
=
threadIdx
.
x
;
i
<=
batch
;
i
+=
blockDim
.
x
)
{
...
...
@@ -218,7 +220,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
idx
=
row
*
lse_seqlen
+
col
+
seq_len
*
only_second_half
;
idx_per_step
=
row
*
lse_per_step_seqlen
+
col
;
}
float
lse_corrected_exp
=
exp
(
lse_per_step
[
idx_per_step
]
-
lse
[
idx
]);
float
lse_corrected_exp
=
exp
f
(
lse_per_step
[
idx_per_step
]
-
lse
[
idx
]);
idx
=
token_id
+
cu_seqlens_s
[
seq_id
+
1
]
*
only_second_half
;
idx
=
(
idx
*
num_heads
+
head_id
)
*
dim_per_head
;
...
...
@@ -232,7 +234,10 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
dtype
*
p_per_step
=
reinterpret_cast
<
dtype
*>
(
&
data_per_step
);
dtype
*
p
=
reinterpret_cast
<
dtype
*>
(
&
data
);
for
(
int
k
=
0
;
k
<
sizeof
(
float4
)
/
sizeof
(
dtype
);
k
++
)
{
p
[
k
]
+=
(
p_per_step
[
k
]
==
0
?
0
:
p_per_step
[
k
]
*
lse_corrected_exp
);
p
[
k
]
=
p
[
k
]
+
(
p_per_step
[
k
]
==
static_cast
<
dtype
>
(
0.
f
)
?
static_cast
<
dtype
>
(
0.
f
)
:
static_cast
<
dtype
>
(
static_cast
<
float
>
(
p_per_step
[
k
])
*
lse_corrected_exp
));
}
reinterpret_cast
<
float4
*>
(
cur_out
)[
j
]
=
data
;
}
...
...
@@ -297,6 +302,442 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in
}
}
}
// namespace fused_attn
/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
void
thd_read_half_tensor
(
const
Tensor
&
tensor
,
const
Tensor
&
cu_seqlens
,
Tensor
&
half
,
int
half_idx
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
tensor
.
dim
()
==
3
||
tensor
.
dim
()
==
4
);
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
auto
tensor_shape
=
tensor
.
shape
();
NVTE_CHECK
(
cu_seqlens
.
dim
()
==
1
);
NVTE_CHECK
(
cu_seqlens_shape
[
0
]
>=
2
);
// Shapes of q and dq are [t, h, d], so the dimension of "t" is 0
// Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1
int
seq_dim
=
tensor
.
dim
()
==
3
?
0
:
1
;
int
batch
=
cu_seqlens_shape
[
0
]
-
1
;
int
num_heads
=
tensor_shape
[
seq_dim
+
1
];
int
dim_per_head
=
tensor_shape
[
seq_dim
+
2
];
int
hidden_size_in_bytes
=
num_heads
*
dim_per_head
*
typeToSize
(
tensor
.
dtype
());
// For 128-bits load/store
NVTE_CHECK
(
hidden_size_in_bytes
%
16
==
0
);
// Launch Kernel
constexpr
unsigned
int
block
=
256
;
unsigned
int
grid_x
=
(
tensor_shape
[
seq_dim
]
/
2
*
32
+
block
-
1
)
/
block
;
unsigned
int
grid_y
=
1
;
for
(
int
i
=
0
;
i
<
seq_dim
;
i
++
)
{
grid_y
*=
tensor_shape
[
i
];
}
dim3
grid
=
{
grid_x
,
grid_y
};
thd_read_half_tensor_kernel
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
half
.
data
.
dptr
,
tensor
.
data
.
dptr
,
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
hidden_size_in_bytes
,
half_idx
,
tensor_shape
[
seq_dim
]);
}
/***************************************************************************************************
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
void
thd_second_half_lse_correction
(
Tensor
lse
,
const
Tensor
&
lse_per_step
,
const
Tensor
&
cu_seqlens
,
bool
lse_packed
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
lse
.
dtype
()
==
DType
::
kFloat32
);
NVTE_CHECK
(
lse_per_step
.
dtype
()
==
DType
::
kFloat32
);
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
NVTE_CHECK
(
cu_seqlens
.
dim
()
==
1
);
int
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
;
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
auto
lse_shape
=
lse
.
shape
();
auto
lse_per_step_shape
=
lse_per_step
.
shape
();
if
(
lse_packed
)
{
NVTE_CHECK
(
lse
.
dim
()
==
2
);
NVTE_CHECK
(
lse_per_step
.
dim
()
==
2
);
batch
=
cu_seqlens_shape
[
0
]
-
1
;
num_heads
=
lse_shape
[
0
];
lse_seqlen
=
lse_shape
[
1
];
second_half_lse_seqlen
=
lse_per_step_shape
[
1
];
NVTE_CHECK
(
lse_per_step_shape
[
0
]
==
num_heads
);
NVTE_CHECK
(
second_half_lse_seqlen
>=
lse_seqlen
/
2
);
}
else
{
NVTE_CHECK
(
lse
.
dim
()
==
3
);
NVTE_CHECK
(
lse_per_step
.
dim
()
==
3
);
batch
=
lse_shape
[
0
];
num_heads
=
lse_shape
[
1
];
lse_seqlen
=
lse_shape
[
2
];
second_half_lse_seqlen
=
lse_per_step_shape
[
2
];
NVTE_CHECK
(
lse_per_step_shape
[
0
]
==
batch
);
NVTE_CHECK
(
lse_per_step_shape
[
1
]
==
num_heads
);
NVTE_CHECK
(
second_half_lse_seqlen
==
lse_seqlen
/
2
);
NVTE_CHECK
(
cu_seqlens_shape
[
0
]
==
batch
+
1
);
}
constexpr
unsigned
int
block
=
256
;
unsigned
int
grid_x
=
(
lse_seqlen
/
2
+
block
-
1
)
/
block
;
unsigned
int
grid_y
=
num_heads
;
dim3
grid
=
{
grid_x
,
grid_y
};
if
(
lse_packed
)
{
thd_lse_kernel
<
true
,
LseCorrectionFunctor
><<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
}
else
{
thd_lse_kernel
<
false
,
LseCorrectionFunctor
><<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
}
}
void
thd_read_second_half_lse
(
const
Tensor
&
lse
,
const
Tensor
&
cu_seqlens
,
Tensor
&
half_lse
,
bool
lse_packed
,
int
second_half_lse_seqlen
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
lse
.
dtype
()
==
DType
::
kFloat32
);
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
NVTE_CHECK
(
cu_seqlens
.
dim
()
==
1
);
int
batch
,
num_heads
,
lse_seqlen
;
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
auto
lse_shape
=
lse
.
shape
();
if
(
lse_packed
)
{
NVTE_CHECK
(
lse
.
dim
()
==
2
);
batch
=
cu_seqlens_shape
[
0
]
-
1
;
num_heads
=
lse_shape
[
0
];
lse_seqlen
=
lse_shape
[
1
];
NVTE_CHECK
(
second_half_lse_seqlen
>=
lse_seqlen
/
2
);
}
else
{
NVTE_CHECK
(
lse
.
dim
()
==
3
);
batch
=
lse_shape
[
0
];
num_heads
=
lse_shape
[
1
];
lse_seqlen
=
lse_shape
[
2
];
NVTE_CHECK
(
cu_seqlens_shape
[
0
]
==
batch
+
1
);
NVTE_CHECK
(
second_half_lse_seqlen
==
lse_seqlen
/
2
);
}
constexpr
unsigned
int
block
=
256
;
unsigned
int
grid_x
=
(
lse_seqlen
/
2
+
block
-
1
)
/
block
;
unsigned
int
grid_y
=
num_heads
;
dim3
grid
=
{
grid_x
,
grid_y
};
if
(
lse_packed
)
{
thd_lse_kernel
<
true
,
ReadLseFunctor
><<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
half_lse
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
}
else
{
thd_lse_kernel
<
false
,
ReadLseFunctor
><<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
half_lse
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/
template
<
typename
dtype
,
int
only_second_half
>
static
void
thd_out_correction_helper
(
Tensor
out
,
const
Tensor
&
out_per_step
,
const
Tensor
&
lse
,
const
Tensor
&
lse_per_step
,
const
Tensor
&
cu_seqlens
,
bool
lse_packed
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
out
.
dtype
()
==
out_per_step
.
dtype
());
NVTE_CHECK
(
lse
.
dtype
()
==
DType
::
kFloat32
);
NVTE_CHECK
(
lse_per_step
.
dtype
()
==
DType
::
kFloat32
);
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
auto
out_shape
=
out
.
shape
();
auto
lse_shape
=
lse
.
shape
();
auto
out_per_step_shape
=
out_per_step
.
shape
();
auto
lse_per_step_shape
=
lse_per_step
.
shape
();
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
int
total_tokens
=
out_shape
[
0
];
int
num_heads
=
out_shape
[
1
];
int
dim_per_head
=
out_shape
[
2
];
NVTE_CHECK
(
out_per_step_shape
[
0
]
==
total_tokens
/
(
only_second_half
+
1
));
NVTE_CHECK
(
out_per_step_shape
[
1
]
==
num_heads
);
NVTE_CHECK
(
out_per_step_shape
[
2
]
==
dim_per_head
);
int
batch
,
lse_seqlen
,
lse_per_step_seqlen
;
if
(
lse_packed
)
{
batch
=
cu_seqlens_shape
[
0
]
-
1
;
lse_seqlen
=
lse_shape
[
1
];
lse_per_step_seqlen
=
lse_per_step_shape
[
1
];
NVTE_CHECK
(
lse_shape
[
0
]
==
num_heads
);
NVTE_CHECK
(
lse_seqlen
>=
total_tokens
);
NVTE_CHECK
(
lse_per_step_shape
[
0
]
==
num_heads
);
NVTE_CHECK
(
lse_per_step_seqlen
>=
lse_seqlen
/
(
only_second_half
+
1
));
}
else
{
batch
=
lse_shape
[
0
];
lse_seqlen
=
lse_shape
[
2
];
lse_per_step_seqlen
=
lse_per_step_shape
[
2
];
NVTE_CHECK
(
lse_shape
[
1
]
==
num_heads
);
NVTE_CHECK
(
lse_per_step_shape
[
0
]
==
batch
);
NVTE_CHECK
(
lse_per_step_shape
[
1
]
==
num_heads
);
NVTE_CHECK
(
lse_per_step_seqlen
==
lse_seqlen
/
(
only_second_half
+
1
));
NVTE_CHECK
(
cu_seqlens_shape
[
0
]
==
batch
+
1
);
}
constexpr
int
tile
=
16
;
constexpr
int
block
=
512
;
unsigned
int
grid_x
=
(
static_cast
<
size_t
>
(
total_tokens
)
/
(
only_second_half
+
1
)
*
tile
+
block
-
1
)
/
block
;
dim3
grid
=
{
grid_x
,
(
unsigned
int
)
num_heads
};
if
(
lse_packed
)
{
thd_out_correction_kernel
<
dtype
,
only_second_half
,
tile
,
true
>
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
out
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
out_per_step
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
dim_per_head
,
lse_seqlen
,
lse_per_step_seqlen
);
}
else
{
thd_out_correction_kernel
<
dtype
,
only_second_half
,
tile
,
false
>
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
out
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
out_per_step
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
dim_per_head
,
lse_seqlen
,
lse_per_step_seqlen
);
}
}
void
thd_out_correction
(
Tensor
out
,
const
Tensor
&
out_per_step
,
const
Tensor
&
lse
,
const
Tensor
&
lse_per_step
,
const
Tensor
&
cu_seqlens
,
bool
only_second_half
,
bool
lse_packed
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
if
(
only_second_half
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
out
.
dtype
(),
dtype
,
thd_out_correction_helper
<
dtype
,
1
>
(
out
,
out_per_step
,
lse
,
lse_per_step
,
cu_seqlens
,
lse_packed
,
stream
););
}
else
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
out
.
dtype
(),
dtype
,
thd_out_correction_helper
<
dtype
,
0
>
(
out
,
out_per_step
,
lse
,
lse_per_step
,
cu_seqlens
,
lse_packed
,
stream
););
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/
template
<
typename
dtype
,
typename
Functor_0
,
typename
Functor_1
,
int
functor_idx
>
static
void
thd_grad_correction_helper
(
Tensor
grad
,
const
Tensor
&
grad_per_step
,
const
Tensor
&
cu_seqlens
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
grad
.
dim
()
==
3
||
grad
.
dim
()
==
4
);
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
NVTE_CHECK
(
cu_seqlens
.
dim
()
==
1
);
auto
grad_shape
=
grad
.
shape
();
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
auto
grad_per_step_shape
=
grad_per_step
.
shape
();
// Shape of dq is [t, h, d], so the dimension of "t" is 0
// Shape of dkv is [2, t, h, d], so the dimension of "t" is 1
int
seq_dim
=
grad
.
dim
()
==
3
?
0
:
1
;
int
total_tokens
=
grad_shape
[
seq_dim
];
int
num_heads
=
grad_shape
[
seq_dim
+
1
];
int
dim_per_head
=
grad_shape
[
seq_dim
+
2
];
int
batch
=
cu_seqlens_shape
[
0
]
-
1
;
if
constexpr
(
functor_idx
<
2
)
{
NVTE_CHECK
(
grad_per_step_shape
[
seq_dim
]
==
total_tokens
/
2
);
}
else
{
NVTE_CHECK
(
grad_per_step_shape
[
seq_dim
]
==
total_tokens
);
}
NVTE_CHECK
(
grad_per_step_shape
[
seq_dim
+
1
]
==
num_heads
);
NVTE_CHECK
(
grad_per_step_shape
[
seq_dim
+
2
]
==
dim_per_head
);
size_t
hidden_size
=
num_heads
*
dim_per_head
;
NVTE_CHECK
((
hidden_size
*
typeToSize
(
grad
.
dtype
()))
%
16
==
0
);
constexpr
unsigned
int
block
=
256
;
unsigned
int
grid_x
;
if
constexpr
(
functor_idx
<
2
)
{
grid_x
=
(
total_tokens
/
2
*
32
+
block
-
1
)
/
block
;
}
else
{
grid_x
=
(
total_tokens
*
32
+
block
-
1
)
/
block
;
}
unsigned
int
grid_y
=
1
;
for
(
int
i
=
0
;
i
<
seq_dim
;
i
++
)
{
grid_y
*=
grad_shape
[
i
];
}
dim3
grid
=
{
grid_x
,
grid_y
};
thd_grad_correction_kernel
<
dtype
,
Functor_0
,
Functor_1
,
functor_idx
,
32
>
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
grad
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
grad_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
hidden_size
,
total_tokens
);
}
template
<
typename
dtype
>
static
void
thd_grad_dispatcher
(
Tensor
grad
,
const
Tensor
&
grad_per_step
,
const
Tensor
&
cu_seqlens
,
const
std
::
string
&
first_half
,
const
std
::
string
&
second_half
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
if
(
first_half
==
"add"
&&
second_half
==
"none"
)
{
thd_grad_correction_helper
<
dtype
,
AddFunctor
<
dtype
>
,
EmptyFunctor
,
0
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
if
(
first_half
==
"copy"
&&
second_half
==
"none"
)
{
thd_grad_correction_helper
<
dtype
,
CopyFunctor
,
EmptyFunctor
,
0
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
if
(
first_half
==
"none"
&&
second_half
==
"add"
)
{
thd_grad_correction_helper
<
dtype
,
EmptyFunctor
,
AddFunctor
<
dtype
>
,
1
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
if
(
first_half
==
"none"
&&
second_half
==
"copy"
)
{
thd_grad_correction_helper
<
dtype
,
EmptyFunctor
,
CopyFunctor
,
1
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
if
(
first_half
==
"add"
&&
second_half
==
"copy"
)
{
thd_grad_correction_helper
<
dtype
,
AddFunctor
<
dtype
>
,
CopyFunctor
,
2
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
if
(
first_half
==
"copy"
&&
second_half
==
"add"
)
{
thd_grad_correction_helper
<
dtype
,
CopyFunctor
,
AddFunctor
<
dtype
>
,
2
>
(
grad
,
grad_per_step
,
cu_seqlens
,
stream
);
}
else
{
NVTE_ERROR
(
"Unsupported Functor of first half and second_half
\n
"
);
}
}
void
thd_grad_correction
(
Tensor
grad
,
const
Tensor
&
grad_per_step
,
const
Tensor
&
cu_seqlens
,
const
std
::
string
&
first_half
,
const
std
::
string
&
second_half
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
grad
.
dtype
(),
dtype
,
thd_grad_dispatcher
<
dtype
>
(
grad
,
grad_per_step
,
cu_seqlens
,
first_half
,
second_half
,
stream
););
}
/***************************************************************************************************
* Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/
void
thd_get_partitioned_indices
(
const
Tensor
&
cu_seqlens
,
Tensor
output
,
int
total_tokens
,
int
world_size
,
int
rank
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
cu_seqlens
.
dtype
()
==
DType
::
kInt32
);
NVTE_CHECK
(
cu_seqlens
.
dim
()
==
1
);
auto
cu_seqlens_shape
=
cu_seqlens
.
shape
();
auto
output_shape
=
output
.
shape
();
NVTE_CHECK
(
cu_seqlens_shape
[
0
]
>=
2
);
NVTE_CHECK
(
rank
>=
0
&&
rank
<
world_size
);
NVTE_CHECK
(
world_size
>
0
);
NVTE_CHECK
(
total_tokens
>
0
&&
total_tokens
%
(
world_size
*
2
)
==
0
);
int
batch
=
cu_seqlens_shape
[
0
]
-
1
;
constexpr
unsigned
int
block
=
256
;
unsigned
int
grid
=
(
output_shape
[
0
]
+
block
-
1
)
/
block
;
thd_partition_indices_kernel
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
int
*>
(
output
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
total_tokens
,
world_size
,
rank
);
}
}
// namespace context_parallel
}
// namespace transformer_engine
#endif
void
nvte_cp_thd_read_half_tensor
(
const
NVTETensor
&
tensor
,
const
NVTETensor
&
cu_seqlens
,
NVTETensor
half
,
int
half_idx
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_read_half_tensor
);
using
namespace
transformer_engine
;
context_parallel
::
thd_read_half_tensor
(
*
reinterpret_cast
<
Tensor
*>
(
tensor
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
half
),
half_idx
,
stream
);
}
void
nvte_cp_thd_second_half_lse_correction
(
NVTETensor
lse
,
const
NVTETensor
&
lse_per_step
,
const
NVTETensor
&
cu_seqlens
,
int
lse_packed
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_second_half_lse_correction
);
using
namespace
transformer_engine
;
context_parallel
::
thd_second_half_lse_correction
(
*
reinterpret_cast
<
Tensor
*>
(
lse
),
*
reinterpret_cast
<
Tensor
*>
(
lse_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
lse_packed
,
stream
);
}
void
nvte_cp_thd_read_second_half_lse
(
const
NVTETensor
&
lse
,
const
NVTETensor
&
cu_seqlens
,
NVTETensor
half_lse
,
int
lse_packed
,
int
second_half_lse_seqlen
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_read_second_half_lse
);
using
namespace
transformer_engine
;
context_parallel
::
thd_read_second_half_lse
(
*
reinterpret_cast
<
Tensor
*>
(
lse
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
half_lse
),
lse_packed
,
second_half_lse_seqlen
,
stream
);
}
void
nvte_cp_thd_out_correction
(
NVTETensor
out
,
const
NVTETensor
&
out_per_step
,
const
NVTETensor
&
lse
,
const
NVTETensor
&
lse_per_step
,
const
NVTETensor
&
cu_seqlens
,
int
only_second_half
,
int
lse_packed
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_out_correction
);
using
namespace
transformer_engine
;
context_parallel
::
thd_out_correction
(
*
reinterpret_cast
<
Tensor
*>
(
out
),
*
reinterpret_cast
<
Tensor
*>
(
out_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
lse
),
*
reinterpret_cast
<
Tensor
*>
(
lse_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
only_second_half
,
lse_packed
,
stream
);
}
void
nvte_cp_thd_grad_correction
(
NVTETensor
grad
,
const
NVTETensor
&
grad_per_step
,
const
NVTETensor
&
cu_seqlens
,
const
char
*
first_half
,
const
char
*
second_half
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_grad_correction
);
using
namespace
transformer_engine
;
std
::
string
first_half_str
(
first_half
);
std
::
string
second_half_str
(
second_half
);
context_parallel
::
thd_grad_correction
(
*
reinterpret_cast
<
Tensor
*>
(
grad
),
*
reinterpret_cast
<
Tensor
*>
(
grad_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
first_half_str
,
second_half_str
,
stream
);
}
void
nvte_cp_thd_get_partitioned_indices
(
const
NVTETensor
&
cu_seqlens
,
NVTETensor
output
,
int
total_tokens
,
int
world_size
,
int
rank
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_thd_get_partitioned_indices
);
using
namespace
transformer_engine
;
context_parallel
::
thd_get_partitioned_indices
(
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
output
),
total_tokens
,
world_size
,
rank
,
stream
);
}
transformer_engine/common/fused_attn/flash_attn.cu
0 → 100644
View file @
f8c2af4c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace
transformer_engine
{
namespace
flash_attention
{
constexpr
int
warp_size
=
32
;
constexpr
int
type_size
=
2
;
// FP16 or BF16
constexpr
int
nvec
=
sizeof
(
uint64_t
)
/
type_size
;
constexpr
int
load_size
=
warp_size
*
nvec
;
constexpr
int
block_size
=
512
;
template
<
typename
T
>
__launch_bounds__
(
block_size
)
__global__
void
prepare_kernel_fwd
(
const
T
*
qkvi
,
T
*
qkv
,
const
size_t
B
,
const
size_t
S
,
const
size_t
Z
,
const
size_t
W
)
{
const
int
warpid
=
(
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
)
/
warp_size
;
const
int
id_in_warp
=
threadIdx
.
x
%
warp_size
;
const
size_t
offset_input
=
blockIdx
.
y
*
W
+
warpid
*
3
*
W
*
Z
+
id_in_warp
*
nvec
;
const
T
*
my_input
=
qkvi
+
offset_input
;
const
size_t
s
=
warpid
/
B
;
if
(
s
>=
S
)
return
;
const
size_t
b
=
warpid
%
B
;
const
size_t
offset_output
=
blockIdx
.
y
*
B
*
S
*
Z
*
W
+
(
s
+
b
*
S
)
*
W
*
Z
+
id_in_warp
*
nvec
;
T
*
my_output
=
qkv
+
offset_output
;
for
(
int
i
=
0
;
i
<
Z
;
++
i
)
{
uint64_t
*
out
=
reinterpret_cast
<
uint64_t
*>
(
my_output
+
i
*
load_size
);
*
out
=
*
reinterpret_cast
<
const
uint64_t
*>
(
my_input
+
i
*
load_size
*
3
);
}
}
template
<
typename
T
>
__launch_bounds__
(
block_size
)
__global__
void
prepare_kernel_bwd
(
const
T
*
q
,
const
T
*
k
,
const
T
*
v
,
T
*
qkv
,
const
size_t
B
,
const
size_t
S
,
const
size_t
Z
,
const
size_t
W
)
{
const
T
*
input
=
blockIdx
.
y
==
0
?
q
:
(
blockIdx
.
y
==
1
?
k
:
v
);
const
int
warpid
=
(
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
)
/
warp_size
;
const
int
id_in_warp
=
threadIdx
.
x
%
warp_size
;
const
size_t
offset_input
=
warpid
*
W
*
Z
+
id_in_warp
*
nvec
;
const
T
*
my_input
=
input
+
offset_input
;
const
size_t
b
=
warpid
/
S
;
if
(
b
>=
B
)
return
;
const
size_t
s
=
warpid
%
S
;
const
size_t
offset_output
=
(
b
+
s
*
B
)
*
3
*
W
*
Z
+
id_in_warp
*
nvec
+
blockIdx
.
y
*
W
;
T
*
my_output
=
qkv
+
offset_output
;
for
(
int
i
=
0
;
i
<
Z
;
++
i
)
{
uint64_t
*
out
=
reinterpret_cast
<
uint64_t
*>
(
my_output
+
i
*
load_size
*
3
);
*
out
=
*
reinterpret_cast
<
const
uint64_t
*>
(
my_input
+
i
*
load_size
);
}
}
void
prepare_flash_attn_fwd
(
Tensor
qkvi
,
Tensor
qkv
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
qkvi
.
dim
()
==
4
,
"Expected 4-dim tensor."
);
NVTE_CHECK
(
qkvi
.
dtype
()
==
DType
::
kFloat16
||
qkvi
.
dtype
()
==
DType
::
kBFloat16
);
auto
qkvi_shape
=
qkvi
.
shape
();
NVTE_CHECK
(
qkvi_shape
[
3
]
%
load_size
==
0
);
NVTE_CHECK
(
qkvi_shape
[
3
]
==
load_size
);
// [s, b, n, h * 3] -> [3, b, s, n, h]
std
::
vector
<
uint64_t
>
shape
=
{
3
,
qkvi_shape
[
1
],
qkvi_shape
[
0
],
qkvi_shape
[
2
],
qkvi_shape
[
3
]};
size_t
warps
=
qkvi_shape
[
0
]
*
qkvi_shape
[
1
];
size_t
warps_per_block
=
block_size
/
warp_size
;
size_t
blocks
=
(
warps
+
warps_per_block
-
1
)
/
warps_per_block
;
dim3
grid
(
blocks
,
3
);
int
threads
=
block_size
;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT
(
qkvi
.
dtype
(),
dtype
,
prepare_kernel_fwd
<
dtype
><<<
grid
,
threads
,
0
,
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
qkvi
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
qkv
.
data
.
dptr
),
shape
[
1
],
shape
[
2
],
shape
[
3
],
shape
[
4
]););
}
void
prepare_flash_attn_bwd
(
Tensor
q
,
Tensor
k
,
Tensor
v
,
Tensor
qkv
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
q
.
dim
()
==
4
,
"Expected 4-dim tensor."
);
NVTE_CHECK
(
k
.
dim
()
==
4
,
"Expected 4-dim tensor."
);
NVTE_CHECK
(
v
.
dim
()
==
4
,
"Expected 4-dim tensor."
);
NVTE_CHECK
(
q
.
dtype
()
==
DType
::
kFloat16
||
q
.
dtype
()
==
DType
::
kBFloat16
);
NVTE_CHECK
(
k
.
dtype
()
==
q
.
dtype
());
NVTE_CHECK
(
v
.
dtype
()
==
q
.
dtype
());
auto
q_shape
=
q
.
shape
();
auto
k_shape
=
k
.
shape
();
auto
v_shape
=
v
.
shape
();
NVTE_CHECK
(
q_shape
[
3
]
%
load_size
==
0
);
NVTE_CHECK
(
q_shape
[
3
]
==
load_size
);
NVTE_CHECK
(
k_shape
[
3
]
%
load_size
==
0
);
NVTE_CHECK
(
k_shape
[
3
]
==
load_size
);
NVTE_CHECK
(
v_shape
[
3
]
%
load_size
==
0
);
NVTE_CHECK
(
v_shape
[
3
]
==
load_size
);
// 3 x [s, b, n, h] -> [b, s, n, 3 * h]
std
::
vector
<
uint64_t
>
shape
=
{
q_shape
[
1
],
q_shape
[
0
],
q_shape
[
2
],
3
*
q_shape
[
3
]};
size_t
warps
=
q_shape
[
0
]
*
q_shape
[
1
];
size_t
warps_per_block
=
block_size
/
warp_size
;
size_t
blocks
=
(
warps
+
warps_per_block
-
1
)
/
warps_per_block
;
dim3
grid
(
blocks
,
3
);
int
threads
=
block_size
;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT
(
q
.
dtype
(),
dtype
,
prepare_kernel_bwd
<
dtype
><<<
grid
,
threads
,
0
,
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
q
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
k
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
v
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
qkv
.
data
.
dptr
),
q_shape
[
0
],
q_shape
[
1
],
q_shape
[
2
],
q_shape
[
3
]););
}
}
// namespace flash_attention
}
// namespace transformer_engine
void
nvte_prepare_flash_attn_fwd
(
NVTETensor
qkvi
,
NVTETensor
qkv
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_prepare_flash_attn_fwd
);
using
namespace
transformer_engine
;
flash_attention
::
prepare_flash_attn_fwd
(
*
reinterpret_cast
<
Tensor
*>
(
qkvi
),
*
reinterpret_cast
<
Tensor
*>
(
qkv
),
stream
);
}
void
nvte_prepare_flash_attn_bwd
(
NVTETensor
q
,
NVTETensor
k
,
NVTETensor
v
,
NVTETensor
qkv
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_prepare_flash_attn_bwd
);
using
namespace
transformer_engine
;
flash_attention
::
prepare_flash_attn_bwd
(
*
reinterpret_cast
<
Tensor
*>
(
q
),
*
reinterpret_cast
<
Tensor
*>
(
k
),
*
reinterpret_cast
<
Tensor
*>
(
v
),
*
reinterpret_cast
<
Tensor
*>
(
qkv
),
stream
);
}
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
f8c2af4c
...
...
@@ -1006,3 +1006,18 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_ERROR
(
"Invalid combination of data type and sequence length for fused attention.
\n
"
);
}
}
uint32_t
nvte_get_runtime_num_segments
(
NVTETensor
cu_seqlen
,
NVTETensor
workspace
,
size_t
len
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_get_runtime_num_segments
);
using
namespace
transformer_engine
::
fused_attn
;
return
GetRuntimeNumSegments
(
cu_seqlen
,
workspace
,
len
,
stream
);
}
void
nvte_populate_rng_state_async
(
NVTETensor
rng_state_dst
,
const
NVTETensor
seed
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
NVTE_Fused_Attn_Backend
backend
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_populate_rng_state_async
);
using
namespace
transformer_engine
::
fused_attn
;
PopulateRngStateAsync
(
rng_state_dst
,
seed
,
q_max_seqlen
,
kv_max_seqlen
,
backend
,
stream
);
}
transformer_engine/
pytorch/csrc
/kv_cache.cu
h
→
transformer_engine/
common/fused_attn
/kv_cache.cu
View file @
f8c2af4c
...
...
@@ -3,48 +3,15 @@
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace
transformer_engine
{
namespace
fused_attn
{
template
<
typename
scalar_t
>
__global__
void
convert_thd_to_bshd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
num_elts
=
(
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
])
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
scalar_t
*
thd_token
=
tensor
+
thd_offset
;
scalar_t
*
bshd_token
=
new_tensor
+
bshd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
bshd_token
+
i
)
=
*
(
thd_token
+
i
);
}
}
}
#include "../common.h"
#include "transformer_engine/fused_attn.h"
template
<
typename
scalar_t
>
__global__
void
convert_bshd_to_thd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
seqlen
=
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
];
int
num_elts
=
seqlen
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
scalar_t
*
bshd_token
=
tensor
+
bshd_offset
;
scalar_t
*
thd_token
=
new_tensor
+
thd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
thd_token
+
i
)
=
*
(
bshd_token
+
i
);
}
}
}
namespace
transformer_engine
{
namespace
kv_cache
{
template
<
typename
scalar_t
>
__global__
void
reindex_kv_cache_kernel
(
scalar_t
*
k_cache
,
scalar_t
*
v_cache
,
int
*
batch_indices
,
template
<
typename
dtype
>
__global__
void
reindex_kv_cache_kernel
(
dtype
*
k_cache
,
dtype
*
v_cache
,
int
*
batch_indices
,
int
*
cu_new_lens
,
int
*
cu_cached_lens
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
b
,
int
max_seq_len
)
{
// k_cache, v_cache: bshd
...
...
@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in
}
}
template
<
typename
scalar_t
>
__global__
void
copy_to_kv_cache_kernel
(
scalar_t
*
new_k
,
scalar_t
*
new_v
,
scalar_t
*
k
_cache
,
scalar_t
*
v_cache
,
int
*
page_table
,
int
*
cu_new_lens
,
int
*
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
template
<
typename
dtype
>
__global__
void
copy_to_kv_cache_kernel
(
dtype
*
new_k
,
dtype
*
new_v
,
dtype
*
k_cache
,
dtype
*
v
_cache
,
int
*
page_table
,
int
*
cu_new_lens
,
int
*
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
bool
is_non_paged
)
{
// new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1]
...
...
@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
}
}
}
}
// namespace fused_attn
template
<
typename
dtype
>
void
copy_to_kv_cache_launcher
(
Tensor
new_k
,
Tensor
new_v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
page_table
,
Tensor
cu_new_lens
,
Tensor
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
bool
is_non_paged
,
cudaStream_t
stream
)
{
if
(
new_k
.
has_data
()
&&
new_v
.
has_data
()
&&
k_cache
.
has_data
()
&&
v_cache
.
has_data
())
{
if
(
is_non_paged
)
{
reindex_kv_cache_kernel
<<<
16
,
256
,
0
,
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
k_cache
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
v_cache
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
page_table
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_new_lens
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_cached_lens
.
data
.
dptr
),
h_kv
,
d_k
,
d_v
,
b
,
max_seq_len
);
}
copy_to_kv_cache_kernel
<<<
16
,
256
,
0
,
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
new_k
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
new_v
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
k_cache
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
v_cache
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
page_table
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_new_lens
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_cached_lens
.
data
.
dptr
),
qkv_format
,
h_kv
,
d_k
,
d_v
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
);
}
}
void
copy_to_kv_cache
(
Tensor
new_k
,
Tensor
new_v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
page_table
,
Tensor
cu_new_lens
,
Tensor
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
bool
is_non_paged
,
cudaStream_t
stream
)
{
int
h_kv
=
new_k
.
shape
()[
new_k
.
dim
()
-
2
];
int
d_k
=
new_k
.
shape
()[
new_k
.
dim
()
-
1
];
int
d_v
=
new_v
.
shape
()[
new_v
.
dim
()
-
1
];
NVTE_CHECK
(
k_cache
.
dtype
()
==
v_cache
.
dtype
()
&&
new_k
.
dtype
()
==
new_v
.
dtype
()
&&
new_k
.
dtype
()
==
k_cache
.
dtype
(),
"new_k, new_v, k_cache and v_cache must be of the same data type."
);
NVTE_CHECK
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
,
"qkv_format must be {BSHD, SBHD, THD}."
);
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT
(
k_cache
.
dtype
(),
dtype
,
copy_to_kv_cache_launcher
<
dtype
>
(
new_k
,
new_v
,
k_cache
,
v_cache
,
page_table
,
cu_new_lens
,
cu_cached_lens
,
qkv_format
,
h_kv
,
d_k
,
d_v
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
,
stream
););
}
template
<
typename
scalar_t
>
__global__
void
convert_thd_to_bshd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
num_elts
=
(
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
])
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
scalar_t
*
thd_token
=
tensor
+
thd_offset
;
scalar_t
*
bshd_token
=
new_tensor
+
bshd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
bshd_token
+
i
)
=
*
(
thd_token
+
i
);
}
}
}
template
<
typename
scalar_t
>
void
convert_thd_to_bshd_launcher
(
Tensor
tensor
,
Tensor
new_tensor
,
Tensor
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
convert_thd_to_bshd_kernel
<<<
16
,
256
,
0
,
stream
>>>
(
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
b
,
max_seq_len
,
h
,
d
);
}
void
convert_thd_to_bshd
(
Tensor
tensor
,
Tensor
cu_seqlens
,
Tensor
new_tensor
,
int
b
,
int
max_seq_len
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
auto
tensor_shape
=
tensor
.
shape
();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT
(
new_tensor
.
dtype
(),
dtype
,
convert_thd_to_bshd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
tensor_shape
[
1
],
tensor_shape
[
2
],
stream
););
}
template
<
typename
scalar_t
>
__global__
void
convert_bshd_to_thd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
seqlen
=
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
];
int
num_elts
=
seqlen
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
scalar_t
*
bshd_token
=
tensor
+
bshd_offset
;
scalar_t
*
thd_token
=
new_tensor
+
thd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
thd_token
+
i
)
=
*
(
bshd_token
+
i
);
}
}
}
template
<
typename
scalar_t
>
void
convert_bshd_to_thd_launcher
(
Tensor
tensor
,
Tensor
new_tensor
,
Tensor
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
convert_bshd_to_thd_kernel
<<<
16
,
256
,
0
,
stream
>>>
(
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
b
,
max_seq_len
,
h
,
d
);
}
void
convert_bshd_to_thd
(
Tensor
tensor
,
Tensor
cu_seqlens
,
Tensor
new_tensor
,
int
t
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
auto
tensor_shape
=
tensor
.
shape
();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT
(
tensor
.
dtype
(),
dtype
,
convert_bshd_to_thd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
tensor_shape
[
0
],
tensor_shape
[
1
],
tensor_shape
[
2
],
tensor_shape
[
3
],
stream
););
}
}
// namespace kv_cache
}
// namespace transformer_engine
#endif
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
void
nvte_copy_to_kv_cache
(
NVTETensor
new_k
,
NVTETensor
new_v
,
NVTETensor
k_cache
,
NVTETensor
v_cache
,
NVTETensor
page_table
,
NVTETensor
cu_new_lens
,
NVTETensor
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
int
is_non_paged
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_copy_to_kv_cache
);
using
namespace
transformer_engine
;
kv_cache
::
copy_to_kv_cache
(
*
reinterpret_cast
<
Tensor
*>
(
new_k
),
*
reinterpret_cast
<
Tensor
*>
(
new_v
),
*
reinterpret_cast
<
Tensor
*>
(
k_cache
),
*
reinterpret_cast
<
Tensor
*>
(
v_cache
),
*
reinterpret_cast
<
Tensor
*>
(
page_table
),
*
reinterpret_cast
<
Tensor
*>
(
cu_new_lens
),
*
reinterpret_cast
<
Tensor
*>
(
cu_cached_lens
),
qkv_format
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
,
stream
);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
void
nvte_convert_thd_to_bshd
(
NVTETensor
tensor
,
NVTETensor
cu_seqlens
,
NVTETensor
new_tensor
,
int
b
,
int
max_seq_len
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_convert_thd_to_bshd
);
using
namespace
transformer_engine
;
kv_cache
::
convert_thd_to_bshd
(
*
reinterpret_cast
<
Tensor
*>
(
tensor
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
new_tensor
),
b
,
max_seq_len
,
stream
);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
void
nvte_convert_bshd_to_thd
(
NVTETensor
tensor
,
NVTETensor
cu_seqlens
,
NVTETensor
new_tensor
,
int
t
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_convert_bshd_to_thd
);
using
namespace
transformer_engine
;
kv_cache
::
convert_bshd_to_thd
(
*
reinterpret_cast
<
Tensor
*>
(
tensor
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
new_tensor
),
t
,
stream
);
}
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment