Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
4099aa8e
Commit
4099aa8e
authored
Mar 20, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
c520cba3
96f9c6de
Changes
49
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1949 additions
and
943 deletions
+1949
-943
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+202
-39
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
...gine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
+329
-238
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
...ngine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
+18
-14
transformer_engine/common/fused_attn/fused_attn_fp8.cu
transformer_engine/common/fused_attn/fused_attn_fp8.cu
+12
-0
transformer_engine/common/fused_attn/utils.cu
transformer_engine/common/fused_attn/utils.cu
+86
-19
transformer_engine/common/fused_attn/utils.h
transformer_engine/common/fused_attn/utils.h
+16
-7
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+10
-3
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+68
-29
transformer_engine/common/normalization/common.cpp
transformer_engine/common/normalization/common.cpp
+1
-1
transformer_engine/common/util/handle_manager.h
transformer_engine/common/util/handle_manager.h
+52
-0
transformer_engine/common/util/pybind_helper.h
transformer_engine/common/util/pybind_helper.h
+19
-1
transformer_engine/jax/cpp_extensions/attention.py
transformer_engine/jax/cpp_extensions/attention.py
+34
-13
transformer_engine/jax/csrc/extensions/attention.cpp
transformer_engine/jax/csrc/extensions/attention.cpp
+18
-14
transformer_engine/pytorch/attention.py
transformer_engine/pytorch/attention.py
+822
-546
transformer_engine/pytorch/constants.py
transformer_engine/pytorch/constants.py
+10
-0
transformer_engine/pytorch/cpp_extensions/fused_attn.py
transformer_engine/pytorch/cpp_extensions/fused_attn.py
+29
-0
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+2
-0
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+10
-2
transformer_engine/pytorch/csrc/extensions/attention.cu
transformer_engine/pytorch/csrc/extensions/attention.cu
+203
-16
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+8
-1
No files found.
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
4099aa8e
This diff is collapsed.
Click to expand it.
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
View file @
4099aa8e
This diff is collapsed.
Click to expand it.
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
View file @
4099aa8e
...
@@ -38,13 +38,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
...
@@ -38,13 +38,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
void
fused_attn_arbitrary_seqlen_fwd_kvpacked
(
void
fused_attn_arbitrary_seqlen_fwd_kvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
max_seqlen_kv
,
size_t
head_dim
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
const
Tensor
*
input_Bias
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_Bias
,
Tensor
*
output_O
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd_kvpacked
(
void
fused_attn_arbitrary_seqlen_bwd_kvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
...
@@ -61,13 +63,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
...
@@ -61,13 +63,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void
fused_attn_arbitrary_seqlen_fwd
(
void
fused_attn_arbitrary_seqlen_fwd
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_Bias
,
Tensor
*
output_O
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_Bias
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd
(
void
fused_attn_arbitrary_seqlen_bwd
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
...
...
transformer_engine/common/fused_attn/fused_attn_fp8.cu
View file @
4099aa8e
...
@@ -1679,6 +1679,12 @@ void fused_attn_fp8_fwd_impl_v1(
...
@@ -1679,6 +1679,12 @@ void fused_attn_fp8_fwd_impl_v1(
s_kv
,
s_kv
,
d
,
d
,
d
,
d
,
0
,
0
,
0
,
0
,
0
,
0
,
bias_b
,
bias_b
,
bias_h
,
bias_h
,
scaling_factor
,
scaling_factor
,
...
@@ -1977,6 +1983,12 @@ void fused_attn_fp8_bwd_impl_v1(
...
@@ -1977,6 +1983,12 @@ void fused_attn_fp8_bwd_impl_v1(
s_kv
,
s_kv
,
d
,
d
,
d
,
d
,
0
,
0
,
0
,
0
,
0
,
0
,
bias_b
,
bias_b
,
bias_h
,
bias_h
,
scaling_factor
,
scaling_factor
,
...
...
transformer_engine/common/fused_attn/utils.cu
View file @
4099aa8e
...
@@ -117,6 +117,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
...
@@ -117,6 +117,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
}
}
break
;
break
;
case
NVTE_QKV_Layout
::
NVTE_SBHD_SBHD_SBHD
:
case
NVTE_QKV_Layout
::
NVTE_SBHD_SBHD_SBHD
:
case
NVTE_QKV_Layout
::
NVTE_Paged_KV_SBHD_SBHD_SBHD
:
if
((
matrix
==
NVTE_QKV_Matrix
::
NVTE_Q_Matrix
)
||
if
((
matrix
==
NVTE_QKV_Matrix
::
NVTE_Q_Matrix
)
||
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_K_Matrix
)
||
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_K_Matrix
)
||
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_V_Matrix
)
||
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_V_Matrix
)
||
...
@@ -223,6 +224,9 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
...
@@ -223,6 +224,9 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
break
;
break
;
case
NVTE_QKV_Layout
::
NVTE_BSHD_BSHD_BSHD
:
case
NVTE_QKV_Layout
::
NVTE_BSHD_BSHD_BSHD
:
case
NVTE_QKV_Layout
::
NVTE_THD_THD_THD
:
case
NVTE_QKV_Layout
::
NVTE_THD_THD_THD
:
case
NVTE_QKV_Layout
::
NVTE_THD_BSHD_BSHD
:
case
NVTE_QKV_Layout
::
NVTE_Paged_KV_BSHD_BSHD_BSHD
:
case
NVTE_QKV_Layout
::
NVTE_Paged_KV_THD_BSHD_BSHD
:
if
((
matrix
==
NVTE_QKV_Matrix
::
NVTE_Q_Matrix
)
||
if
((
matrix
==
NVTE_QKV_Matrix
::
NVTE_Q_Matrix
)
||
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_O_Matrix
))
{
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_O_Matrix
))
{
strideA
[
batch_dim_idx
]
=
s_q
*
h
*
d
;
strideA
[
batch_dim_idx
]
=
s_q
*
h
*
d
;
...
@@ -243,6 +247,52 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
...
@@ -243,6 +247,52 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
strideA
[
hidden_transpose_dim_idx
]
=
1
;
strideA
[
hidden_transpose_dim_idx
]
=
1
;
}
}
break
;
break
;
case
NVTE_QKV_Layout
::
NVTE_SBHD_BSHD_BSHD
:
case
NVTE_QKV_Layout
::
NVTE_Paged_KV_SBHD_BSHD_BSHD
:
if
((
matrix
==
NVTE_QKV_Matrix
::
NVTE_K_Matrix
)
||
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_V_Matrix
))
{
strideA
[
batch_dim_idx
]
=
s_kv
*
h
*
d
;
strideA
[
head_dim_idx
]
=
d
;
strideA
[
seqlen_dim_idx
]
=
h
*
d
;
strideA
[
hidden_dim_idx
]
=
1
;
}
else
if
((
matrix
==
NVTE_QKV_Matrix
::
NVTE_K_Matrix_Transpose
)
||
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_V_Matrix_Transpose
))
{
strideA
[
batch_dim_idx
]
=
s_kv
*
h
*
d
;
strideA
[
head_dim_idx
]
=
d
;
strideA
[
seqlen_transpose_dim_idx
]
=
h
*
d
;
strideA
[
hidden_transpose_dim_idx
]
=
1
;
}
else
if
((
matrix
==
NVTE_QKV_Matrix
::
NVTE_Q_Matrix
)
||
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_O_Matrix
))
{
strideA
[
batch_dim_idx
]
=
h
*
d
;
strideA
[
head_dim_idx
]
=
d
;
strideA
[
seqlen_dim_idx
]
=
b
*
h
*
d
;
strideA
[
hidden_dim_idx
]
=
1
;
}
break
;
case
NVTE_QKV_Layout
::
NVTE_BSHD_SBHD_SBHD
:
case
NVTE_QKV_Layout
::
NVTE_THD_SBHD_SBHD
:
case
NVTE_QKV_Layout
::
NVTE_Paged_KV_BSHD_SBHD_SBHD
:
case
NVTE_QKV_Layout
::
NVTE_Paged_KV_THD_SBHD_SBHD
:
if
((
matrix
==
NVTE_QKV_Matrix
::
NVTE_K_Matrix
)
||
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_V_Matrix
))
{
strideA
[
batch_dim_idx
]
=
h
*
d
;
strideA
[
head_dim_idx
]
=
d
;
strideA
[
seqlen_dim_idx
]
=
b
*
h
*
d
;
strideA
[
hidden_dim_idx
]
=
1
;
}
else
if
((
matrix
==
NVTE_QKV_Matrix
::
NVTE_K_Matrix_Transpose
)
||
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_V_Matrix_Transpose
))
{
strideA
[
batch_dim_idx
]
=
h
*
d
;
strideA
[
head_dim_idx
]
=
d
;
strideA
[
seqlen_transpose_dim_idx
]
=
b
*
h
*
d
;
strideA
[
hidden_transpose_dim_idx
]
=
1
;
}
else
if
((
matrix
==
NVTE_QKV_Matrix
::
NVTE_Q_Matrix
)
||
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_O_Matrix
))
{
strideA
[
batch_dim_idx
]
=
s_q
*
h
*
d
;
strideA
[
head_dim_idx
]
=
d
;
strideA
[
seqlen_dim_idx
]
=
h
*
d
;
strideA
[
hidden_dim_idx
]
=
1
;
}
break
;
}
}
if
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_S_Matrix
)
{
if
(
matrix
==
NVTE_QKV_Matrix
::
NVTE_S_Matrix
)
{
...
@@ -379,28 +429,44 @@ __device__ void cu_seqlens_padded_to_offsets_impl(
...
@@ -379,28 +429,44 @@ __device__ void cu_seqlens_padded_to_offsets_impl(
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
auto
cu_seqlens_id
=
min
(
tid
,
actual_b
);
auto
cu_seqlens_id
=
min
(
tid
,
actual_b
);
if
(
tid
<=
max_b
)
{
if
(
tid
<=
max_b
)
{
offsets_o
[
tid
]
=
h
*
d_v
*
cu_seqlens_q_padded
[
cu_seqlens_id
];
if
(
offsets_s
!=
nullptr
)
{
if
(
offsets_s
!=
nullptr
)
{
offsets_s
[
tid
]
=
h
*
cu_seqlens_q_padded
[
cu_seqlens_id
];
offsets_s
[
tid
]
=
h
*
cu_seqlens_q_padded
[
cu_seqlens_id
];
}
}
switch
(
layout_group
)
{
if
(
offsets_q
!=
nullptr
&&
offsets_o
!=
nullptr
)
{
case
NVTE_QKV_Layout_Group
::
NVTE_HD_HD_HD
:
offsets_o
[
tid
]
=
h
*
d_v
*
cu_seqlens_q_padded
[
cu_seqlens_id
];
offsets_q
[
tid
]
=
h
*
d_qk
*
cu_seqlens_q_padded
[
cu_seqlens_id
];
switch
(
layout_group
)
{
offsets_k
[
tid
]
=
hg
*
d_qk
*
cu_seqlens_kv_padded
[
cu_seqlens_id
];
case
NVTE_QKV_Layout_Group
::
NVTE_HD_HD_HD
:
offsets_v
[
tid
]
=
hg
*
d_v
*
cu_seqlens_kv_padded
[
cu_seqlens_id
];
case
NVTE_QKV_Layout_Group
::
NVTE_Paged_KV_HD_HD_HD
:
break
;
offsets_q
[
tid
]
=
h
*
d_qk
*
cu_seqlens_q_padded
[
cu_seqlens_id
];
case
NVTE_QKV_Layout_Group
::
NVTE_3HD
:
break
;
case
NVTE_QKV_Layout_Group
::
NVTE_H3D
:
case
NVTE_QKV_Layout_Group
::
NVTE_3HD
:
offsets_q
[
tid
]
=
3
*
h
*
d_qk
*
cu_seqlens_q_padded
[
cu_seqlens_id
];
case
NVTE_QKV_Layout_Group
::
NVTE_H3D
:
offsets_k
[
tid
]
=
offsets_q
[
cu_seqlens_id
];
offsets_q
[
tid
]
=
3
*
h
*
d_qk
*
cu_seqlens_q_padded
[
cu_seqlens_id
];
offsets_v
[
tid
]
=
offsets_q
[
cu_seqlens_id
];
break
;
break
;
case
NVTE_QKV_Layout_Group
::
NVTE_HD_2HD
:
case
NVTE_QKV_Layout_Group
::
NVTE_HD_2HD
:
case
NVTE_QKV_Layout_Group
::
NVTE_HD_H2D
:
case
NVTE_QKV_Layout_Group
::
NVTE_HD_H2D
:
offsets_q
[
tid
]
=
h
*
d_qk
*
cu_seqlens_q_padded
[
cu_seqlens_id
];
offsets_q
[
tid
]
=
h
*
d_qk
*
cu_seqlens_q_padded
[
cu_seqlens_id
];
break
;
offsets_k
[
tid
]
=
2
*
hg
*
d_qk
*
cu_seqlens_kv_padded
[
cu_seqlens_id
];
}
offsets_v
[
tid
]
=
offsets_k
[
cu_seqlens_id
];
}
break
;
if
(
offsets_k
!=
nullptr
&&
offsets_v
!=
nullptr
)
{
switch
(
layout_group
)
{
case
NVTE_QKV_Layout_Group
::
NVTE_HD_HD_HD
:
case
NVTE_QKV_Layout_Group
::
NVTE_Paged_KV_HD_HD_HD
:
offsets_k
[
tid
]
=
hg
*
d_qk
*
cu_seqlens_kv_padded
[
cu_seqlens_id
];
offsets_v
[
tid
]
=
hg
*
d_v
*
cu_seqlens_kv_padded
[
cu_seqlens_id
];
break
;
case
NVTE_QKV_Layout_Group
::
NVTE_3HD
:
case
NVTE_QKV_Layout_Group
::
NVTE_H3D
:
offsets_k
[
tid
]
=
3
*
h
*
d_qk
*
cu_seqlens_q_padded
[
cu_seqlens_id
];
offsets_v
[
tid
]
=
offsets_k
[
cu_seqlens_id
];
break
;
case
NVTE_QKV_Layout_Group
::
NVTE_HD_2HD
:
case
NVTE_QKV_Layout_Group
::
NVTE_HD_H2D
:
offsets_k
[
tid
]
=
2
*
hg
*
d_qk
*
cu_seqlens_kv_padded
[
cu_seqlens_id
];
offsets_v
[
tid
]
=
offsets_k
[
cu_seqlens_id
];
break
;
}
}
}
}
}
}
}
...
@@ -433,6 +499,7 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
...
@@ -433,6 +499,7 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
std
::
array
<
int64_t
,
4
>
offsets_qkvo
{};
std
::
array
<
int64_t
,
4
>
offsets_qkvo
{};
switch
(
layout_group
)
{
switch
(
layout_group
)
{
case
NVTE_QKV_Layout_Group
::
NVTE_HD_HD_HD
:
case
NVTE_QKV_Layout_Group
::
NVTE_HD_HD_HD
:
case
NVTE_QKV_Layout_Group
::
NVTE_Paged_KV_HD_HD_HD
:
offsets_qkvo
[
0
]
=
num_attn_heads
*
head_dim_qk
*
max_seqlen_q
;
offsets_qkvo
[
0
]
=
num_attn_heads
*
head_dim_qk
*
max_seqlen_q
;
offsets_qkvo
[
1
]
=
num_gqa_groups
*
head_dim_qk
*
max_seqlen_kv
;
offsets_qkvo
[
1
]
=
num_gqa_groups
*
head_dim_qk
*
max_seqlen_kv
;
offsets_qkvo
[
2
]
=
num_gqa_groups
*
head_dim_v
*
max_seqlen_kv
;
offsets_qkvo
[
2
]
=
num_gqa_groups
*
head_dim_v
*
max_seqlen_kv
;
...
...
transformer_engine/common/fused_attn/utils.h
View file @
4099aa8e
...
@@ -93,6 +93,12 @@ struct FADescriptor_v1 {
...
@@ -93,6 +93,12 @@ struct FADescriptor_v1 {
std
::
int64_t
s_kv
;
std
::
int64_t
s_kv
;
std
::
int64_t
d_qk
;
std
::
int64_t
d_qk
;
std
::
int64_t
d_v
;
std
::
int64_t
d_v
;
std
::
int64_t
num_pages_k
;
std
::
int64_t
num_pages_v
;
std
::
int64_t
page_size_k
;
std
::
int64_t
page_size_v
;
std
::
int64_t
max_pages_per_seq_k
;
std
::
int64_t
max_pages_per_seq_v
;
std
::
int64_t
bias_b
;
std
::
int64_t
bias_b
;
std
::
int64_t
bias_h
;
std
::
int64_t
bias_h
;
float
attnScale
;
float
attnScale
;
...
@@ -108,13 +114,16 @@ struct FADescriptor_v1 {
...
@@ -108,13 +114,16 @@ struct FADescriptor_v1 {
cudnn_frontend
::
DataType_t
bwd_tensor_type
;
cudnn_frontend
::
DataType_t
bwd_tensor_type
;
bool
operator
<
(
const
FADescriptor_v1
&
rhs
)
const
{
bool
operator
<
(
const
FADescriptor_v1
&
rhs
)
const
{
return
std
::
tie
(
b
,
h
,
hg
,
s_q
,
s_kv
,
d_qk
,
d_v
,
bias_b
,
bias_h
,
attnScale
,
isTraining
,
return
std
::
tie
(
b
,
h
,
hg
,
s_q
,
s_kv
,
d_qk
,
d_v
,
num_pages_k
,
num_pages_v
,
page_size_k
,
dropoutProbability
,
layout
,
mask_type
,
window_size_left
,
window_size_right
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
deterministic
,
bias_type
,
fwd_tensor_type
,
bwd_tensor_type
)
<
attnScale
,
isTraining
,
dropoutProbability
,
layout
,
mask_type
,
window_size_left
,
std
::
tie
(
rhs
.
b
,
rhs
.
h
,
rhs
.
hg
,
rhs
.
s_q
,
rhs
.
s_kv
,
rhs
.
d_qk
,
rhs
.
d_v
,
rhs
.
bias_b
,
window_size_right
,
deterministic
,
bias_type
,
fwd_tensor_type
,
bwd_tensor_type
)
<
rhs
.
bias_h
,
rhs
.
attnScale
,
rhs
.
isTraining
,
rhs
.
dropoutProbability
,
rhs
.
layout
,
std
::
tie
(
rhs
.
b
,
rhs
.
h
,
rhs
.
hg
,
rhs
.
s_q
,
rhs
.
s_kv
,
rhs
.
d_qk
,
rhs
.
d_v
,
rhs
.
num_pages_k
,
rhs
.
mask_type
,
rhs
.
window_size_left
,
rhs
.
window_size_right
,
rhs
.
deterministic
,
rhs
.
num_pages_v
,
rhs
.
page_size_k
,
rhs
.
page_size_v
,
rhs
.
max_pages_per_seq_k
,
rhs
.
bias_type
,
rhs
.
fwd_tensor_type
,
rhs
.
bwd_tensor_type
);
rhs
.
max_pages_per_seq_v
,
rhs
.
bias_b
,
rhs
.
bias_h
,
rhs
.
attnScale
,
rhs
.
isTraining
,
rhs
.
dropoutProbability
,
rhs
.
layout
,
rhs
.
mask_type
,
rhs
.
window_size_left
,
rhs
.
window_size_right
,
rhs
.
deterministic
,
rhs
.
bias_type
,
rhs
.
fwd_tensor_type
,
rhs
.
bwd_tensor_type
);
}
}
};
};
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
4099aa8e
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <mutex>
#include <mutex>
#include "../common.h"
#include "../common.h"
#include "../util/handle_manager.h"
#include "../util/logging.h"
#include "../util/logging.h"
#include "common/util/cuda_runtime.h"
#include "common/util/cuda_runtime.h"
...
@@ -54,6 +55,10 @@ uint32_t _getAlignment(uintptr_t address) {
...
@@ -54,6 +55,10 @@ uint32_t _getAlignment(uintptr_t address) {
}
}
}
}
inline
void
CreateCublasHandle
(
cublasLtHandle_t
*
handle
)
{
NVTE_CHECK_CUBLAS
(
cublasLtCreate
(
handle
));
}
struct
GemmParam
{
struct
GemmParam
{
void
*
A
;
void
*
A
;
void
*
B
;
void
*
B
;
...
@@ -147,7 +152,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
...
@@ -147,7 +152,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
#endif // __HIP_PLATFORM_AMD__
#endif // __HIP_PLATFORM_AMD__
namespace
transformer_engine
{
namespace
transformer_engine
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
//Forward declaration. The implementation is in rocm_gemm.cu
//Forward declaration. The implementation is in rocm_gemm.cu
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
...
@@ -157,6 +161,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -157,6 +161,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
);
const
Tensor
*
inputCounter
,
hipStream_t
stream
);
#else // Use cublasLt
#else // Use cublasLt
using
cublasHandleManager
=
detail
::
HandleManager
<
cublasLtHandle_t
,
CreateCublasHandle
>
;
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
bool
grad
,
int
ldb
,
int
ldd
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
bool
grad
,
...
@@ -209,8 +214,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -209,8 +214,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
float
zero
=
0.0
;
float
zero
=
0.0
;
float
beta
=
(
accumulate
)
?
one
:
zero
;
float
beta
=
(
accumulate
)
?
one
:
zero
;
cublasLtHandle_t
handle
;
cublasLtHandle_t
handle
=
cublasHandleManager
::
Instance
().
GetHandle
();
NVTE_CHECK_CUBLAS
(
cublasLtCreate
(
&
handle
));
cublasLtMatmulDesc_t
operationDesc
=
nullptr
;
cublasLtMatmulDesc_t
operationDesc
=
nullptr
;
cublasLtMatrixLayout_t
Adesc
=
nullptr
,
Bdesc
=
nullptr
,
Cdesc
=
nullptr
,
Ddesc
=
nullptr
;
cublasLtMatrixLayout_t
Adesc
=
nullptr
,
Bdesc
=
nullptr
,
Cdesc
=
nullptr
,
Ddesc
=
nullptr
;
...
@@ -362,6 +366,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -362,6 +366,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&
pre_gelu_out
,
sizeof
(
pre_gelu_out
)));
&
pre_gelu_out
,
sizeof
(
pre_gelu_out
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ld_gelumat
,
sizeof
(
ld_gelumat
)));
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ld_gelumat
,
sizeof
(
ld_gelumat
)));
const
cudaDataType_t
aux_type
=
get_cuda_dtype
(
outputPreGelu
->
data
.
dtype
);
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE
,
&
aux_type
,
sizeof
(
aux_type
)));
}
}
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
4099aa8e
...
@@ -25,24 +25,34 @@ extern "C" {
...
@@ -25,24 +25,34 @@ extern "C" {
* head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
* head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
* `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length
* `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length
* or padded to the same length, and `THD`-based layouts are used when sequences have
* or padded to the same length, and `THD`-based layouts are used when sequences have
* different lengths in a batch.
* different lengths in a batch.
`Paged_KV`-based layouts are used for paged attention.
*/
*/
enum
NVTE_QKV_Layout
{
enum
NVTE_QKV_Layout
{
NVTE_SB3HD
=
0
,
/*!< SB3HD layout */
NVTE_SB3HD
=
0
,
/*!< SB3HD layout */
NVTE_SBH3D
=
1
,
/*!< SBH3D layout */
NVTE_SBH3D
=
1
,
/*!< SBH3D layout */
NVTE_SBHD_SB2HD
=
2
,
/*!< SBHD_SB2HD layout */
NVTE_SBHD_SB2HD
=
2
,
/*!< SBHD_SB2HD layout */
NVTE_SBHD_SBH2D
=
3
,
/*!< SBHD_SBH2D layout */
NVTE_SBHD_SBH2D
=
3
,
/*!< SBHD_SBH2D layout */
NVTE_SBHD_SBHD_SBHD
=
4
,
/*!< SBHD_SBHD_SBHD layout */
NVTE_SBHD_SBHD_SBHD
=
4
,
/*!< SBHD_SBHD_SBHD layout */
NVTE_BS3HD
=
5
,
/*!< BS3HD layout */
NVTE_BS3HD
=
5
,
/*!< BS3HD layout */
NVTE_BSH3D
=
6
,
/*!< BSH3D layout */
NVTE_BSH3D
=
6
,
/*!< BSH3D layout */
NVTE_BSHD_BS2HD
=
7
,
/*!< BSHD_BS2HD layout */
NVTE_BSHD_BS2HD
=
7
,
/*!< BSHD_BS2HD layout */
NVTE_BSHD_BSH2D
=
8
,
/*!< BSHD_BSH2D layout */
NVTE_BSHD_BSH2D
=
8
,
/*!< BSHD_BSH2D layout */
NVTE_BSHD_BSHD_BSHD
=
9
,
/*!< BSHD_BSHD_BSHD layout */
NVTE_BSHD_BSHD_BSHD
=
9
,
/*!< BSHD_BSHD_BSHD layout */
NVTE_T3HD
=
10
,
/*!< T3HD layout */
NVTE_T3HD
=
10
,
/*!< T3HD layout */
NVTE_TH3D
=
11
,
/*!< TH3D layout */
NVTE_TH3D
=
11
,
/*!< TH3D layout */
NVTE_THD_T2HD
=
12
,
/*!< THD_T2HD layout */
NVTE_THD_T2HD
=
12
,
/*!< THD_T2HD layout */
NVTE_THD_TH2D
=
13
,
/*!< THD_TH2D layout */
NVTE_THD_TH2D
=
13
,
/*!< THD_TH2D layout */
NVTE_THD_THD_THD
=
14
,
/*!< THD_THD_THD layout */
NVTE_THD_THD_THD
=
14
,
/*!< THD_THD_THD layout */
NVTE_SBHD_BSHD_BSHD
=
15
,
/*!< SBHD_BSHD_BSHD layout */
NVTE_BSHD_SBHD_SBHD
=
16
,
/*!< BSHD_SBHD_SBHD layout */
NVTE_THD_BSHD_BSHD
=
17
,
/*!< THD_BSHD_BSHD layout */
NVTE_THD_SBHD_SBHD
=
18
,
/*!< THD_SBHD_SBHD layout */
NVTE_Paged_KV_BSHD_BSHD_BSHD
=
19
,
/*!< Paged_KV_BSHD_BSHD_BSHD layout */
NVTE_Paged_KV_BSHD_SBHD_SBHD
=
20
,
/*!< Paged_KV_BSHD_SBHD_SBHD layout */
NVTE_Paged_KV_SBHD_BSHD_BSHD
=
21
,
/*!< Paged_KV_SBHD_BSHD_BSHD layout */
NVTE_Paged_KV_SBHD_SBHD_SBHD
=
22
,
/*!< Paged_KV_SBHD_SBHD_SBHD layout */
NVTE_Paged_KV_THD_BSHD_BSHD
=
23
,
/*!< Paged_KV_THD_BSHD_BSHD layout */
NVTE_Paged_KV_THD_SBHD_SBHD
=
24
,
/*!< Paged_KV_THD_SBHD_SBHD layout */
};
};
/*! \enum NVTE_QKV_Layout_Group
/*! \enum NVTE_QKV_Layout_Group
...
@@ -59,18 +69,28 @@ enum NVTE_QKV_Layout_Group {
...
@@ -59,18 +69,28 @@ enum NVTE_QKV_Layout_Group {
NVTE_HD_H2D
=
3
,
NVTE_HD_H2D
=
3
,
/*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */
/*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */
NVTE_HD_HD_HD
=
4
,
NVTE_HD_HD_HD
=
4
,
/*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */
NVTE_Paged_KV_HD_HD_HD
=
5
,
};
};
/*! \enum NVTE_QKV_Format
/*! \enum NVTE_QKV_Format
* \brief QKV formats
* \brief QKV formats
*/
*/
enum
NVTE_QKV_Format
{
enum
NVTE_QKV_Format
{
/*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD */
/*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD
, Paged_KV_SBHD_SBHD_SBHD
*/
NVTE_SBHD
=
0
,
NVTE_SBHD
=
0
,
/*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD */
/*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD
, Paged_KV_BSHD_BSHD_BSHD
*/
NVTE_BSHD
=
1
,
NVTE_BSHD
=
1
,
/*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */
/*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */
NVTE_THD
=
2
,
NVTE_THD
=
2
,
/*! BSHD format for Q and SBHD format for KV, i.e. BSHD_SBHD_SBHD, Paged_KV_BSHD_SBHD_SBHD */
NVTE_BSHD_2SBHD
=
3
,
/*! SBHD format for Q and BSHD format for KV, i.e. SBHD_BSHD_BSHD, Paged_KV_SBHD_BSHD_BSHD */
NVTE_SBHD_2BSHD
=
4
,
/*! THD format for Q and BSHD format for KV, i.e. THD_BSHD_BSHD, Paged_KV_THD_BSHD_BSHD */
NVTE_THD_2BSHD
=
5
,
/*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */
NVTE_THD_2SBHD
=
6
,
};
};
/*! \enum NVTE_Bias_Type
/*! \enum NVTE_Bias_Type
...
@@ -135,6 +155,22 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout);
...
@@ -135,6 +155,22 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout);
*/
*/
NVTE_QKV_Format
nvte_get_qkv_format
(
NVTE_QKV_Layout
qkv_layout
);
NVTE_QKV_Format
nvte_get_qkv_format
(
NVTE_QKV_Layout
qkv_layout
);
/*! \brief Get Q format for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd.
*
* \return q format, e.g. sbhd.
*/
NVTE_QKV_Format
nvte_get_q_format
(
NVTE_QKV_Layout
qkv_layout
);
/*! \brief Get KV format for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd.
*
* \return kv format, e.g. bshd.
*/
NVTE_QKV_Format
nvte_get_kv_format
(
NVTE_QKV_Layout
qkv_layout
);
/*! \brief Get fused attention backend based on input parameters.
/*! \brief Get fused attention backend based on input parameters.
*
*
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] q_dtype The data type of Tensor Q.
...
@@ -312,6 +348,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
...
@@ -312,6 +348,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k].
* \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
...
@@ -329,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
...
@@ -329,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_fused_attn_fwd_kvpacked
(
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
Bias
,
void
nvte_fused_attn_fwd_kvpacked
(
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
rng_state
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed KV input.
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
*
...
@@ -445,6 +481,8 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -445,6 +481,8 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k].
* \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
...
@@ -465,7 +503,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -465,7 +503,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
rng_state
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
...
...
transformer_engine/common/normalization/common.cpp
View file @
4099aa8e
...
@@ -217,7 +217,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
...
@@ -217,7 +217,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
wtype
,
cpp_dtype
,
*
(
reinterpret_cast
<
cpp_dtype
*>
(
_scalar_dptr
.
get
()))
=
(
cpp_dtype
)
1.0
f
;);
wtype
,
cpp_dtype
,
*
(
reinterpret_cast
<
cpp_dtype
*>
(
_scalar_dptr
.
get
()))
=
(
cpp_dtype
)
1.0
f
;);
_handle
=
cudnnExecutionPlanManager
::
Instance
().
Get
Cudnn
Handle
();
_handle
=
cudnnExecutionPlanManager
::
Instance
().
GetHandle
();
_graph
.
set_io_data_type
(
get_cudnn_fe_dtype
(
itype
))
_graph
.
set_io_data_type
(
get_cudnn_fe_dtype
(
itype
))
.
set_intermediate_data_type
(
get_cudnn_fe_dtype
(
ctype
))
.
set_intermediate_data_type
(
get_cudnn_fe_dtype
(
ctype
))
...
...
transformer_engine/common/util/handle_manager.h
0 → 100644
View file @
4099aa8e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_
#define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_
#include <vector>
#include "cuda_runtime.h"
#include "logging.h"
namespace
transformer_engine
::
detail
{
template
<
typename
Handle
,
void
Create
(
Handle
*
),
void
Destroy
(
Handle
)
=
nullptr
>
class
HandleManager
{
public:
static
HandleManager
&
Instance
()
{
static
thread_local
HandleManager
instance
;
return
instance
;
}
Handle
GetHandle
()
{
static
thread_local
std
::
vector
<
bool
>
initialized
(
handles_
.
size
(),
false
);
const
int
device_id
=
cuda
::
current_device
();
NVTE_CHECK
(
0
<=
device_id
&&
device_id
<
handles_
.
size
(),
"invalid CUDA device ID"
);
if
(
!
initialized
[
device_id
])
{
Create
(
&
(
handles_
[
device_id
]));
initialized
[
device_id
]
=
true
;
}
return
handles_
[
device_id
];
}
~
HandleManager
()
{
if
(
Destroy
!=
nullptr
)
{
for
(
auto
&
handle
:
handles_
)
{
Destroy
(
handle
);
}
}
}
private:
HandleManager
()
:
handles_
(
cuda
::
num_devices
(),
nullptr
)
{}
std
::
vector
<
Handle
>
handles_
=
nullptr
;
};
}
// namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_
transformer_engine/common/util/pybind_helper.h
View file @
4099aa8e
...
@@ -36,6 +36,14 @@
...
@@ -36,6 +36,14 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
.value("NVTE_THD", NVTE_QKV_Format::NVTE_THD) \
.value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \
.value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \
.value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \
.value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) \
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) \
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \
...
@@ -51,7 +59,17 @@
...
@@ -51,7 +59,17 @@
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \
.value("NVTE_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD) \
.value("NVTE_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD) \
.value("NVTE_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD) \
.value("NVTE_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD) \
.value("NVTE_Paged_KV_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD) \
.value("NVTE_Paged_KV_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD) \
.value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \
.value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \
.value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \
.value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
...
...
transformer_engine/jax/cpp_extensions/attention.py
View file @
4099aa8e
...
@@ -295,7 +295,10 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -295,7 +295,10 @@ class FusedAttnFwdPrimitive(BasePrimitive):
elif
backend
==
NVTE_Fused_Attn_Backend
.
NVTE_F16_arbitrary_seqlen
:
elif
backend
==
NVTE_Fused_Attn_Backend
.
NVTE_F16_arbitrary_seqlen
:
# cuDNN 9.6 reduces the required softmax shape
# cuDNN 9.6 reduces the required softmax shape
if
get_cudnn_version
()
>=
(
9
,
6
,
0
):
if
get_cudnn_version
()
>=
(
9
,
6
,
0
):
softmax_shape
=
(
*
batch_shape
,
attn_heads
,
q_max_seqlen
,
1
)
if
config
.
qkv_layout
.
is_thd
():
softmax_shape
=
(
*
batch_shape
,
q_max_seqlen
,
attn_heads
,
1
)
else
:
softmax_shape
=
(
*
batch_shape
,
attn_heads
,
q_max_seqlen
,
1
)
else
:
else
:
softmax_shape
=
(
softmax_shape
=
(
*
batch_shape
,
*
batch_shape
,
...
@@ -607,28 +610,49 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -607,28 +610,49 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def
infer_sharding_from_operands
(
config
,
mesh
,
arg_infos
,
result_infos
):
def
infer_sharding_from_operands
(
config
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
del
result_infos
q_spec
=
get_padded_spec
(
arg_infos
[
0
])
q_spec
=
get_padded_spec
(
arg_infos
[
0
])
# when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+
# otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments)
is_packed_softmax
=
get_cudnn_version
()
>=
(
9
,
6
,
0
)
and
config
.
qkv_layout
.
is_thd
()
if
config
.
qkv_layout
.
is_qkvpacked
():
if
config
.
qkv_layout
.
is_qkvpacked
():
# q_spec = (...batch, q_seqlen, 3, head, hidden)
# q_spec = (...batch, q_seqlen, 3, head, hidden)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
[:
-
3
],
*
q_spec
[
-
2
:]))
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
[:
-
3
],
*
q_spec
[
-
2
:]))
softmax_aux_sharding
=
NamedSharding
(
if
not
is_packed_softmax
:
mesh
,
PartitionSpec
(
*
q_spec
[:
-
4
],
q_spec
[
-
2
],
q_spec
[
-
4
],
None
)
softmax_aux_sharding
=
NamedSharding
(
)
mesh
,
PartitionSpec
(
*
q_spec
[:
-
4
],
q_spec
[
-
2
],
q_spec
[
-
4
],
None
)
)
else
:
softmax_aux_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
[:
-
4
],
q_spec
[
-
4
],
q_spec
[
-
2
],
None
)
)
elif
config
.
qkv_layout
.
is_kvpacked
():
elif
config
.
qkv_layout
.
is_kvpacked
():
# q_spec = (...batch, q_seqlen, head, hidden)
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
softmax_aux_sharding
=
NamedSharding
(
if
not
is_packed_softmax
:
mesh
,
PartitionSpec
(
*
q_spec
[:
-
3
],
q_spec
[
-
2
],
q_spec
[
-
3
],
None
)
softmax_aux_sharding
=
NamedSharding
(
)
mesh
,
PartitionSpec
(
*
q_spec
[:
-
3
],
q_spec
[
-
2
],
q_spec
[
-
3
],
None
)
)
else
:
softmax_aux_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
[:
-
3
],
q_spec
[
-
3
],
q_spec
[
-
2
],
None
)
)
elif
config
.
qkv_layout
.
is_separate
():
elif
config
.
qkv_layout
.
is_separate
():
# q_spec = (...batch, q_seqlen, head, hidden)
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
softmax_aux_sharding
=
NamedSharding
(
if
not
is_packed_softmax
:
mesh
,
PartitionSpec
(
*
q_spec
[:
-
3
],
q_spec
[
-
2
],
q_spec
[
-
3
],
None
)
softmax_aux_sharding
=
NamedSharding
(
)
mesh
,
PartitionSpec
(
*
q_spec
[:
-
3
],
q_spec
[
-
2
],
q_spec
[
-
3
],
None
)
)
else
:
softmax_aux_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
[:
-
3
],
q_spec
[
-
3
],
q_spec
[
-
2
],
None
)
)
else
:
else
:
raise
ValueError
(
f
"Unsupported
{
config
.
qkv_layout
=
}
"
)
raise
ValueError
(
f
"Unsupported
{
config
.
qkv_layout
=
}
"
)
rng_state_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
))
rng_state_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
))
return
(
out_sharding
,
softmax_aux_sharding
,
rng_state_sharding
)
return
(
out_sharding
,
softmax_aux_sharding
,
rng_state_sharding
)
...
@@ -2236,7 +2260,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -2236,7 +2260,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
subblock_config
,
subblock_config
,
)
)
# TODO(rewang): THD softmax_aux layout is acutally [B, S, H]
softmax_aux_per_step
=
softmax_aux_per_step
.
reshape
((
batch
,
q_max_seqlen
,
head
,
1
))
softmax_aux_per_step
=
softmax_aux_per_step
.
reshape
((
batch
,
q_max_seqlen
,
head
,
1
))
def
skip_correction
(
_output
,
_softmax_aux
,
output_per_step
,
softmax_aux_per_step
):
def
skip_correction
(
_output
,
_softmax_aux
,
output_per_step
,
softmax_aux_per_step
):
...
@@ -2272,8 +2295,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -2272,8 +2295,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
carry
=
scan_kv_block
(
i
,
carry
)
carry
=
scan_kv_block
(
i
,
carry
)
(
_
,
_
,
_
,
output
,
softmax_aux
)
=
carry
(
_
,
_
,
_
,
output
,
softmax_aux
)
=
carry
softmax_aux
=
softmax_aux
.
reshape
((
batch
,
head
,
q_max_seqlen
,
1
))
return
output
.
astype
(
q
.
dtype
),
softmax_aux
,
rng_state
return
output
.
astype
(
q
.
dtype
),
softmax_aux
,
rng_state
return
mesh
,
fwd_impl
,
out_shardings
,
arg_shardings
return
mesh
,
fwd_impl
,
out_shardings
,
arg_shardings
...
...
transformer_engine/jax/csrc/extensions/attention.cpp
View file @
4099aa8e
...
@@ -129,6 +129,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
...
@@ -129,6 +129,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto
o_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
auto
o_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
auto
dummy_rng_state_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
2
},
DType
::
kInt64
);
auto
dummy_rng_state_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
2
},
DType
::
kInt64
);
auto
dummy_page_table_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kInt32
);
NVTETensorPack
aux_output_tensors
;
NVTETensorPack
aux_output_tensors
;
nvte_tensor_pack_create
(
&
aux_output_tensors
);
nvte_tensor_pack_create
(
&
aux_output_tensors
);
...
@@ -164,15 +165,16 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
...
@@ -164,15 +165,16 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
nvte_fused_attn_fwd_kvpacked
(
nvte_fused_attn_fwd_kvpacked
(
q_tensor
.
data
(),
kv_tensor
.
data
(),
bias_tensor
.
data
(),
s_tensor
.
data
(),
o_tensor
.
data
(),
q_tensor
.
data
(),
kv_tensor
.
data
(),
bias_tensor
.
data
(),
s_tensor
.
data
(),
o_tensor
.
data
(),
&
aux_output_tensors
,
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
&
aux_output_tensors
,
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
ragged_offset_tensor
.
data
(),
ragged_offset_tensor
.
data
(),
dummy_
rng_stat
e_tensor
.
data
(),
ragged_offset_tensor
.
data
(),
ragged_offset_tensor
.
data
(),
dummy_
page_tabl
e_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
is_training
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
dummy_page_table_tensor
.
data
(),
dummy_rng_state_tensor
.
data
(),
q_max_seqlen
,
bias_type
,
mask_type
,
window_size_left
,
window_size_right
,
query_workspace_tensor
.
data
()
,
kv_max_seqlen
,
is_training
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
nullptr
);
mask_type
,
window_size_left
,
window_size_right
,
query_workspace_tensor
.
data
(),
nullptr
);
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_HD_HD
)
{
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_HD_HD
)
{
nvte_fused_attn_fwd
(
nvte_fused_attn_fwd
(
q_tensor
.
data
(),
k_tensor
.
data
(),
v_tensor
.
data
(),
bias_tensor
.
data
(),
s_tensor
.
data
(),
q_tensor
.
data
(),
k_tensor
.
data
(),
v_tensor
.
data
(),
bias_tensor
.
data
(),
s_tensor
.
data
(),
o_tensor
.
data
(),
&
aux_output_tensors
,
q_cu_seqlens_tensor
.
data
(),
o_tensor
.
data
(),
&
aux_output_tensors
,
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
ragged_offset_tensor
.
data
(),
ragged_offset_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
ragged_offset_tensor
.
data
(),
ragged_offset_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
dummy_rng_state_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
is_training
,
scaling_factor
,
dummy_rng_state_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
is_training
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
window_size_left
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
window_size_left
,
window_size_right
,
query_workspace_tensor
.
data
(),
nullptr
);
window_size_right
,
query_workspace_tensor
.
data
(),
nullptr
);
...
@@ -256,6 +258,7 @@ static void FusedAttnForwardImpl(
...
@@ -256,6 +258,7 @@ static void FusedAttnForwardImpl(
backend
,
softmax_aux
);
backend
,
softmax_aux
);
/* Call the underlying NVTE API */
/* Call the underlying NVTE API */
auto
dummy_page_table_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kInt32
);
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_3HD
)
{
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_3HD
)
{
auto
qkv_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
3
,
attn_heads
,
head_dim
};
auto
qkv_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
3
,
attn_heads
,
head_dim
};
auto
qkv_tensor
=
TensorWrapper
(
q
,
qkv_shape
,
dtype
);
auto
qkv_tensor
=
TensorWrapper
(
q
,
qkv_shape
,
dtype
);
...
@@ -273,9 +276,10 @@ static void FusedAttnForwardImpl(
...
@@ -273,9 +276,10 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd_kvpacked
(
nvte_fused_attn_fwd_kvpacked
(
q_tensor
.
data
(),
kv_tensor
.
data
(),
bias_tensor
.
data
(),
s_tensor
.
data
(),
o_tensor
.
data
(),
q_tensor
.
data
(),
kv_tensor
.
data
(),
bias_tensor
.
data
(),
s_tensor
.
data
(),
o_tensor
.
data
(),
&
aux_output_tensors
,
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
&
aux_output_tensors
,
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
q_seq_offsets_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
rng_state_tensor
.
data
(),
q_seq_offsets_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
is_training
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
dummy_page_table_tensor
.
data
(),
rng_state_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
bias_type
,
mask_type
,
window_size_left
,
window_size_right
,
workspace_tensor
.
data
(),
stream
);
is_training
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
window_size_left
,
window_size_right
,
workspace_tensor
.
data
(),
stream
);
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_HD_HD
)
{
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_HD_HD
)
{
auto
q_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
attn_heads
,
head_dim
};
auto
q_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
attn_heads
,
head_dim
};
auto
k_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
kv_max_seqlen
,
num_gqa_groups
,
head_dim
};
auto
k_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
kv_max_seqlen
,
num_gqa_groups
,
head_dim
};
...
@@ -283,13 +287,13 @@ static void FusedAttnForwardImpl(
...
@@ -283,13 +287,13 @@ static void FusedAttnForwardImpl(
auto
q_tensor
=
TensorWrapper
(
q
,
q_shape
,
dtype
);
auto
q_tensor
=
TensorWrapper
(
q
,
q_shape
,
dtype
);
auto
k_tensor
=
TensorWrapper
(
k
,
k_shape
,
dtype
);
auto
k_tensor
=
TensorWrapper
(
k
,
k_shape
,
dtype
);
auto
v_tensor
=
TensorWrapper
(
v
,
v_shape
,
dtype
);
auto
v_tensor
=
TensorWrapper
(
v
,
v_shape
,
dtype
);
nvte_fused_attn_fwd
(
q_tensor
.
data
(),
k_tensor
.
data
(),
v_tensor
.
data
(),
bias_tensor
.
data
(),
nvte_fused_attn_fwd
(
s
_tensor
.
data
(),
o
_tensor
.
data
(),
&
aux_output_tensors
,
q_tensor
.
data
(),
k
_tensor
.
data
(),
v
_tensor
.
data
(),
bias_tensor
.
data
(),
s_tensor
.
data
()
,
q_cu_seqlens_tensor
.
data
(),
kv
_cu_seqlens_tensor
.
data
(),
o_tensor
.
data
(),
&
aux_output_tensors
,
q
_cu_seqlens_tensor
.
data
(),
q_seq_offsets_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
q_seq_offsets_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
rng_state_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
is_training
,
dummy_page_table_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
rng_state_tensor
.
data
()
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
q_max_seqlen
,
kv_max_seqlen
,
is_training
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
window_size_left
,
window_size_right
,
workspace_tensor
.
data
(),
stream
);
bias_type
,
mask_type
,
window_size_left
,
window_size_right
,
workspace_tensor
.
data
(),
stream
);
}
else
{
}
else
{
NVTE_ERROR
(
"Unsupported qkv_layout."
);
NVTE_ERROR
(
"Unsupported qkv_layout."
);
}
}
...
...
transformer_engine/pytorch/attention.py
View file @
4099aa8e
This diff is collapsed.
Click to expand it.
transformer_engine/pytorch/constants.py
View file @
4099aa8e
...
@@ -64,6 +64,16 @@ QKVLayouts = (
...
@@ -64,6 +64,16 @@ QKVLayouts = (
"thd_t2hd"
,
"thd_t2hd"
,
"thd_th2d"
,
"thd_th2d"
,
"thd_thd_thd"
,
"thd_thd_thd"
,
"sbhd_bshd_bshd"
,
"bshd_sbhd_sbhd"
,
"thd_bshd_bshd"
,
"thd_sbhd_sbhd"
,
"paged_kv_bshd_bshd_bshd"
,
"paged_kv_bshd_sbhd_sbhd"
,
"paged_kv_sbhd_bshd_bshd"
,
"paged_kv_sbhd_sbhd_sbhd"
,
"paged_kv_thd_bshd_bshd"
,
"paged_kv_thd_sbhd_sbhd"
,
)
)
LayerTypes
=
(
"encoder"
,
"decoder"
)
LayerTypes
=
(
"encoder"
,
"decoder"
)
...
...
transformer_engine/pytorch/cpp_extensions/fused_attn.py
View file @
4099aa8e
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
(
from
transformer_engine_torch
import
(
NVTE_QKV_Layout
,
NVTE_QKV_Layout
,
NVTE_QKV_Format
,
NVTE_Bias_Type
,
NVTE_Bias_Type
,
NVTE_Mask_Type
,
NVTE_Mask_Type
,
NVTE_Fused_Attn_Backend
,
NVTE_Fused_Attn_Backend
,
...
@@ -31,6 +32,16 @@ TORCH_DType = {
...
@@ -31,6 +32,16 @@ TORCH_DType = {
tex
.
DType
.
kInt32
:
torch
.
int32
,
tex
.
DType
.
kInt32
:
torch
.
int32
,
}
}
QKVFormat
=
{
"bshd"
:
NVTE_QKV_Format
.
NVTE_BSHD
,
"sbhd"
:
NVTE_QKV_Format
.
NVTE_SBHD
,
"thd"
:
NVTE_QKV_Format
.
NVTE_THD
,
"sbhd_2bshd"
:
NVTE_QKV_Format
.
NVTE_SBHD_2BSHD
,
"bshd_2sbhd"
:
NVTE_QKV_Format
.
NVTE_BSHD_2SBHD
,
"thd_2bshd"
:
NVTE_QKV_Format
.
NVTE_THD_2BSHD
,
"thd_2sbhd"
:
NVTE_QKV_Format
.
NVTE_THD_2SBHD
,
}
QKVLayout
=
{
QKVLayout
=
{
"sb3hd"
:
NVTE_QKV_Layout
.
NVTE_SB3HD
,
"sb3hd"
:
NVTE_QKV_Layout
.
NVTE_SB3HD
,
"sbh3d"
:
NVTE_QKV_Layout
.
NVTE_SBH3D
,
"sbh3d"
:
NVTE_QKV_Layout
.
NVTE_SBH3D
,
...
@@ -47,6 +58,16 @@ QKVLayout = {
...
@@ -47,6 +58,16 @@ QKVLayout = {
"thd_t2hd"
:
NVTE_QKV_Layout
.
NVTE_THD_T2HD
,
"thd_t2hd"
:
NVTE_QKV_Layout
.
NVTE_THD_T2HD
,
"thd_th2d"
:
NVTE_QKV_Layout
.
NVTE_THD_TH2D
,
"thd_th2d"
:
NVTE_QKV_Layout
.
NVTE_THD_TH2D
,
"thd_thd_thd"
:
NVTE_QKV_Layout
.
NVTE_THD_THD_THD
,
"thd_thd_thd"
:
NVTE_QKV_Layout
.
NVTE_THD_THD_THD
,
"sbhd_bshd_bshd"
:
NVTE_QKV_Layout
.
NVTE_SBHD_BSHD_BSHD
,
"bshd_sbhd_sbhd"
:
NVTE_QKV_Layout
.
NVTE_BSHD_SBHD_SBHD
,
"thd_bshd_bshd"
:
NVTE_QKV_Layout
.
NVTE_THD_BSHD_BSHD
,
"thd_sbhd_sbhd"
:
NVTE_QKV_Layout
.
NVTE_THD_SBHD_SBHD
,
"paged_kv_bshd_bshd_bshd"
:
NVTE_QKV_Layout
.
NVTE_Paged_KV_BSHD_BSHD_BSHD
,
"paged_kv_bshd_sbhd_sbhd"
:
NVTE_QKV_Layout
.
NVTE_Paged_KV_BSHD_SBHD_SBHD
,
"paged_kv_sbhd_bshd_bshd"
:
NVTE_QKV_Layout
.
NVTE_Paged_KV_SBHD_BSHD_BSHD
,
"paged_kv_sbhd_sbhd_sbhd"
:
NVTE_QKV_Layout
.
NVTE_Paged_KV_SBHD_SBHD_SBHD
,
"paged_kv_thd_bshd_bshd"
:
NVTE_QKV_Layout
.
NVTE_Paged_KV_THD_BSHD_BSHD
,
"paged_kv_thd_sbhd_sbhd"
:
NVTE_QKV_Layout
.
NVTE_Paged_KV_THD_SBHD_SBHD
,
}
}
AttnBiasType
=
{
AttnBiasType
=
{
...
@@ -100,6 +121,8 @@ def fused_attn_fwd(
...
@@ -100,6 +121,8 @@ def fused_attn_fwd(
attn_bias
:
torch
.
Tensor
=
None
,
attn_bias
:
torch
.
Tensor
=
None
,
cu_seqlens_q_padded
:
torch
.
Tensor
=
None
,
cu_seqlens_q_padded
:
torch
.
Tensor
=
None
,
cu_seqlens_kv_padded
:
torch
.
Tensor
=
None
,
cu_seqlens_kv_padded
:
torch
.
Tensor
=
None
,
page_table_k
:
torch
.
Tensor
=
None
,
page_table_v
:
torch
.
Tensor
=
None
,
s_quantizer
:
Quantizer
=
None
,
s_quantizer
:
Quantizer
=
None
,
o_quantizer
:
Quantizer
=
None
,
o_quantizer
:
Quantizer
=
None
,
attn_scale
:
float
=
None
,
attn_scale
:
float
=
None
,
...
@@ -148,6 +171,10 @@ def fused_attn_fwd(
...
@@ -148,6 +171,10 @@ def fused_attn_fwd(
cumulative sequence offsets for Q; shape [batch_size + 1]
cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
cumulative sequence offsets for KV; shape [batch_size + 1]
page_table_k: torch.Tensor, default = None
page table for K cache; shape [batch_size, max_pages_per_seq_k]
page_table_v: torch.Tensor, default = None
page table for V cache; shape [batch_size, max_pages_per_seq_v]
s_quantizer: Quantizer, default = None
s_quantizer: Quantizer, default = None
Quantizer object for the intermediate value S.
Quantizer object for the intermediate value S.
o_quantizer: Quantizer, default = None
o_quantizer: Quantizer, default = None
...
@@ -268,6 +295,8 @@ def fused_attn_fwd(
...
@@ -268,6 +295,8 @@ def fused_attn_fwd(
fake_dtype
,
fake_dtype
,
cu_seqlens_q_padded
,
cu_seqlens_q_padded
,
cu_seqlens_kv_padded
,
cu_seqlens_kv_padded
,
page_table_k
,
page_table_v
,
s_quantizer
,
s_quantizer
,
o_quantizer
,
o_quantizer
,
attn_bias
,
attn_bias
,
...
...
transformer_engine/pytorch/csrc/common.h
View file @
4099aa8e
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
#include <ATen/cudnn/Handle.h>
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
#include <c10/macros/Macros.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#ifndef USE_ROCM
#ifndef USE_ROCM
#include <cublasLt.h>
#include <cublasLt.h>
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
4099aa8e
...
@@ -51,8 +51,9 @@ std::vector<py::object> fused_attn_fwd(
...
@@ -51,8 +51,9 @@ std::vector<py::object> fused_attn_fwd(
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
o_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
Bias
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_k
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
Bias
,
const
c10
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
);
const
c10
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
);
std
::
vector
<
py
::
object
>
fused_attn_bwd
(
std
::
vector
<
py
::
object
>
fused_attn_bwd
(
...
@@ -69,6 +70,13 @@ std::vector<py::object> fused_attn_bwd(
...
@@ -69,6 +70,13 @@ std::vector<py::object> fused_attn_bwd(
at
::
Tensor
fa_prepare_fwd
(
at
::
Tensor
qkvi
);
at
::
Tensor
fa_prepare_fwd
(
at
::
Tensor
qkvi
);
at
::
Tensor
fa_prepare_bwd
(
at
::
Tensor
q
,
at
::
Tensor
k
,
at
::
Tensor
v
);
at
::
Tensor
fa_prepare_bwd
(
at
::
Tensor
q
,
at
::
Tensor
k
,
at
::
Tensor
v
);
at
::
Tensor
convert_thd_to_bshd
(
at
::
Tensor
tensor
,
at
::
Tensor
cu_seqlens
,
int
b
,
int
max_seq_len
);
at
::
Tensor
convert_bshd_to_thd
(
at
::
Tensor
tensor
,
at
::
Tensor
cu_seqlens
,
int
t
);
void
copy_to_kv_cache
(
torch
::
Tensor
new_k
,
torch
::
Tensor
new_v
,
torch
::
Tensor
k_cache
,
torch
::
Tensor
v_cache
,
torch
::
Tensor
page_table
,
torch
::
Tensor
cu_new_lens
,
torch
::
Tensor
cu_cached_lens
,
NVTE_QKV_Format
kv_format
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
bool
is_non_paged
);
/***************************************************************************************************
/***************************************************************************************************
* GEMM
* GEMM
**************************************************************************************************/
**************************************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/attention.cu
View file @
4099aa8e
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "extensions.h"
#include "kv_cache.cuh"
#include "thd_utils.cuh"
#include "thd_utils.cuh"
constexpr
int
block_size
=
512
;
constexpr
int
block_size
=
512
;
...
@@ -95,8 +96,9 @@ std::vector<py::object> fused_attn_fwd(
...
@@ -95,8 +96,9 @@ std::vector<py::object> fused_attn_fwd(
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
o_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
Bias
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_k
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
Bias
,
const
c10
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
const
c10
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
static_assert
(
false
,
...
@@ -135,6 +137,7 @@ std::vector<py::object> fused_attn_fwd(
...
@@ -135,6 +137,7 @@ std::vector<py::object> fused_attn_fwd(
TensorWrapper
te_Bias
;
TensorWrapper
te_Bias
;
TensorWrapper
te_cu_seqlens_q
,
te_cu_seqlens_kv
;
TensorWrapper
te_cu_seqlens_q
,
te_cu_seqlens_kv
;
TensorWrapper
te_cu_seqlens_q_padded
,
te_cu_seqlens_kv_padded
;
TensorWrapper
te_cu_seqlens_q_padded
,
te_cu_seqlens_kv_padded
;
TensorWrapper
te_page_table_k
,
te_page_table_v
;
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
// FP8
// FP8
auto
h
=
q_shape
[
q_shape
.
size
()
-
2
];
auto
h
=
q_shape
[
q_shape
.
size
()
-
2
];
...
@@ -179,6 +182,19 @@ std::vector<py::object> fused_attn_fwd(
...
@@ -179,6 +182,19 @@ std::vector<py::object> fused_attn_fwd(
cu_seqlens_kv_padded
.
value
().
data_ptr
(),
cu_seqlens_kv_padded_shape
,
DType
::
kInt32
);
cu_seqlens_kv_padded
.
value
().
data_ptr
(),
cu_seqlens_kv_padded_shape
,
DType
::
kInt32
);
}
}
if
((
page_table_k
.
has_value
())
&&
(
page_table_v
.
has_value
()))
{
auto
page_table_k_sizes
=
page_table_k
.
value
().
sizes
().
vec
();
std
::
vector
<
size_t
>
page_table_k_shape
{
page_table_k_sizes
.
begin
(),
page_table_k_sizes
.
end
()};
auto
page_table_v_sizes
=
page_table_v
.
value
().
sizes
().
vec
();
std
::
vector
<
size_t
>
page_table_v_shape
{
page_table_v_sizes
.
begin
(),
page_table_v_sizes
.
end
()};
te_page_table_k
=
makeTransformerEngineTensor
(
page_table_k
.
value
().
data_ptr
(),
page_table_k_shape
,
DType
::
kInt32
,
nullptr
,
nullptr
,
nullptr
);
te_page_table_v
=
makeTransformerEngineTensor
(
page_table_v
.
value
().
data_ptr
(),
page_table_v_shape
,
DType
::
kInt32
,
nullptr
,
nullptr
,
nullptr
);
}
// extract rng seed and offset
// extract rng seed and offset
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
rng_gen
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
rng_gen
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
...
@@ -196,13 +212,13 @@ std::vector<py::object> fused_attn_fwd(
...
@@ -196,13 +212,13 @@ std::vector<py::object> fused_attn_fwd(
TensorWrapper
workspace
;
TensorWrapper
workspace
;
// populate tensors with appropriate shapes and dtypes
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_Bias
.
data
(),
te_S
.
data
(),
nvte_fused_attn_fwd
(
te_
O
.
data
(),
&
nvte_aux_tensor_pack
,
te_cu_seqlens_q
.
data
(),
te_
Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_Bias
.
data
(),
te_S
.
data
(),
te_O
.
data
(),
te_cu_seqlens_
kv
.
data
(),
te_cu_seqlens_
q_padded
.
data
(),
&
nvte_aux_tensor_pack
,
te_cu_seqlens_
q
.
data
(),
te_cu_seqlens_
kv
.
data
(),
te_cu_seqlens_
kv
_padded
.
data
(),
te_
rng_state
.
data
(),
max_seqlen_q
,
te_cu_seqlens_
q
_padded
.
data
(),
te_
cu_seqlens_kv_padded
.
data
(),
te_page_table_k
.
data
()
,
max_seqlen_kv
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
te_page_table_v
.
data
(),
te_rng_state
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
is_training
,
attn_mask_type
,
window_size
[
0
],
window_size
[
1
],
workspace
.
data
(),
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size
[
0
],
window_size
[
1
],
at
::
cuda
::
getCurrentCUDAStream
());
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
// allocate memory for workspace and auxiliary output tensors
// allocate memory for workspace and auxiliary output tensors
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
...
@@ -250,13 +266,13 @@ std::vector<py::object> fused_attn_fwd(
...
@@ -250,13 +266,13 @@ std::vector<py::object> fused_attn_fwd(
}
}
// execute the kernel
// execute the kernel
nvte_fused_attn_fwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_Bias
.
data
(),
te_S
.
data
(),
nvte_fused_attn_fwd
(
te_
O
.
data
(),
&
nvte_aux_tensor_pack
,
te_cu_seqlens_q
.
data
(),
te_
Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_Bias
.
data
(),
te_S
.
data
(),
te_O
.
data
(),
te_cu_seqlens_
kv
.
data
(),
te_cu_seqlens_
q_padded
.
data
(),
&
nvte_aux_tensor_pack
,
te_cu_seqlens_
q
.
data
(),
te_cu_seqlens_
kv
.
data
(),
te_cu_seqlens_
kv
_padded
.
data
(),
te_
rng_state
.
data
(),
max_seqlen_q
,
te_cu_seqlens_
q
_padded
.
data
(),
te_
cu_seqlens_kv_padded
.
data
(),
te_page_table_k
.
data
()
,
max_seqlen_kv
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
te_page_table_v
.
data
(),
te_rng_state
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
is_training
,
attn_mask_type
,
window_size
[
0
],
window_size
[
1
],
workspace
.
data
(),
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size
[
0
],
window_size
[
1
],
at
::
cuda
::
getCurrentCUDAStream
());
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
// destroy tensor wrappers, but not allocated memory
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy
(
&
nvte_aux_tensor_pack
);
nvte_tensor_pack_destroy
(
&
nvte_aux_tensor_pack
);
...
@@ -1027,3 +1043,174 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
...
@@ -1027,3 +1043,174 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
return
output
;
return
output
;
}
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
template
<
typename
scalar_t
>
void
convert_thd_to_bshd_launcher
(
at
::
Tensor
tensor
,
at
::
Tensor
new_tensor
,
at
::
Tensor
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
transformer_engine
::
fused_attn
::
convert_thd_to_bshd_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data_ptr
<
scalar_t
>
()),
cu_seqlens
.
data_ptr
<
int
>
(),
b
,
max_seq_len
,
h
,
d
);
}
at
::
Tensor
convert_thd_to_bshd
(
at
::
Tensor
tensor
,
at
::
Tensor
cu_seqlens
,
int
b
,
int
max_seq_len
)
{
int
h
=
tensor
.
size
(
1
);
int
d
=
tensor
.
size
(
2
);
std
::
vector
<
int64_t
>
shape
=
{
b
,
max_seq_len
,
h
,
d
};
at
::
Tensor
new_tensor
=
at
::
zeros
(
shape
,
at
::
CUDA
(
tensor
.
scalar_type
()));
if
(
new_tensor
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
using
dtype
=
at
::
Half
;
convert_thd_to_bshd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
h
,
d
);
}
else
if
(
new_tensor
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
using
dtype
=
at
::
BFloat16
;
convert_thd_to_bshd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
h
,
d
);
}
else
if
(
new_tensor
.
scalar_type
()
==
at
::
ScalarType
::
Float
)
{
using
dtype
=
float
;
convert_thd_to_bshd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
h
,
d
);
}
else
if
(
new_tensor
.
scalar_type
()
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
using
dtype
=
at
::
Float8_e4m3fn
;
convert_thd_to_bshd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
h
,
d
);
}
else
if
(
new_tensor
.
scalar_type
()
==
at
::
ScalarType
::
Float8_e5m2
)
{
using
dtype
=
at
::
Float8_e5m2
;
convert_thd_to_bshd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
h
,
d
);
}
else
{
NVTE_ERROR
(
"Unsupported dtype for KV cache.
\n
"
);
}
return
new_tensor
;
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
template
<
typename
scalar_t
>
void
convert_bshd_to_thd_launcher
(
at
::
Tensor
tensor
,
at
::
Tensor
new_tensor
,
at
::
Tensor
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
transformer_engine
::
fused_attn
::
convert_bshd_to_thd_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data_ptr
<
scalar_t
>
()),
cu_seqlens
.
data_ptr
<
int
>
(),
b
,
max_seq_len
,
h
,
d
);
}
at
::
Tensor
convert_bshd_to_thd
(
at
::
Tensor
tensor
,
at
::
Tensor
cu_seqlens
,
int
t
)
{
int
b
=
tensor
.
size
(
0
);
int
max_seq_len
=
tensor
.
size
(
1
);
int
h
=
tensor
.
size
(
2
);
int
d
=
tensor
.
size
(
3
);
std
::
vector
<
int64_t
>
shape
=
{
t
,
h
,
d
};
at
::
Tensor
new_tensor
=
at
::
zeros
(
shape
,
at
::
CUDA
(
tensor
.
scalar_type
()));
if
(
tensor
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
using
dtype
=
at
::
Half
;
convert_bshd_to_thd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
h
,
d
);
}
else
if
(
tensor
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
using
dtype
=
at
::
BFloat16
;
convert_bshd_to_thd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
h
,
d
);
}
else
if
(
tensor
.
scalar_type
()
==
at
::
ScalarType
::
Float
)
{
using
dtype
=
float
;
convert_bshd_to_thd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
h
,
d
);
}
else
if
(
tensor
.
scalar_type
()
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
using
dtype
=
at
::
Float8_e4m3fn
;
convert_bshd_to_thd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
h
,
d
);
}
else
if
(
tensor
.
scalar_type
()
==
at
::
ScalarType
::
Float8_e5m2
)
{
using
dtype
=
at
::
Float8_e5m2
;
convert_bshd_to_thd_launcher
<
dtype
>
(
tensor
,
new_tensor
,
cu_seqlens
,
b
,
max_seq_len
,
h
,
d
);
}
else
{
NVTE_ERROR
(
"Unsupported dtype for KV cache.
\n
"
);
}
return
new_tensor
;
}
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
template
<
typename
scalar_t
>
void
copy_to_kv_cache_launcher
(
at
::
Tensor
new_k
,
at
::
Tensor
new_v
,
at
::
Tensor
k_cache
,
at
::
Tensor
v_cache
,
at
::
Tensor
page_table
,
at
::
Tensor
cu_new_lens
,
at
::
Tensor
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
bool
is_non_paged
)
{
if
(
new_k
.
data_ptr
()
!=
nullptr
&&
new_v
.
data_ptr
()
!=
nullptr
&&
k_cache
.
data_ptr
()
!=
nullptr
&&
v_cache
.
data_ptr
()
!=
nullptr
)
{
if
(
is_non_paged
)
{
transformer_engine
::
fused_attn
::
reindex_kv_cache_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
reinterpret_cast
<
scalar_t
*>
(
k_cache
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
v_cache
.
data_ptr
<
scalar_t
>
()),
page_table
.
data_ptr
<
int
>
(),
cu_new_lens
.
data_ptr
<
int
>
(),
cu_cached_lens
.
data_ptr
<
int
>
(),
h_kv
,
d_k
,
d_v
,
b
,
max_seq_len
);
}
transformer_engine
::
fused_attn
::
copy_to_kv_cache_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
reinterpret_cast
<
scalar_t
*>
(
new_k
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
new_v
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
k_cache
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
v_cache
.
data_ptr
<
scalar_t
>
()),
page_table
.
data_ptr
<
int
>
(),
cu_new_lens
.
data_ptr
<
int
>
(),
cu_cached_lens
.
data_ptr
<
int
>
(),
qkv_format
,
h_kv
,
d_k
,
d_v
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
);
}
}
void
copy_to_kv_cache
(
at
::
Tensor
new_k
,
at
::
Tensor
new_v
,
at
::
Tensor
k_cache
,
at
::
Tensor
v_cache
,
at
::
Tensor
page_table
,
at
::
Tensor
cu_new_lens
,
at
::
Tensor
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
bool
is_non_paged
)
{
int
h_kv
=
new_k
.
size
(
-
2
);
int
d_k
=
new_k
.
size
(
-
1
);
int
d_v
=
new_v
.
size
(
-
1
);
NVTE_CHECK
(
k_cache
.
scalar_type
()
==
v_cache
.
scalar_type
()
&&
new_k
.
scalar_type
()
==
new_v
.
scalar_type
()
&&
new_k
.
scalar_type
()
==
k_cache
.
scalar_type
(),
"new_k, new_v, k_cache and v_cache must be of the same data type."
);
NVTE_CHECK
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
,
"qkv_format must be {BSHD, SBHD, THD}."
);
if
(
k_cache
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
using
dtype
=
at
::
Half
;
copy_to_kv_cache_launcher
<
dtype
>
(
new_k
,
new_v
,
k_cache
,
v_cache
,
page_table
,
cu_new_lens
,
cu_cached_lens
,
qkv_format
,
h_kv
,
d_k
,
d_v
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
);
}
else
if
(
k_cache
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
using
dtype
=
at
::
BFloat16
;
copy_to_kv_cache_launcher
<
dtype
>
(
new_k
,
new_v
,
k_cache
,
v_cache
,
page_table
,
cu_new_lens
,
cu_cached_lens
,
qkv_format
,
h_kv
,
d_k
,
d_v
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
);
}
else
if
(
k_cache
.
scalar_type
()
==
at
::
ScalarType
::
Float
)
{
using
dtype
=
float
;
copy_to_kv_cache_launcher
<
dtype
>
(
new_k
,
new_v
,
k_cache
,
v_cache
,
page_table
,
cu_new_lens
,
cu_cached_lens
,
qkv_format
,
h_kv
,
d_k
,
d_v
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
);
}
else
if
(
k_cache
.
scalar_type
()
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
using
dtype
=
at
::
Float8_e4m3fn
;
copy_to_kv_cache_launcher
<
dtype
>
(
new_k
,
new_v
,
k_cache
,
v_cache
,
page_table
,
cu_new_lens
,
cu_cached_lens
,
qkv_format
,
h_kv
,
d_k
,
d_v
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
);
}
else
if
(
k_cache
.
scalar_type
()
==
at
::
ScalarType
::
Float8_e5m2
)
{
using
dtype
=
at
::
Float8_e5m2
;
copy_to_kv_cache_launcher
<
dtype
>
(
new_k
,
new_v
,
k_cache
,
v_cache
,
page_table
,
cu_new_lens
,
cu_cached_lens
,
qkv_format
,
h_kv
,
d_k
,
d_v
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
);
}
else
{
NVTE_ERROR
(
"Unsupported dtype for KV cache.
\n
"
);
}
}
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
4099aa8e
...
@@ -36,6 +36,10 @@ namespace transformer_engine::pytorch {
...
@@ -36,6 +36,10 @@ namespace transformer_engine::pytorch {
namespace
detail
{
namespace
detail
{
bool
is_low_precision
(
const
DType
type
)
{
return
type
==
DType
::
kFloat8E4M3
||
type
==
DType
::
kFloat8E5M2
;
}
std
::
vector
<
size_t
>
getGemmOutputShape
(
const
NVTEShape
&
A_shape
,
const
bool
transa
,
std
::
vector
<
size_t
>
getGemmOutputShape
(
const
NVTEShape
&
A_shape
,
const
bool
transa
,
const
NVTEShape
&
B_shape
,
const
bool
transb
)
{
const
NVTEShape
&
B_shape
,
const
bool
transb
)
{
// Flatten outer dims to get 2D matrices
// Flatten outer dims to get 2D matrices
...
@@ -96,6 +100,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
...
@@ -96,6 +100,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
TensorWrapper
A_tensor
=
makeTransformerEngineTensor
(
A
,
none
);
TensorWrapper
A_tensor
=
makeTransformerEngineTensor
(
A
,
none
);
TensorWrapper
B_tensor
=
makeTransformerEngineTensor
(
B
,
none
);
TensorWrapper
B_tensor
=
makeTransformerEngineTensor
(
B
,
none
);
const
bool
low_precision
=
detail
::
is_low_precision
(
A_tensor
.
dtype
())
||
detail
::
is_low_precision
(
B_tensor
.
dtype
());
// Check tensor dimensions
// Check tensor dimensions
const
auto
&
A_shape
=
A_tensor
.
shape
();
const
auto
&
A_shape
=
A_tensor
.
shape
();
const
auto
&
B_shape
=
B_tensor
.
shape
();
const
auto
&
B_shape
=
B_tensor
.
shape
();
...
@@ -137,7 +144,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
...
@@ -137,7 +144,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Activation input tensor
// Activation input tensor
MaybeTensor
pre_gelu_out
=
std
::
nullopt
;
MaybeTensor
pre_gelu_out
=
std
::
nullopt
;
DType
gelu_type
=
bias_type
;
DType
gelu_type
=
low_precision
?
bias_type
:
D_tensor
.
dtype
()
;
if
(
gelu
)
{
if
(
gelu
)
{
if
(
!
grad
)
{
if
(
!
grad
)
{
auto
dtype
=
GetATenDType
(
gelu_type
);
auto
dtype
=
GetATenDType
(
gelu_type
);
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment