Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ca1a816d
Commit
ca1a816d
authored
Dec 29, 2024
by
Po Yen Chen
Browse files
Move splitkv partitioner logics into splitkv kernel
parent
f31fad7d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
72 deletions
+57
-72
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+1
-3
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+0
-1
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+56
-14
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+0
-54
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
ca1a816d
...
@@ -107,9 +107,7 @@ using fmha_epilogue =
...
@@ -107,9 +107,7 @@ using fmha_epilogue =
false, false>>;
false, false>>;
using fmha_kernel =
using fmha_kernel =
ck_tile::FmhaFwdSplitKVKernel<ck_tile::FmhaFwdSplitKVTilePartitioner<fmha_shape>,
ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
{{
...
...
include/ck_tile/ops/fmha.hpp
View file @
ca1a816d
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
ca1a816d
...
@@ -17,10 +17,9 @@
...
@@ -17,10 +17,9 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdSplitKVKernel
struct
FmhaFwdSplitKVKernel
{
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
EpiloguePipeline_
>
;
using
EpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
...
@@ -60,7 +59,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -60,7 +59,7 @@ struct FmhaFwdSplitKVKernel
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
// clang-format on
__host__
static
std
::
string
GetName
()
CK_TILE_HOST
static
std
::
string
GetName
()
{
{
// sync with generate.py
// sync with generate.py
// clang-format off
// clang-format off
...
@@ -235,7 +234,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -235,7 +234,7 @@ struct FmhaFwdSplitKVKernel
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
template
<
bool
Cond
=
!
kIsGroupMode
>
__host__
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
v_ptr
,
...
@@ -359,7 +358,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -359,7 +358,7 @@ struct FmhaFwdSplitKVKernel
}
}
template
<
bool
Cond
=
kIsGroupMode
>
template
<
bool
Cond
=
kIsGroupMode
>
__host__
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
v_ptr
,
...
@@ -473,16 +472,60 @@ struct FmhaFwdSplitKVKernel
...
@@ -473,16 +472,60 @@ struct FmhaFwdSplitKVKernel
return
kargs
;
return
kargs
;
}
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
ck_tile
::
index_t
num_splits
)
{
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_seqlen_q
,
hdim_v
,
num_splits
);
// TODO: this may need tuning
if
constexpr
(
kIsGroupMode
)
{
return
dim3
(
nhead
,
batch_size
,
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
)
*
num_splits
);
}
else
{
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
)
*
num_splits
,
nhead
,
batch_size
);
}
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
if
constexpr
(
kIsGroupMode
)
{
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
z
,
kargs
.
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
}
else
{
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
x
,
kargs
.
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
}
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
...
@@ -495,8 +538,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -495,8 +538,7 @@ struct FmhaFwdSplitKVKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
]
=
const
auto
[
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
]
=
GetTileIndex
(
kargs
);
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
,
kargs
.
num_splits
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
deleted
100644 → 0
View file @
f31fad7d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdSplitKVTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
)
*
num_splits
,
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
x
,
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment