Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6cc7d0de
Commit
6cc7d0de
authored
Jun 30, 2023
by
danyao12
Browse files
rename device ops
parent
38f48480
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
68 additions
and
64 deletions
+68
-64
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+15
-14
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v2.hpp
+15
-14
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+15
-14
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+15
-14
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+3
-3
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
6cc7d0de
...
@@ -40,7 +40,7 @@ __global__ void
...
@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1
(
kernel_grouped_multihead_attention_backward_
kloop_
xdl_cshuffle_v1
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -254,7 +254,7 @@ template <index_t NumDimG,
...
@@ -254,7 +254,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
;
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
;
struct
ProblemDesc
struct
ProblemDesc
{
{
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
...
@@ -956,16 +956,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -956,16 +956,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1
<
const
auto
kernel
=
GridwiseGemm
,
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v1
<
GroupKernelArg
,
GridwiseGemm
,
AElementwiseOperation
,
GroupKernelArg
,
BElementwiseOperation
,
AElementwiseOperation
,
AccElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
AccElementwiseOperation
,
CElementwiseOperation
,
B1ElementwiseOperation
,
has_main_k_block_loop_
,
CElementwiseOperation
,
Deterministic
>
;
has_main_k_block_loop_
,
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
@@ -1209,7 +1210,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1209,7 +1210,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str
<<
"DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v2.hpp
View file @
6cc7d0de
...
@@ -40,7 +40,7 @@ __global__ void
...
@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2
(
kernel_grouped_multihead_attention_backward_
kloop_
xdl_cshuffle_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -254,7 +254,7 @@ template <index_t NumDimG,
...
@@ -254,7 +254,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
;
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2
;
struct
ProblemDesc
struct
ProblemDesc
{
{
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
...
@@ -948,16 +948,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -948,16 +948,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2
<
const
auto
kernel
=
GridwiseGemm
,
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v2
<
GroupKernelArg
,
GridwiseGemm
,
AElementwiseOperation
,
GroupKernelArg
,
BElementwiseOperation
,
AElementwiseOperation
,
AccElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
AccElementwiseOperation
,
CElementwiseOperation
,
B1ElementwiseOperation
,
has_main_k_block_loop_
,
CElementwiseOperation
,
Deterministic
>
;
has_main_k_block_loop_
,
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
@@ -1200,7 +1201,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1200,7 +1201,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str
<<
"DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
6cc7d0de
...
@@ -40,7 +40,7 @@ __global__ void
...
@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1
(
kernel_grouped_multihead_attention_backward_
qloop_
xdl_cshuffle_v1
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -251,7 +251,7 @@ template <index_t NumDimG,
...
@@ -251,7 +251,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -263,7 +263,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -263,7 +263,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
;
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V1
;
struct
ProblemDesc
struct
ProblemDesc
{
{
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
...
@@ -961,16 +961,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -961,16 +961,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1
<
const
auto
kernel
=
GridwiseGemm
,
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
GroupKernelArg
,
GridwiseGemm
,
AElementwiseOperation
,
GroupKernelArg
,
BElementwiseOperation
,
AElementwiseOperation
,
AccElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
AccElementwiseOperation
,
CElementwiseOperation
,
B1ElementwiseOperation
,
has_main_k_block_loop_
,
CElementwiseOperation
,
Deterministic
>
;
has_main_k_block_loop_
,
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
@@ -1207,7 +1208,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1207,7 +1208,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str
<<
"DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V1"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
6cc7d0de
...
@@ -40,7 +40,7 @@ __global__ void
...
@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2
(
kernel_grouped_multihead_attention_backward_
qloop_
xdl_cshuffle_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -258,7 +258,7 @@ template <index_t NumDimG,
...
@@ -258,7 +258,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V2
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -270,7 +270,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -270,7 +270,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
;
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V2
;
struct
ProblemDesc
struct
ProblemDesc
{
{
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
...
@@ -968,16 +968,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -968,16 +968,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2
<
const
auto
kernel
=
GridwiseGemm
,
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2
<
GroupKernelArg
,
GridwiseGemm
,
AElementwiseOperation
,
GroupKernelArg
,
BElementwiseOperation
,
AElementwiseOperation
,
AccElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
AccElementwiseOperation
,
CElementwiseOperation
,
B1ElementwiseOperation
,
has_main_k_block_loop_
,
CElementwiseOperation
,
Deterministic
>
;
has_main_k_block_loop_
,
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
@@ -1219,7 +1220,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1219,7 +1220,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str
<<
"DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V2"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
View file @
6cc7d0de
...
@@ -39,7 +39,7 @@ __global__ void
...
@@ -39,7 +39,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v
2
(
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v
1
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -250,7 +250,7 @@ template <index_t NumDimG,
...
@@ -250,7 +250,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
struct
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V1
:
public
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
:
public
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -290,7 +290,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -290,7 +290,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V1
;
using
ProblemDesc
=
typename
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
using
ProblemDesc
=
typename
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -813,7 +813,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -813,7 +813,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto
launch_kernel
=
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v
2
<
GridwiseGemm
,
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v
1
<
GridwiseGemm
,
GemmAccDataType
,
GemmAccDataType
,
GroupKernelArg
,
GroupKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -1123,7 +1123,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1123,7 +1123,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle"
str
<<
"DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V1
"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
6cc7d0de
...
@@ -256,7 +256,7 @@ template <index_t NumDimG,
...
@@ -256,7 +256,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
struct
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V2
:
public
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
:
public
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -296,7 +296,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -296,7 +296,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V2
;
using
ProblemDesc
=
typename
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
using
ProblemDesc
=
typename
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -1145,7 +1145,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1145,7 +1145,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle"
str
<<
"DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V2
"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment