Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
00cb7e41
Commit
00cb7e41
authored
Jun 25, 2023
by
danyao12
Browse files
modify comment
parent
c07c2b55
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
13 deletions
+13
-13
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v1.cpp
+4
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+9
-9
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
View file @
00cb7e41
...
@@ -32,7 +32,7 @@ Kernel outputs:
...
@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -78,7 +78,7 @@ using GemmDataType = F16;
...
@@ -78,7 +78,7 @@ using GemmDataType = F16;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
INT32
;
// INT32
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -89,7 +89,7 @@ static constexpr ck::index_t NumDimK = 1;
...
@@ -89,7 +89,7 @@ static constexpr ck::index_t NumDimK = 1;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
4
;
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
...
@@ -104,7 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
...
@@ -104,7 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
tru
e
;
static
constexpr
bool
Deterministic
=
fals
e
;
// DIM should be a multiple of 8.
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If DIM <= 32 , ues prototype1 1st template.
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
00cb7e41
...
@@ -879,7 +879,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -879,7 +879,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2
,
// MPerXdl
m2
,
// MPerXdl
n2
,
// NGroupNum
n2
,
// NGroupNum
n3
,
// NInputNum
n3
,
// NInputNum
n4
));
//
r
egisterNum
n4
));
//
R
egisterNum
constexpr
auto
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
=
// for blockwise copy
constexpr
auto
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
...
@@ -889,7 +889,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -889,7 +889,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2
,
// MPerXdl
m2
,
// MPerXdl
n2
,
// NGroupNum
n2
,
// NGroupNum
n3
,
// NInputNum
n3
,
// NInputNum
n4
,
//
r
egisterNum
n4
,
//
R
egisterNum
I1
));
// I1
I1
));
// I1
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
...
@@ -902,7 +902,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -902,7 +902,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2
,
// MPerXdl
m2
,
// MPerXdl
n2
,
// NGroupNum
n2
,
// NGroupNum
n3
,
// NInputNum
n3
,
// NInputNum
n4
));
//
r
egisterNum
n4
));
//
R
egisterNum
constexpr
auto
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
constexpr
auto
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
...
@@ -974,7 +974,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -974,7 +974,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2
,
// MPerXdl
m2
,
// MPerXdl
n2
,
// NGroupNum
n2
,
// NGroupNum
n3
,
// NInputNum
n3
,
// NInputNum
n4
>
,
//
r
egisterNum
n4
>
,
//
R
egisterNum
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
7
,
// DstVectorDim
1
,
// DstScalarPerVector
1
,
// DstScalarPerVector
...
@@ -982,12 +982,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -982,12 +982,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
true
>
{
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
0
,
//
mr
epeat
make_multi_index
(
0
,
//
MR
epeat
0
,
//
nr
epeat
0
,
//
NR
epeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
wave_m_n_id
[
I1
],
// MPerXdl
0
,
//
g
roup
0
,
//
NG
roup
Index
wave_m_n_id
[
I0
],
// NInputIndex
wave_m_n_id
[
I0
],
// NInputIndex
0
),
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
...
@@ -1003,8 +1003,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1003,8 +1003,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
1
,
1
,
1
,
1
,
true
>
{
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
,
true
>
{
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
,
make_multi_index
(
0
,
//
mr
epeat
make_multi_index
(
0
,
//
MR
epeat
0
,
//
nr
epeat
0
,
//
NR
epeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
]
/
ZN4
,
wave_m_n_id
[
I1
]
/
ZN4
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment