Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
63ea1d70
You need to sign in or sign up before continuing.
Commit
63ea1d70
authored
Sep 26, 2023
by
letaoqin
Browse files
update include file name
parent
cef44211
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
51 deletions
+51
-51
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp
...n/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp
+17
-17
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle.hpp
...n/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle.hpp
+34
-34
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp
View file @
63ea1d70
...
...
@@ -9,7 +9,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_m
ultiple_head_flash_attention_fwd
.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_m
ha_infer
.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...
...
@@ -207,22 +207,22 @@ template <index_t NumDimG,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionForward_Xdl
:
public
DeviceBatchedMultiheadAttentionInfer
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
C0ElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
C0ElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle.hpp
View file @
63ea1d70
...
...
@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_m
ultiple_head_flash_attention_fwd
.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_m
ha_infer
.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle.hpp"
...
...
@@ -196,22 +196,22 @@ template <index_t NumDimG,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionForward_Xdl
:
public
DeviceGroupedMultiheadAttentionInfer
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
...
...
@@ -231,23 +231,23 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
#endif
using
DeviceOp
=
DeviceGroupedMultiheadAttentionForward_Xdl
;
using
ProblemDesc
=
typename
DeviceGroupedMultiheadAttention
Forward
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>::
ProblemDesc
;
using
ProblemDesc
=
typename
DeviceGroupedMultiheadAttention
Infer
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>::
ProblemDesc
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment