Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fe6ee651
Commit
fe6ee651
authored
Jan 11, 2023
by
fsx950223
Browse files
add workspace size
parent
3c6a9b06
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
1 deletion
+17
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
...oftmax_gemm/grouped_multihead_attention_backward_fp16.cpp
+12
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle.hpp
...ice_grouped_multihead_attention_backward_xdl_cshuffle.hpp
+5
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
View file @
fe6ee651
...
@@ -316,7 +316,7 @@ int run(int argc, char* argv[])
...
@@ -316,7 +316,7 @@ int run(int argc, char* argv[])
std
::
vector
<
DeviceMemPtr
>
ygrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
ygrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
kgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
kgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
vgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
vgrad_tensors_device
;
std
::
size_t
group_count
=
1
;
std
::
size_t
group_count
=
3
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
){
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
){
// int M = 128 * (rand() % 8 + 1);
// int M = 128 * (rand() % 8 + 1);
...
@@ -538,6 +538,17 @@ int run(int argc, char* argv[])
...
@@ -538,6 +538,17 @@ int run(int argc, char* argv[])
Scale
{
alpha
},
Scale
{
alpha
},
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{});
YElementOp
{});
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace
.
GetDeviceBuffer
());
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle.hpp
View file @
fe6ee651
...
@@ -1062,6 +1062,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1062,6 +1062,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GroupKernelArg
);
}
static
auto
MakeArgument
(
const
std
::
vector
<
const
DataType
*>&
p_As
,
static
auto
MakeArgument
(
const
std
::
vector
<
const
DataType
*>&
p_As
,
const
std
::
vector
<
const
DataType
*>&
p_Bs
,
const
std
::
vector
<
const
DataType
*>&
p_Bs
,
const
std
::
vector
<
const
DataType
*>&
p_B1s
,
const
std
::
vector
<
const
DataType
*>&
p_B1s
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment