Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f1793462
Commit
f1793462
authored
Feb 07, 2025
by
rocking
Browse files
Add rowwise dynamic quant kernel
parent
b5d201dd
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
96 additions
and
48 deletions
+96
-48
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+0
-1
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+33
-26
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+17
-1
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+20
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+4
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+10
-9
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+12
-10
No files found.
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
f1793462
...
...
@@ -6,7 +6,6 @@ FWD_DTYPE_MAP = {
"fp16"
:
"FmhaFwdFp16"
,
"bf16"
:
"FmhaFwdBf16"
,
"fp8"
:
"FmhaFwdFp8"
,
"fp8fp16"
:
"FmhaFwdFp8Fp16"
,
"fp8bf16"
:
"FmhaFwdFp8Bf16"
}
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
f1793462
...
...
@@ -56,6 +56,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_lse},
{F_dropout},
{F_squant},
{F_rdquant},
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
...
...
@@ -88,7 +89,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant},
{F_rdquant},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
...
...
@@ -123,9 +124,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_FWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
(t.do_fp8_rowwise_dynamic_quant == {F_rdquant})&&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant},
{F_rdquant},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
...
...
@@ -149,6 +150,7 @@ class FmhaFwdApiTrait:
lse
:
str
#
dropout
:
str
squant
:
str
#
rdquant
:
str
#
spad
:
str
skpad
:
str
dpad
:
str
...
...
@@ -157,7 +159,7 @@ class FmhaFwdApiTrait:
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
mode
}
-
{
self
.
bm0
}
-
{
self
.
bn0
}
-
{
self
.
bk0
}
-
{
self
.
bn0
}
-
{
self
.
bk1
}
-
{
self
.
bk0max
}
-'
+
\
f
'
{
self
.
vlayout
}
-
{
self
.
mask
}
-
{
self
.
bias
}
-
{
self
.
lse
}
-
{
self
.
dropout
}
-
{
self
.
squant
}
-
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-
{
self
.
dvpad
}
'
f
'
{
self
.
vlayout
}
-
{
self
.
mask
}
-
{
self
.
bias
}
-
{
self
.
lse
}
-
{
self
.
dropout
}
-
{
self
.
squant
}
-
{
self
.
dquant
}
-
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-
{
self
.
dvpad
}
'
@
property
def
scheck
(
self
)
->
str
:
...
...
@@ -218,6 +220,7 @@ class FmhaFwdPipeline:
F_lse
:
str
#
F_dropout
:
str
#
F_squant
:
str
#
F_rdquant
:
str
#
F_mask
:
str
# value from MASK_MAP
@
property
...
...
@@ -241,6 +244,7 @@ class FmhaFwdPipeline:
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
==
't'
:
n
+=
'_dropout'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
if
self
.
F_rdquant
==
't'
:
n
+=
'_rdquant'
return
n
class
FmhaFwdApiPool
:
...
...
@@ -271,7 +275,7 @@ class FmhaFwdApiPool:
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
=
BOOL_MAP
[
trait
.
dropout
]
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_rdquant
=
BOOL_MAP
[
trait
.
rdquant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0max
=
trait
.
bk0max
,
F_hdim
=
hdim
,
F_dtype
=
FWD_DTYPE_MAP
[
dtype
])
...
...
@@ -357,6 +361,7 @@ class FmhaFwdKernel:
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_rdquant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_rdquant
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_pipeline
.
F_mask
],
...
...
@@ -391,6 +396,7 @@ class FmhaFwdKernel:
lse
=
self
.
F_pipeline
.
F_lse
,
dropout
=
self
.
F_pipeline
.
F_dropout
,
squant
=
self
.
F_pipeline
.
F_squant
,
rdquant
=
self
.
F_pipeline
.
F_rdquant
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
...
...
@@ -407,7 +413,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
32
,
32
,
16
,
-
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
or
dtype
==
'fp8bf16'
:
return
{
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
32
,
32
,
32
,
-
1
),
...
...
@@ -424,39 +430,39 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
if
hdim
==
256
:
# if True:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
else
:
if
bias
==
"bias"
:
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
else
:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
if
receipt
==
1
and
bias
!=
"bias"
:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
'f'
,
'f'
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
]:
# no need lse/dropout kernels
# squant
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
squant
,
mask
))
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
#
TODO
None
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
't'
,
'f'
,
mask
))
elif
dtype
in
[
'fp8bf16'
]:
#
rdquant
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
'no'
,
'f'
,
'f'
,
'f'
,
't'
,
's_no'
))
else
:
assert
False
return
pipelines
...
...
@@ -492,6 +498,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'alibi'
]
cond
&=
pipeline
.
F_squant
==
'f'
cond
&=
pipeline
.
F_rdquant
==
'f'
if
not
cond
:
continue
api_pool
.
register_traits
(
k
.
api_trait
())
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
f1793462
...
...
@@ -85,6 +85,10 @@ auto create_args(int argc, char* argv[])
"P and O.
\n
"
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
"range_p, range_o"
)
.
insert
(
"rdquant"
,
"0"
,
"if using rowwise dynamic quantization fusion or not.
\n
"
"0: no dynamic quant. 1: apply rowwise dynamic quantization with respect Q, K and V.
\n
"
)
.
insert
(
"iperm"
,
"1"
,
"permute input
\n
"
...
...
@@ -95,7 +99,7 @@ auto create_args(int argc, char* argv[])
"n or 0, no bias
\n
"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
\n
"
"a(libi) or 2, alibi with 1*h. a:1, b*h"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8
/fp8bf16
"
)
.
insert
(
"mask"
,
"0"
,
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')
\n
"
...
...
@@ -176,6 +180,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
}
}
template
<
>
auto
get_elimit
<
FmhaFwdFp8Bf16
>
(
std
::
string
init_method
)
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
int
num_splits_heuristic
(
int
batch_nhead_mblocks
,
int
num_SMs
,
int
num_n_blocks
,
int
max_splits
)
{
// If we have enough to almost fill the SMs, then just use 1 split
...
...
@@ -1580,6 +1592,10 @@ int main(int argc, char* argv[])
{
return
run
<
FmhaFwdFp8
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"fp8bf16"
)
{
return
run
<
FmhaFwdFp8Bf16
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
f1793462
...
...
@@ -107,6 +107,22 @@ struct FmhaFwdTypeConfig<FmhaFwdBf8>
using
ODataType
=
ck_tile
::
bf8_t
;
};
template
<
>
struct
FmhaFwdTypeConfig
<
FmhaFwdFp8Bf16
>
{
using
QDataType
=
ck_tile
::
fp8_t
;
using
KDataType
=
ck_tile
::
fp8_t
;
using
VDataType
=
ck_tile
::
fp8_t
;
using
BiasDataType
=
float
;
using
RandValOutputDataType
=
uint8_t
;
using
LSEDataType
=
float
;
// data type for lse(logsumexp L_j = max_j + log(l_j))
using
SaccDataType
=
float
;
// data type for first gemm accumulation
using
SMPLComputeDataType
=
float
;
// data type for reduction, softmax
using
PDataType
=
ck_tile
::
fp8_t
;
// data type for A matrix of second gemm
using
OaccDataType
=
float
;
// data type for second gemm accumulation
using
ODataType
=
ck_tile
::
bf16_t
;
};
struct
FmhaMasks
{
using
NoMask
=
ck_tile
::
GenericAttentionMask
<
false
>
;
...
...
@@ -635,6 +651,7 @@ template <ck_tile::index_t HDim_,
bool
kStoreLse_
,
bool
kHasDropout_
,
bool
kDoFp8StaticQuant_
,
bool
kDoFp8RowwiseDynamicQuant_
,
bool
kPadS_
,
bool
kPadSK_
,
bool
kPadD_
,
...
...
@@ -657,6 +674,7 @@ struct fmha_fwd_traits_
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kDoFp8RowwiseDynamicQuant
=
kDoFp8RowwiseDynamicQuant_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
static
constexpr
bool
kPadD
=
kPadD_
;
...
...
@@ -789,6 +807,7 @@ struct fmha_fwd_traits
bool
has_lse
;
bool
has_dropout
;
bool
do_fp8_static_quant
;
bool
do_fp8_rowwise_dynamic_quant
;
// TODO: padding check is inside this api
};
float
fmha_fwd
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
...
...
@@ -804,6 +823,7 @@ struct fmha_fwd_splitkv_traits
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_lse
;
bool
do_fp8_static_quant
;
bool
do_fp8_rowwise_dynamic_quant
;
// TODO: padding check is inside this api
};
float
fmha_fwd_splitkv
(
fmha_fwd_splitkv_traits
,
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
f1793462
...
...
@@ -51,6 +51,8 @@ struct FmhaFwdKernel
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8RowwiseDynamicQuant
=
FmhaPipeline
::
Problem
::
kDoFp8RowwiseDynamicQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
...
...
@@ -93,7 +95,8 @@ struct FmhaFwdKernel
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
)
+
(
kDoFp8RowwiseDynamicQuant
?
"_rdquant"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
f1793462
...
...
@@ -46,15 +46,16 @@ struct BlockFmhaPipelineProblem
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Traits
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Traits
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8RowwiseDynamicQuant
=
Traits
::
kDoFp8RowwiseDynamicQuant
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
template
<
typename
QDataType_
,
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
f1793462
...
...
@@ -18,19 +18,21 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool
kStoreLSE_
,
bool
kHasDropout_
,
bool
kDoFp8StaticQuant_
,
bool
kDoFp8RowwiseDynamicQuant_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadSeqLenK
=
kPadSeqLenK_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadSeqLenK
=
kPadSeqLenK_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kDoFp8RowwiseDynamicQuant
=
kDoFp8RowwiseDynamicQuant_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment