Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
51ec5aa0
Commit
51ec5aa0
authored
Mar 09, 2023
by
danyao12
Browse files
modify argc and macro
parent
80ef43a2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
60 deletions
+42
-60
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
...ale_softmax_gemm/batched_multihead_attention_backward.cpp
+17
-25
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+23
-33
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
...ale_softmax_gemm/grouped_multihead_attention_backward.cpp
+2
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
View file @
51ec5aa0
...
@@ -25,7 +25,7 @@ Kernel outputs:
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define
RANGE_HDKO 1 // 0~2
#define
DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -91,11 +91,11 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
...
@@ -91,11 +91,11 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
//
Headdim/K/O
should be a multiple of 8.
//
DIM
should be a multiple of 8.
// If
Headdim/K/O
<= 32 , ues prototype1 1st template.
// If
DIM
<= 32 , ues prototype1 1st template.
// If 32 <
Headdim/K/O
<= 64 , ues prototype1 2nd template.
// If 32 <
DIM
<= 64 , ues prototype1 2nd template.
// If 64 <
Headdim/K/O
<= 128, ues prototype2 2nd template.
// If 64 <
DIM
<= 128, ues prototype2 2nd template.
#if(
RANGE_HDKO == 0
)
#if(
DIM <= 32
)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
...
@@ -163,7 +163,7 @@ using DeviceGemmInstance =
...
@@ -163,7 +163,7 @@ using DeviceGemmInstance =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
#elif(
RANGE_HDKO == 1
)
#elif(
DIM <= 64
)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
...
@@ -299,7 +299,7 @@ using DeviceGemmInstance =
...
@@ -299,7 +299,7 @@ using DeviceGemmInstance =
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization
// MaskingSpec>; // MaskingSpecialization
#elif(
RANGE_HDKO == 2
)
#elif(
DIM <= 128
)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimG
,
...
@@ -478,21 +478,13 @@ int run(int argc, char* argv[])
...
@@ -478,21 +478,13 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
N
=
512
;
#if(RANGE_HDKO == 0)
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
32
;
// K/O<=32
ck
::
index_t
O
=
DIM
;
#elif(RANGE_HDKO == 1)
ck
::
index_t
K
=
64
;
// 32<K/O<=64
#elif(RANGE_HDKO == 2)
ck
::
index_t
K
=
80
;
// 64<K/O<=128
#endif
ck
::
index_t
O
=
K
;
ck
::
index_t
G0
=
54
;
ck
::
index_t
G0
=
54
;
ck
::
index_t
G1
=
16
;
ck
::
index_t
G1
=
16
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
...
@@ -510,7 +502,7 @@ int run(int argc, char* argv[])
...
@@ -510,7 +502,7 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
1
4
)
else
if
(
argc
==
1
3
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -523,11 +515,10 @@ int run(int argc, char* argv[])
...
@@ -523,11 +515,10 @@ int run(int argc, char* argv[])
G0
=
std
::
stoi
(
argv
[
8
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
}
}
else
else
{
{
...
@@ -543,6 +534,7 @@ int run(int argc, char* argv[])
...
@@ -543,6 +534,7 @@ int run(int argc, char* argv[])
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
51ec5aa0
...
@@ -32,7 +32,7 @@ Kernel outputs:
...
@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define
RANGE_HDKO 0 // 0~2
#define
DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -100,11 +100,11 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
...
@@ -100,11 +100,11 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
//
Headdim/K/O
should be a multiple of 8.
//
DIM
should be a multiple of 8.
// If
Headdim/K/O
<= 32 , ues
bwd
prototype1 1st template.
// If
DIM
<= 32 , ues prototype1 1st template.
// If 32 <
Headdim/K/O
<= 64 , ues
bwd
prototype1 2nd template.
// If 32 <
DIM
<= 64 , ues prototype1 2nd template.
// If 64 <
Headdim/K/O
<= 128, ues
bwd
prototype2 2nd template.
// If 64 <
DIM
<= 128, ues prototype2 2nd template.
#if(
RANGE_HDKO == 0
)
#if(
DIM <= 32
)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
...
@@ -242,7 +242,7 @@ using DeviceGemmInstanceBWD =
...
@@ -242,7 +242,7 @@ using DeviceGemmInstanceBWD =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
#elif(
RANGE_HDKO == 1
)
#elif(
DIM <= 64
)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
...
@@ -448,7 +448,7 @@ using DeviceGemmInstanceBWD =
...
@@ -448,7 +448,7 @@ using DeviceGemmInstanceBWD =
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization
// MaskingSpec>; // MaskingSpecialization
#elif(
RANGE_HDKO == 2
)
#elif(
DIM <= 128
)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
...
@@ -657,14 +657,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -657,14 +657,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
// masking
#if USING_MASK
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
const
auto
mask
=
DeviceGemmInstanceFWD
::
C0MatrixMask
(
N
);
const
auto
mask
=
DeviceGemmInstanceFWD
::
C0MatrixMask
(
N
);
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
});
#endif
// P = Softmax(S)
// P = Softmax(S)
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
...
@@ -699,25 +697,17 @@ int run(int argc, char* argv[])
...
@@ -699,25 +697,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
// 512
ck
::
index_t
M
=
512
;
// 512
ck
::
index_t
N
=
512
;
// 512
ck
::
index_t
N
=
512
;
// 512
#if(RANGE_HDKO == 0)
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
32
;
// K/O<=32
ck
::
index_t
O
=
DIM
;
#elif(RANGE_HDKO == 1)
ck
::
index_t
K
=
64
;
// 32<K/O<=64
#elif(RANGE_HDKO == 2)
ck
::
index_t
K
=
72
;
// 64<K/O<=128
#endif
ck
::
index_t
O
=
K
;
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G1
=
6
;
// 16
ck
::
index_t
G1
=
6
;
// 16
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
input_permute
=
true
;
bool
output_permute
=
true
;
float
p_drop
=
0.
3
;
float
p_drop
=
0.
2
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
@@ -731,7 +721,7 @@ int run(int argc, char* argv[])
...
@@ -731,7 +721,7 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
1
4
)
else
if
(
argc
==
1
3
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -744,11 +734,10 @@ int run(int argc, char* argv[])
...
@@ -744,11 +734,10 @@ int run(int argc, char* argv[])
G0
=
std
::
stoi
(
argv
[
8
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
}
}
else
else
{
{
...
@@ -761,9 +750,10 @@ int run(int argc, char* argv[])
...
@@ -761,9 +750,10 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
View file @
51ec5aa0
...
@@ -480,8 +480,8 @@ int run(int argc, char* argv[])
...
@@ -480,8 +480,8 @@ int run(int argc, char* argv[])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.2
;
float
p_drop
=
0.2
;
bool
input_permute
=
fals
e
;
bool
input_permute
=
tru
e
;
bool
output_permute
=
fals
e
;
bool
output_permute
=
tru
e
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment