Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c56f28b0
Commit
c56f28b0
authored
Mar 07, 2023
by
danyao12
Browse files
let grouped fwd test more random
parent
62357796
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
13 deletions
+17
-13
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
...ale_softmax_gemm/batched_multihead_attention_backward.cpp
+2
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+2
-1
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+13
-11
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
View file @
c56f28b0
...
...
@@ -231,7 +231,8 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// using DeviceGemmInstance =
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
// NumDimG,
// NumDimM,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
c56f28b0
...
...
@@ -380,7 +380,8 @@ using DeviceGemmInstanceBWD =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// using DeviceGemmInstanceBWD =
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
// NumDimG,
// NumDimM,
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
c56f28b0
...
...
@@ -10,10 +10,7 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
output_permute
=
true
;
float
p_drop
=
0.1
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
p_drop
=
0.2
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
...
...
@@ -27,14 +24,15 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
6
)
else
if
(
argc
==
7
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
input_permute
=
std
::
stoi
(
argv
[
4
]);
output_permute
=
std
::
stoi
(
argv
[
5
]);
p_drop
=
std
::
stoi
(
argv
[
4
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
}
else
{
...
...
@@ -45,6 +43,10 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1
;
// scaling after 1st gemm
std
::
size_t
group_count
=
8
;
...
...
@@ -81,14 +83,14 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
128
*
(
rand
()
%
8
+
1
);
int
N
=
128
*
(
rand
()
%
8
+
1
);
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
#if(RANGE_HDKO == 0)
int
K
=
32
;
// K/O<=32
#elif(RANGE_HDKO == 1)
int
K
=
6
4
;
// 32<K/O<=64
int
K
=
5
6
;
// 32<K/O<=64
#elif(RANGE_HDKO == 2)
int
K
=
72
;
// 64<K/O<=128
int
K
=
80
;
// 64<K/O<=128
#endif
int
O
=
K
;
int
G0
=
rand
()
%
3
+
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment