Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c703632f
Commit
c703632f
authored
Sep 25, 2023
by
letaoqin
Browse files
fix class name
parent
6bc73d41
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
4 deletions
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp
...n/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle.hpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp
View file @
c703632f
...
@@ -44,7 +44,7 @@ __global__ void
...
@@ -44,7 +44,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_mutiple_head_flash_attention_forward
(
kernel_batched_mu
l
tiple_head_flash_attention_forward
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
D0DataType
*
p_d0_grid
,
const
D0DataType
*
p_d0_grid
,
...
@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
...
@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
};
};
using
GridwiseGemm
=
GridwiseMutiHeadFlashAttentionForward_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseMu
l
tiHeadFlashAttentionForward_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
D0DataType
,
D0DataType
,
GemmAccDataType
,
GemmAccDataType
,
...
@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
...
@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_mutiple_head_flash_attention_forward
<
const
auto
kernel
=
kernel_batched_mu
l
tiple_head_flash_attention_forward
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
D0DataType
,
D0DataType
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle.hpp
View file @
c703632f
...
@@ -86,7 +86,7 @@ template <typename FloatAB,
...
@@ -86,7 +86,7 @@ template <typename FloatAB,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseMutiHeadFlashAttentionForward_Xdl_CShuffle
struct
GridwiseMu
l
tiHeadFlashAttentionForward_Xdl_CShuffle
{
{
static_assert
(
D0BlockTransferSrcScalarPerVector
==
1
||
static_assert
(
D0BlockTransferSrcScalarPerVector
==
1
||
D0BlockTransferSrcScalarPerVector
==
2
||
D0BlockTransferSrcScalarPerVector
==
2
||
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment