Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b9cb659d
Commit
b9cb659d
authored
Feb 27, 2023
by
guangzlu
Browse files
added time test for run_grouped_multihead_attention_forward.inc
parent
a2b5c7ac
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
5 deletions
+31
-5
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+31
-5
No files found.
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
b9cb659d
...
...
@@ -5,12 +5,11 @@ int run(int argc, char* argv[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
fals
e
;
bool
time_kernel
=
tru
e
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
...
...
@@ -56,7 +55,8 @@ int run(int argc, char* argv[])
std
::
vector
<
const
void
*>
p_b0
;
std
::
vector
<
const
void
*>
p_b1
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
void
*>
p_z
;
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1_m_n_k_o
;
...
...
@@ -221,6 +221,7 @@ int run(int argc, char* argv[])
p_b1
.
push_back
(
b1_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_z
.
push_back
(
z_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_z_nullptr
.
push_back
(
nullptr
);
p_lse
.
push_back
(
lse_tensors_device
[
i
]
->
GetDeviceBuffer
());
}
...
...
@@ -233,12 +234,13 @@ int run(int argc, char* argv[])
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b0
,
p_b1
,
p_c
,
p_z
,
p_z
_nullptr
,
p_lse
,
{},
// p_acc0_biases
{},
// p_acc1_biases
...
...
@@ -252,7 +254,6 @@ int run(int argc, char* argv[])
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
...
@@ -277,6 +278,31 @@ int run(int argc, char* argv[])
bool
pass
=
true
;
if
(
do_verification
)
{
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b0
,
p_b1
,
p_c
,
p_z
,
p_lse
,
{},
// p_acc0_biases
{},
// p_acc1_biases
problem_descs
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
const
int
&
G0
=
g0_g1_m_n_k_o
[
i
][
0
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment