Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5509e684
Commit
5509e684
authored
Jan 13, 2023
by
fsx950223
Browse files
fix arguments
parent
2ebc3248
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
27 deletions
+27
-27
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle.hpp
...ice_grouped_multihead_attention_backward_xdl_cshuffle.hpp
+27
-27
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle.hpp
View file @
5509e684
...
@@ -665,15 +665,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -665,15 +665,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
vector
<
const
DataType
*>&
p_As
,
Argument
(
const
std
::
vector
<
const
void
*>&
p_As
,
const
std
::
vector
<
const
DataType
*>&
p_Bs
,
const
std
::
vector
<
const
void
*>&
p_Bs
,
const
std
::
vector
<
const
DataType
*>&
p_B1s
,
const
std
::
vector
<
const
void
*>&
p_B1s
,
const
std
::
vector
<
const
DataType
*>&
p_Cs
,
// for dS
const
std
::
vector
<
const
void
*>&
p_Cs
,
// for dS
const
std
::
vector
<
const
LSEDataType
*>&
p_LSEs
,
const
std
::
vector
<
const
void
*>&
p_LSEs
,
const
std
::
vector
<
const
DataType
*>&
p_Ygrads
,
const
std
::
vector
<
const
void
*>&
p_Ygrads
,
std
::
vector
<
DataType
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
DataType
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
DataType
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_biases
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
...
@@ -1042,15 +1042,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1042,15 +1042,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GroupKernelArg
);
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GroupKernelArg
);
}
}
static
auto
MakeArgument
(
const
std
::
vector
<
const
DataType
*>&
p_As
,
static
auto
MakeArgument
(
const
std
::
vector
<
const
void
*>&
p_As
,
const
std
::
vector
<
const
DataType
*>&
p_Bs
,
const
std
::
vector
<
const
void
*>&
p_Bs
,
const
std
::
vector
<
const
DataType
*>&
p_B1s
,
const
std
::
vector
<
const
void
*>&
p_B1s
,
const
std
::
vector
<
const
DataType
*>&
p_Cs
,
// for dS
const
std
::
vector
<
const
void
*>&
p_Cs
,
// for dS
const
std
::
vector
<
const
LSEDataType
*>&
p_LSEs
,
const
std
::
vector
<
const
void
*>&
p_LSEs
,
const
std
::
vector
<
const
DataType
*>&
p_Ygrads
,
const
std
::
vector
<
const
void
*>&
p_Ygrads
,
std
::
vector
<
DataType
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
DataType
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
DataType
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_biases
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
...
@@ -1084,15 +1084,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1084,15 +1084,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
// polymorphic
// polymorphic
// FIXME: constness
// FIXME: constness
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
const
DataType
*>&
p_As
,
MakeArgumentPointer
(
const
std
::
vector
<
const
void
*>&
p_As
,
const
std
::
vector
<
const
DataType
*>&
p_Bs
,
const
std
::
vector
<
const
void
*>&
p_Bs
,
const
std
::
vector
<
const
DataType
*>&
p_B1s
,
const
std
::
vector
<
const
void
*>&
p_B1s
,
const
std
::
vector
<
const
DataType
*>&
p_Cs
,
// for dS
const
std
::
vector
<
const
void
*>&
p_Cs
,
// for dS
const
std
::
vector
<
const
LSEDataType
*>&
p_LSEs
,
const
std
::
vector
<
const
void
*>&
p_LSEs
,
const
std
::
vector
<
const
DataType
*>&
p_Ygrads
,
const
std
::
vector
<
const
void
*>&
p_Ygrads
,
std
::
vector
<
DataType
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
DataType
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
DataType
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_biases
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment