Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7cd4e574
Commit
7cd4e574
authored
Jan 01, 2025
by
Po Yen Chen
Browse files
Add group mode block-mapping for fmha splitkv kernel
parent
db952741
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
13 deletions
+35
-13
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+35
-13
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
7cd4e574
...
@@ -60,7 +60,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -60,7 +60,7 @@ struct FmhaFwdSplitKVKernel
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
// clang-format on
__host__
static
std
::
string
GetName
()
CK_TILE_HOST
static
std
::
string
GetName
()
{
{
// sync with generate.py
// sync with generate.py
// clang-format off
// clang-format off
...
@@ -237,7 +237,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -237,7 +237,7 @@ struct FmhaFwdSplitKVKernel
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
template
<
bool
Cond
=
!
kIsGroupMode
>
__host__
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
v_ptr
,
...
@@ -361,7 +361,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -361,7 +361,7 @@ struct FmhaFwdSplitKVKernel
}
}
template
<
bool
Cond
=
kIsGroupMode
>
template
<
bool
Cond
=
kIsGroupMode
>
__host__
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
q_ptr
,
MakeKargs
(
const
void
*
q_ptr
,
const
void
*
k_ptr
,
const
void
*
k_ptr
,
const
void
*
v_ptr
,
const
void
*
v_ptr
,
...
@@ -482,10 +482,20 @@ struct FmhaFwdSplitKVKernel
...
@@ -482,10 +482,20 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
num_splits
)
ck_tile
::
index_t
num_splits
)
{
{
// TODO: this may need tuning
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
FmhaPipeline
::
kM0
)
*
if
constexpr
(
kIsGroupMode
)
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
)
*
num_splits
,
{
nhead
,
return
dim3
(
nhead
,
batch_size
);
batch_size
,
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
)
*
num_splits
);
}
else
{
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
)
*
num_splits
,
nhead
,
batch_size
);
}
}
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
...
@@ -498,15 +508,27 @@ struct FmhaFwdSplitKVKernel
...
@@ -498,15 +508,27 @@ struct FmhaFwdSplitKVKernel
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
};
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
x
,
kargs
.
num_splits
);
if
constexpr
(
kIsGroupMode
)
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
{
const
index_t
i_nhead
=
blockIdx
.
y
;
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
z
,
kargs
.
num_splits
);
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
}
else
{
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
x
,
kargs
.
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
}
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment