Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
067e71a8
"vscode:/vscode.git/clone" did not exist on "866324b9a5bea99a74a81172704c67008d7cb9fa"
Commit
067e71a8
authored
Feb 16, 2023
by
guangzlu
Browse files
added dropout verify for grouped mha fp16 fwd
parent
937fcc07
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
23 deletions
+21
-23
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
...softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+12
-20
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+6
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
View file @
067e71a8
...
...
@@ -103,8 +103,8 @@ using DeviceGemmInstance =
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
...
...
@@ -112,7 +112,7 @@ using DeviceGemmInstance =
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
067e71a8
...
...
@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
tru
e
;
bool
time_kernel
=
fals
e
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
...
...
@@ -45,7 +45,7 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
alpha
=
0.25
;
// scaling after 1st gemm
float
alpha
=
1
;
// scaling after 1st gemm
std
::
size_t
group_count
=
8
;
...
...
@@ -76,27 +76,17 @@ int run(int argc, char* argv[])
std
::
size_t
flop
=
0
,
num_byte
=
0
;
std
::
cout
<<
"group count "
<<
group_count
<<
". printing first 4 groups
\n
"
;
//
std::cout << "group count " << group_count << ". printing first 4 groups\n";
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
5
12
;
int
N
=
5
12
;
int
K
=
40
;
int
O
=
40
;
int
M
=
12
8
*
(
rand
()
%
8
+
1
)
;
int
N
=
12
8
*
(
rand
()
%
8
+
1
)
;
int
K
=
128
;
int
O
=
128
;
int
G0
=
rand
()
%
3
+
1
;
int
G1
=
rand
()
%
5
+
1
;
// int M = 128 * (rand() % 8 + 1);
// int N = 128 * (rand() % 8 + 1);
// int K = 40;
// int O = 40 * (rand() % 2 + 1);
// int G0 = rand() % 3 + 1;
// int G1 = rand() % 5 + 1;
std
::
cout
<<
"group id"
<<
i
<<
" M, N, K, O, G0, G1 is "
<<
M
<<
","
<<
N
<<
","
<<
K
<<
","
<<
O
<<
","
<<
G0
<<
","
<<
G1
<<
std
::
endl
;
g0_g1_m_n_k_o
.
push_back
({
G0
,
G1
,
M
,
N
,
K
,
O
});
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
...
...
@@ -322,7 +312,6 @@ int run(int argc, char* argv[])
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
// Tensor<CDataType> z_gs_ms_ns_host_result(z_gs_ms_os_lengths, z_gs_ms_os_strides);
Tensor
<
LSEDataType
>
lse_g_m_host_result
({
G0
*
G1
,
M
});
// scratch object after gemm1
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
...
...
@@ -416,11 +405,10 @@ int run(int argc, char* argv[])
atol
=
1
e
-
2
;
}
printf
(
"group id is %lu
\n
"
,
i
);
// bool pass_ =
// ck::utils::check_err(c_gs_ms_os_device_result.mData,
// c_gs_ms_os_host_result.mData);
bool
pass_
=
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
"Error: Incorrect results c!"
,
...
...
@@ -433,6 +421,10 @@ int run(int argc, char* argv[])
atol
);
pass
&=
pass_
;
}
if
(
pass
)
{
std
::
cout
<<
"Verification passed."
<<
std
::
endl
;
}
}
return
pass
?
0
:
1
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
067e71a8
...
...
@@ -273,6 +273,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
if
(
Gemm1N
!=
K
)
{
std
::
cout
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
return
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment