Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ad3e94bb
Commit
ad3e94bb
authored
Jul 28, 2024
by
danyao12
Browse files
fwd dropout revert
parent
a0c92495
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
433 additions
and
153 deletions
+433
-153
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+5
-5
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+1
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+9
-10
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+9
-10
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+1
-1
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+7
-3
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+294
-10
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+28
-32
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+41
-39
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+8
-12
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
...peline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
+7
-8
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+1
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+7
-9
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+8
-7
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
...le/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
+2
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+1
-1
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+4
-0
No files found.
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
ad3e94bb
...
@@ -67,11 +67,11 @@ BIAS_CHECK_MAP = {
...
@@ -67,11 +67,11 @@ BIAS_CHECK_MAP = {
}
}
DROPOUT_MAP
=
{
DROPOUT_MAP
=
{
"no"
:
"ck_tile::BlockDropout<false, true, false>"
,
"no"
:
"ck_tile::BlockDropout
Bwd
<false, true, false>"
,
"dropout_wg32"
:
"ck_tile::BlockDropout<true, true, false>"
,
"dropout_wg32"
:
"ck_tile::BlockDropout
Bwd
<true, true, false>"
,
"dropout_wg32_storerandval"
:
"ck_tile::BlockDropout<true, true, true >"
,
"dropout_wg32_storerandval"
:
"ck_tile::BlockDropout
Bwd
<true, true, true >"
,
"dropout_wg16"
:
"ck_tile::BlockDropout<true, false, false>"
,
"dropout_wg16"
:
"ck_tile::BlockDropout
Bwd
<true, false, false>"
,
"dropout_wg16_storerandval"
:
"ck_tile::BlockDropout<true, false, true >"
"dropout_wg16_storerandval"
:
"ck_tile::BlockDropout
Bwd
<true, false, true >"
}
}
DROPOUT_CHECK_MAP
=
{
DROPOUT_CHECK_MAP
=
{
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
ad3e94bb
...
@@ -62,6 +62,7 @@ using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
...
@@ -62,6 +62,7 @@ using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_dbias},
{F_dbias},
false,
false,
false,
false,
false,
{F_occupancy}>;
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
using fmha_dropout_{F_idx} = {F_dropout};
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
ad3e94bb
...
@@ -53,10 +53,10 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
...
@@ -53,10 +53,10 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_bias},
{F_bias},
false,
false,
{F_lse},
{F_lse},
{F_dropout},
{F_squant},
{F_squant},
{F_occupancy}>;
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
...
@@ -73,7 +73,6 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
...
@@ -73,7 +73,6 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
fmha_shape_{F_idx},
fmha_shape_{F_idx},
{F_mode},
{F_mode},
fmha_mask_{F_idx},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_trait_{F_idx}>;
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
using fmha_pipeline_{F_idx} = {F_pipeline}<
...
@@ -90,7 +89,7 @@ using fmha_kernel_{F_idx} =
...
@@ -90,7 +89,7 @@ using fmha_kernel_{F_idx} =
fmha_epilogue_{F_idx}>;
fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx},
fmha_dropout_{F_idx
}, {F_
bias
}, {F_
lse
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx},
{F_bias
}, {F_
lse
}, {F_
dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
#include <iostream>
...
@@ -125,9 +124,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
...
@@ -125,9 +124,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
}}
"""
"""
FMHA_FWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && ({F_dropout
_check
}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (
t.has_dropout ==
{F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_
dropout}, {F_
bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse
}, {F_dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
return fmha_fwd_<trait_>(s, a);
}}
}}
"""
"""
...
@@ -239,7 +238,7 @@ class FmhaFwdPipeline:
...
@@ -239,7 +238,7 @@ class FmhaFwdPipeline:
else
:
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
!
=
'
no
'
:
n
+=
f
'_
{
self
.
F_
dropout
}
'
if
self
.
F_dropout
=
=
'
t
'
:
n
+=
'_dropout'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
return
n
return
n
...
@@ -270,7 +269,7 @@ class FmhaFwdApiPool:
...
@@ -270,7 +269,7 @@ class FmhaFwdApiPool:
inners
=
inners
+
FMHA_FWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
inners
=
inners
+
FMHA_FWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
]
,
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
=
BOOL_MAP
[
trait
.
dropout
]
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
...
@@ -348,7 +347,7 @@ class FmhaFwdKernel:
...
@@ -348,7 +347,7 @@ class FmhaFwdKernel:
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
DROPOUT
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_dropout
=
BOOL
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
...
@@ -420,7 +419,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
...
@@ -420,7 +419,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
squant
=
't'
if
dtype
==
'fp8'
else
'f'
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
if
dtype
in
[
'fp16'
,
'bf16'
]:
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
list
(
DROPOUT_MAP
.
keys
())[:
3
]):
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
if
hdim
==
256
:
if
hdim
==
256
:
# if True:
# if True:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
...
@@ -439,7 +438,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
...
@@ -439,7 +438,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# no need lse/dropout kernels
# no need lse/dropout kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'
no
'
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'
f
'
,
squant
,
mask
))
else
:
else
:
assert
False
assert
False
return
pipelines
return
pipelines
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
ad3e94bb
...
@@ -29,7 +29,6 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = {
...
@@ -29,7 +29,6 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = {
FMHA_FWD_SPLITKV_KERNEL_BODY
=
"""
FMHA_FWD_SPLITKV_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask};
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
namespace {{
namespace {{
template <bool kHasUnevenSplits>
template <bool kHasUnevenSplits>
...
@@ -52,6 +51,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
...
@@ -52,6 +51,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_bias},
{F_bias},
false,
false,
{F_lse},
{F_lse},
{F_dropout},
{F_squant},
{F_squant},
kHasUnevenSplits,
kHasUnevenSplits,
{F_occupancy}>;
{F_occupancy}>;
...
@@ -71,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...
@@ -71,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
fmha_shape,
fmha_shape,
{F_mode},
{F_mode},
fmha_mask_{F_idx},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_trait>;
fmha_trait>;
using fmha_pipeline = {F_pipeline}<
using fmha_pipeline = {F_pipeline}<
...
@@ -99,7 +98,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
...
@@ -99,7 +98,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
}}
}}
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx},
fmha_dropout_{F_idx
}, {F_
bias
}, {F_
lse
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx},
{F_bias
}, {F_
lse
}, {F_
dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
#include <iostream>
...
@@ -225,9 +224,9 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream
...
@@ -225,9 +224,9 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream
}}
}}
"""
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && ({F_dropout
_check
}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (
t.has_dropout ==
{F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_
dropout}, {F_
bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse
}, {F_dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
...
@@ -268,7 +267,7 @@ class FmhaFwdSplitKVPipeline:
...
@@ -268,7 +267,7 @@ class FmhaFwdSplitKVPipeline:
else
:
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
!
=
'
no
'
:
n
+=
f
'_
{
self
.
F_
dropout
}
'
if
self
.
F_dropout
=
=
'
t
'
:
n
+=
'_dropout'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
return
n
return
n
...
@@ -323,7 +322,7 @@ class FmhaFwdSplitKVApiPool:
...
@@ -323,7 +322,7 @@ class FmhaFwdSplitKVApiPool:
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
]
,
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
=
BOOL_MAP
[
trait
.
dropout
]
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
...
@@ -384,7 +383,7 @@ class FmhaFwdSplitKVKernel:
...
@@ -384,7 +383,7 @@ class FmhaFwdSplitKVKernel:
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
DROPOUT
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_dropout
=
BOOL
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
...
@@ -535,7 +534,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -535,7 +534,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
pipelines
=
[]
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
if
dtype
in
[
'fp16'
,
'bf16'
]:
# splitkv kernel donot support dropout
# splitkv kernel donot support dropout
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
list
(
DROPOUT_MAP
.
keys
())[:
1
]):
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"f"
]):
if
hdim
==
256
:
if
hdim
==
256
:
# if True:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
...
@@ -554,7 +553,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -554,7 +553,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# no need lse/dropout kernels
# no need lse/dropout kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'
no
'
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'
f
'
,
squant
,
mask
))
else
:
else
:
assert
False
assert
False
return
pipelines
return
pipelines
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
ad3e94bb
...
@@ -622,7 +622,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -622,7 +622,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias
.
type
,
bias
.
type
,
lse
,
lse
,
p_drop
>
0.0
f
,
p_drop
>
0.0
f
,
s_randval
,
squant
};
squant
};
auto
p_compute_element_func
=
[
&
]()
{
auto
p_compute_element_func
=
[
&
]()
{
...
@@ -745,6 +744,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -745,6 +744,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask
.
right
,
mask
.
right
,
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
p_drop
,
s_randval
,
{
drop_seed
,
drop_offset
}};
{
drop_seed
,
drop_offset
}};
}();
}();
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
ad3e94bb
...
@@ -143,6 +143,7 @@ struct fmha_fwd_args
...
@@ -143,6 +143,7 @@ struct fmha_fwd_args
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
float
p_drop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
};
...
@@ -189,6 +190,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -189,6 +190,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
,
args
.
p_drop
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
args
.
drop_seed_offset
);
}
}
else
else
...
@@ -233,6 +235,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -233,6 +235,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
,
args
.
p_drop
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
args
.
drop_seed_offset
);
}
}
}();
}();
...
@@ -289,6 +292,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -289,6 +292,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
,
args
.
p_drop
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
args
.
drop_seed_offset
);
}
}
else
else
...
@@ -337,6 +341,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -337,6 +341,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
,
args
.
p_drop
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
args
.
drop_seed_offset
);
}
}
}();
}();
...
@@ -422,9 +427,9 @@ template <ck_tile::index_t HDim_,
...
@@ -422,9 +427,9 @@ template <ck_tile::index_t HDim_,
bool
kIsVLayoutRowMajor_
,
bool
kIsVLayoutRowMajor_
,
ck_tile
::
BlockFmhaPipelineEnum
FmhaPipelineEnum_
,
ck_tile
::
BlockFmhaPipelineEnum
FmhaPipelineEnum_
,
typename
FmhaMask_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
bool
kStoreLse_
,
bool
kStoreLse_
,
bool
kHasDropout_
,
bool
kDoFp8StaticQuant_
,
bool
kDoFp8StaticQuant_
,
bool
kPadS_
,
bool
kPadS_
,
bool
kPadSK_
,
bool
kPadSK_
,
...
@@ -444,9 +449,9 @@ struct fmha_fwd_traits_
...
@@ -444,9 +449,9 @@ struct fmha_fwd_traits_
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
auto
FmhaPipelineEnum
=
FmhaPipelineEnum_
;
static
constexpr
auto
FmhaPipelineEnum
=
FmhaPipelineEnum_
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
FmhaDropout_
>
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
...
@@ -503,7 +508,6 @@ struct fmha_fwd_traits
...
@@ -503,7 +508,6 @@ struct fmha_fwd_traits
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_lse
;
bool
has_lse
;
bool
has_dropout
;
bool
has_dropout
;
bool
is_store_randval
;
bool
do_fp8_static_quant
;
bool
do_fp8_static_quant
;
// TODO: padding check is inside this api
// TODO: padding check is inside this api
};
};
...
...
include/ck_tile/ops/fmha/block/block_dropout.hpp
View file @
ad3e94bb
...
@@ -8,11 +8,295 @@
...
@@ -8,11 +8,295 @@
namespace
ck_tile
{
namespace
ck_tile
{
struct
NullBlockDropout
{
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
__host__
__device__
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
(
void
)
randval_dram_block_window_tmp
;
(
void
)
seqlen_qk_start
;
return
make_null_tile_window
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{}));
}
};
struct
BlockDropout
{
CK_TILE_HOST_DEVICE
BlockDropout
(
index_t
i_batch
,
index_t
i_head
,
index_t
nheads
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
float
rp_undrop_
,
uint8_t
p_undrop_in_uint8_t_
,
bool
is_store_randval_
)
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
get_lane_id
()),
rp_undrop
(
rp_undrop_
),
p_undrop_in_uint8_t
(
p_undrop_in_uint8_t_
),
is_store_randval
(
is_store_randval_
)
{
}
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
const
auto
block_origin
=
randval_dram_block_window_tmp
.
get_window_origin
();
auto
randval_dram_window
=
[
&
]()
{
if
constexpr
(
IsFwd
)
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
block_origin
.
at
(
number
<
0
>
{}),
seqlen_qk_start
});
// M/N
}
else
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
seqlen_qk_start
,
block_origin
.
at
(
number
<
1
>
{})});
// M/N
}
}();
return
randval_dram_window
;
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsBlockDescriptor
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
WG
::
kN
;
constexpr
index_t
kN1
=
8
;
constexpr
index_t
kN0
=
kNPerStep
/
kN1
;
constexpr
auto
randval_lds_block_desc_0
=
make_naive_tensor_descriptor
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kMPerStep
>
{},
number
<
kN1
>
{}),
ck_tile
::
make_tuple
(
number
<
(
kMPerStep
+
1
)
*
kN1
>
{},
number
<
kN1
>
{},
number
<
1
>
{}),
number
<
kN1
>
{},
number
<
1
>
{});
constexpr
auto
randval_lds_block_desc
=
transform_tensor_descriptor
(
randval_lds_block_desc_0
,
ck_tile
::
make_tuple
(
make_pass_through_transform
(
number
<
kMPerStep
>
{}),
make_merge_transform
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kN1
>
{}))),
ck_tile
::
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
ck_tile
::
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
randval_lds_block_desc
;
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
1
;
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
constexpr
auto
randval_block_inner_part_dstr_encoding
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
BlockGemm
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
CDataType
,
float
>
)
{
return
typename
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
}
else
{
return
typename
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
}
}();
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
randval_block_inner_part_dstr_encoding
);
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsShuffleTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
1
;
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
,
typename
PComputeDataType
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
void
*
randval_ptr
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// randval tile in LDS
auto
randval_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
uint8_t
*>
(
randval_ptr
),
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
());
auto
randval_lds_window
=
make_tile_window
(
randval_lds
,
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
().
get_lengths
(),
{
0
,
0
});
// register distribute
auto
randval_dist_generated
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
static_assert
(
randval_dist_generated
.
kThreadElementSpaceSize
==
16
);
auto
randval_lds_read_window
=
make_tile_window
(
randval_lds_window
.
get_bottom_tensor_view
(),
randval_lds_window
.
get_window_lengths
(),
randval_lds_window
.
get_window_origin
(),
MakeRandValLdsShuffleTileDistribution
<
BlockGemm
>
());
const
int
start_m0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
0
>
{});
if
(
is_store_randval
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
// save to Global
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
});
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
});
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
};
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
,
idx1
.
impl_
.
at
(
1
),
idx1
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
?
p_compute
[
p_idx
]
*
rp_undrop
:
PComputeDataType
(
0
);
});
});
});
});
}
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
template
<
bool
IsDropout_
,
bool
IsWG32_
,
bool
IsStoreRandval_
>
template
<
bool
IsDropout_
,
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropout
;
struct
BlockDropout
Bwd
;
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropout
<
false
,
IsWG32_
,
IsStoreRandval_
>
struct
BlockDropout
Bwd
<
false
,
IsWG32_
,
IsStoreRandval_
>
{
{
static
constexpr
bool
IsDropout
=
false
;
static
constexpr
bool
IsDropout
=
false
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
...
@@ -30,7 +314,7 @@ struct BlockDropout<false, IsWG32_, IsStoreRandval_>
...
@@ -30,7 +314,7 @@ struct BlockDropout<false, IsWG32_, IsStoreRandval_>
};
};
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropout
<
true
,
IsWG32_
,
IsStoreRandval_
>
struct
BlockDropout
Bwd
<
true
,
IsWG32_
,
IsStoreRandval_
>
{
{
static
constexpr
bool
IsDropout
=
true
;
static
constexpr
bool
IsDropout
=
true
;
// true: 32*32 warp gemm
// true: 32*32 warp gemm
...
@@ -38,13 +322,13 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
...
@@ -38,13 +322,13 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
static
constexpr
bool
IsWG32
=
IsWG32_
;
static
constexpr
bool
IsWG32
=
IsWG32_
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
CK_TILE_HOST_DEVICE
BlockDropout
(
index_t
i_batch
,
CK_TILE_HOST_DEVICE
BlockDropout
Bwd
(
index_t
i_batch
,
index_t
i_head
,
index_t
i_head
,
index_t
nheads
,
index_t
nheads
,
unsigned
long
long
seed
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
unsigned
long
long
offset
,
float
rp_undrop_
,
float
rp_undrop_
,
uint8_t
p_undrop_in_uint8_t_
)
uint8_t
p_undrop_in_uint8_t_
)
:
ph
(
seed
,
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
(
IsWG32
?
get_lane_id
()
:
((
get_lane_id
()
&
47
)
+
((
get_warp_id
()
&
1
)
<<
4
)))),
(
IsWG32
?
get_lane_id
()
:
((
get_lane_id
()
&
47
)
+
((
get_warp_id
()
&
1
)
<<
4
)))),
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
ad3e94bb
...
@@ -47,12 +47,10 @@ struct FmhaFwdKernel
...
@@ -47,12 +47,10 @@ struct FmhaFwdKernel
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
// clang-format off
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
typename
T
>
struct
t2s
;
...
@@ -89,8 +87,7 @@ struct FmhaFwdKernel
...
@@ -89,8 +87,7 @@ struct FmhaFwdKernel
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _SS_
#undef _TS_
#undef _TS_
// clang-format on
// clang-format on
...
@@ -188,6 +185,7 @@ struct FmhaFwdKernel
...
@@ -188,6 +185,7 @@ struct FmhaFwdKernel
}
}
float
rp_undrop
=
1
;
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
void
*
rand_val_ptr
=
nullptr
;
...
@@ -279,6 +277,7 @@ struct FmhaFwdKernel
...
@@ -279,6 +277,7 @@ struct FmhaFwdKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
...
@@ -346,13 +345,11 @@ struct FmhaFwdKernel
...
@@ -346,13 +345,11 @@ struct FmhaFwdKernel
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
constexpr
(
kIsStoreRandval
)
kargs
.
rand_val_ptr
=
rand_val_ptr
;
{
kargs
.
stride_randval
=
stride_randval
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
}
}
}
return
kargs
;
return
kargs
;
...
@@ -395,6 +392,7 @@ struct FmhaFwdKernel
...
@@ -395,6 +392,7 @@ struct FmhaFwdKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
...
@@ -460,12 +458,10 @@ struct FmhaFwdKernel
...
@@ -460,12 +458,10 @@ struct FmhaFwdKernel
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
constexpr
(
kIsStoreRandval
)
kargs
.
rand_val_ptr
=
rand_val_ptr
;
{
kargs
.
stride_randval
=
stride_randval
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
}
}
}
return
kargs
;
return
kargs
;
...
@@ -530,7 +526,7 @@ struct FmhaFwdKernel
...
@@ -530,7 +526,7 @@ struct FmhaFwdKernel
{
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
}
if
constexpr
(
k
IsStoreRandval
)
if
constexpr
(
k
HasDropout
)
{
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
}
...
@@ -570,7 +566,7 @@ struct FmhaFwdKernel
...
@@ -570,7 +566,7 @@ struct FmhaFwdKernel
{
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
}
if
constexpr
(
k
IsStoreRandval
)
if
constexpr
(
k
HasDropout
)
{
{
batch_offset_randval
=
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
@@ -748,28 +744,28 @@ struct FmhaFwdKernel
...
@@ -748,28 +744,28 @@ struct FmhaFwdKernel
}
}
}();
}();
// dropout
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
return
FmhaDropout
{
i_batch_
,
return
BlockDropout
{
i_batch_
,
i_nhead_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
kargs
.
p_undrop_in_uint8_t
,
kargs
.
is_store_randval
};
}
}
else
else
{
{
return
Fmha
Dropout
{};
return
NullBlock
Dropout
{};
};
};
}();
}();
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
IsStoreRandval
)
if
constexpr
(
k
HasDropout
)
{
{
RandValOutputDataType
*
rand_val_ptr
=
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
ad3e94bb
...
@@ -46,12 +46,10 @@ struct FmhaFwdSplitKVKernel
...
@@ -46,12 +46,10 @@ struct FmhaFwdSplitKVKernel
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
// clang-format off
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
typename
T
>
struct
t2s
;
...
@@ -88,8 +86,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -88,8 +86,7 @@ struct FmhaFwdSplitKVKernel
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _SS_
#undef _TS_
#undef _TS_
// clang-format on
// clang-format on
...
@@ -192,6 +189,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -192,6 +189,7 @@ struct FmhaFwdSplitKVKernel
}
}
float
rp_undrop
=
1
;
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
void
*
rand_val_ptr
=
nullptr
;
...
@@ -284,6 +282,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -284,6 +282,7 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
...
@@ -351,13 +350,11 @@ struct FmhaFwdSplitKVKernel
...
@@ -351,13 +350,11 @@ struct FmhaFwdSplitKVKernel
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
constexpr
(
kIsStoreRandval
)
kargs
.
rand_val_ptr
=
rand_val_ptr
;
{
kargs
.
stride_randval
=
stride_randval
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
}
}
}
return
kargs
;
return
kargs
;
...
@@ -405,6 +402,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -405,6 +402,7 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
...
@@ -471,12 +469,10 @@ struct FmhaFwdSplitKVKernel
...
@@ -471,12 +469,10 @@ struct FmhaFwdSplitKVKernel
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
constexpr
(
kIsStoreRandval
)
kargs
.
rand_val_ptr
=
rand_val_ptr
;
{
kargs
.
stride_randval
=
stride_randval
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
}
}
}
return
kargs
;
return
kargs
;
...
@@ -540,7 +536,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -540,7 +536,7 @@ struct FmhaFwdSplitKVKernel
{
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
}
if
constexpr
(
k
IsStoreRandval
)
if
constexpr
(
k
HasDropout
)
{
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
}
...
@@ -575,7 +571,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -575,7 +571,7 @@ struct FmhaFwdSplitKVKernel
{
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
}
}
if
constexpr
(
k
IsStoreRandval
)
if
constexpr
(
k
HasDropout
)
{
{
batch_offset_randval
=
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
@@ -747,27 +743,33 @@ struct FmhaFwdSplitKVKernel
...
@@ -747,27 +743,33 @@ struct FmhaFwdSplitKVKernel
}();
}();
// dropout
// dropout
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
float
rp_undrop
=
1
;
if
constexpr
(
kHasDropout
)
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
{
uint64_t
drop_seed
=
0
;
return
FmhaDropout
{
i_batch_
,
uint64_t
drop_offset
=
0
;
i_nhead_
,
bool
is_store_randval
=
false
;
kargs
.
num_head_q
,
kargs
.
drop_seed
,
if
constexpr
(
kHasDropout
)
kargs
.
drop_offset
,
{
kargs
.
rp_undrop
,
rp_undrop
=
kargs
.
rp_undrop
;
kargs
.
p_undrop_in_uint8_t
};
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
}
drop_seed
=
kargs
.
drop_seed
;
else
drop_offset
=
kargs
.
drop_offset
;
{
is_store_randval
=
kargs
.
is_store_randval
;
return
FmhaDropout
{};
}
};
BlockDropout
dropout
(
i_batch
,
}();
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
IsStoreRandval
)
if
constexpr
(
k
HasDropout
)
{
{
RandValOutputDataType
*
rand_val_ptr
=
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
ad3e94bb
...
@@ -28,7 +28,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -28,7 +28,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -50,7 +49,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -50,7 +49,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kHasDropout
=
false
;
// ignore this flag
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
@@ -141,7 +141,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -141,7 +141,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Fmha
Dropout
dropout
)
const
Block
Dropout
&
dropout
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
@@ -249,7 +249,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -249,7 +249,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
)
>
(
randval_dram_block_window_tmp
,
seqlen_k_start
);
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
auto
v_dram_window
=
...
@@ -501,14 +501,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -501,14 +501,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
});
});
});
});
if
constexpr
(
FmhaDropout
::
I
sDropout
)
if
constexpr
(
kHa
sDropout
)
{
{
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
dropout
.
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>
(
smem_ptr
,
smem_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
q_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
}
block_sync_lds
();
block_sync_lds
();
...
@@ -641,7 +637,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -641,7 +637,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Fmha
Dropout
dropout
)
const
Block
Dropout
&
dropout
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp
View file @
ad3e94bb
...
@@ -29,7 +29,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -29,7 +29,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -55,7 +54,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -55,7 +54,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kStoreLSE
=
true
;
// always store LSE (acc)
static
constexpr
bool
kHasDropout
=
false
;
// ignore this flag
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
@@ -153,7 +153,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -153,7 +153,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Fmha
Dropout
dropout
)
const
Block
Dropout
&
dropout
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
@@ -301,7 +301,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -301,7 +301,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
)
>
(
randval_dram_block_window_tmp
,
seqlen_k_start
);
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
auto
v_dram_window
=
...
@@ -584,13 +584,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -584,13 +584,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
});
});
});
});
if
constexpr
(
FmhaDropout
::
I
sDropout
)
if
constexpr
(
kHa
sDropout
)
{
{
auto
randval_ptr
=
auto
randval_ptr
=
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
dropout
.
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>
(
randval_ptr
,
randval_ptr
,
q_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
+
i_total_loops
*
kN0
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
p_compute
,
randval_dram_window
);
randval_dram_window
);
...
@@ -742,7 +741,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
...
@@ -742,7 +741,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Fmha
Dropout
dropout
)
const
Block
Dropout
&
dropout
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
ad3e94bb
...
@@ -21,7 +21,6 @@ template <typename QDataType_,
...
@@ -21,7 +21,6 @@ template <typename QDataType_,
typename
BlockFmhaShape_
,
typename
BlockFmhaShape_
,
bool
kIsGroupMode_
,
bool
kIsGroupMode_
,
typename
FmhaMask_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
typename
Traits_
>
typename
Traits_
>
struct
BlockFmhaPipelineProblem
struct
BlockFmhaPipelineProblem
{
{
...
@@ -38,7 +37,6 @@ struct BlockFmhaPipelineProblem
...
@@ -38,7 +37,6 @@ struct BlockFmhaPipelineProblem
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
BlockFmhaShape_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
remove_cvref_t
<
FmhaDropout_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
...
@@ -51,6 +49,7 @@ struct BlockFmhaPipelineProblem
...
@@ -51,6 +49,7 @@ struct BlockFmhaPipelineProblem
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Traits
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
};
...
@@ -69,7 +68,6 @@ template <typename QDataType,
...
@@ -69,7 +68,6 @@ template <typename QDataType,
typename
BlockFmhaShape
,
typename
BlockFmhaShape
,
bool
kIsGroupMode
,
bool
kIsGroupMode
,
typename
FmhaMask
,
typename
FmhaMask
,
typename
FmhaDropout
,
typename
Traits
>
typename
Traits
>
struct
BlockFmhaFwdSplitKVPipelineProblem
:
BlockFmhaPipelineProblem
<
QDataType
,
struct
BlockFmhaFwdSplitKVPipelineProblem
:
BlockFmhaPipelineProblem
<
QDataType
,
KDataType
,
KDataType
,
...
@@ -85,7 +83,6 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
...
@@ -85,7 +83,6 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
BlockFmhaShape
,
BlockFmhaShape
,
kIsGroupMode
,
kIsGroupMode
,
FmhaMask
,
FmhaMask
,
FmhaDropout
,
Traits
>
Traits
>
{
{
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
ad3e94bb
...
@@ -29,7 +29,6 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -29,7 +29,6 @@ struct BlockFmhaPipelineQRKSVS
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -52,6 +51,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -52,6 +51,7 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
// ... together with tensor distribution. tensor dist should able to overwrite this
...
@@ -100,6 +100,8 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -100,6 +100,8 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
const
char
*
name
=
"qr"
;
static
constexpr
const
char
*
name
=
"qr"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
@@ -139,7 +141,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -139,7 +141,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Fmha
Dropout
dropout
)
const
Dropout
Type
&
dropout
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
@@ -484,14 +486,10 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -484,14 +486,10 @@ struct BlockFmhaPipelineQRKSVS
});
});
});
});
if
constexpr
(
FmhaDropout
::
I
sDropout
)
if
constexpr
(
kHa
sDropout
)
{
{
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
smem_ptr
,
smem_ptr
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
q_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
randval_dram_window
);
}
}
block_sync_lds
();
block_sync_lds
();
...
@@ -622,7 +620,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -622,7 +620,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Fmha
Dropout
dropout
)
const
Dropout
Type
&
dropout
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
ad3e94bb
...
@@ -30,7 +30,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -30,7 +30,6 @@ struct BlockFmhaPipelineQRKSVSAsync
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -57,6 +56,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -57,6 +56,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
// ... together with tensor distribution. tensor dist should able to overwrite this
...
@@ -82,7 +82,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -82,7 +82,7 @@ struct BlockFmhaPipelineQRKSVSAsync
else
else
{
{
// minimize occupancy
// minimize occupancy
if
constexpr
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
FmhaDropout
::
I
sDropout
)
if
constexpr
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
kHa
sDropout
)
{
{
return
1
;
return
1
;
}
}
...
@@ -118,6 +118,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -118,6 +118,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
const
char
*
name
=
"qr_async"
;
static
constexpr
const
char
*
name
=
"qr_async"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
@@ -157,7 +159,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -157,7 +159,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Fmha
Dropout
dropout
)
const
Dropout
Type
&
dropout
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
@@ -303,7 +305,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -303,7 +305,7 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr
auto
k_pre_np
=
[
&
]()
{
constexpr
auto
k_pre_np
=
[
&
]()
{
if
constexpr
(
kPadSeqLenK
&&
if
constexpr
(
kPadSeqLenK
&&
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
FmhaDropout
::
I
sDropout
)))
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
&&
kHa
sDropout
)))
return
bool_constant
<
true
>
{};
return
bool_constant
<
true
>
{};
else
else
return
bool_constant
<
false
>
{};
return
bool_constant
<
false
>
{};
...
@@ -587,13 +589,12 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -587,13 +589,12 @@ struct BlockFmhaPipelineQRKSVSAsync
});
});
});
});
if
constexpr
(
FmhaDropout
::
I
sDropout
)
if
constexpr
(
kHa
sDropout
)
{
{
auto
randval_ptr
=
auto
randval_ptr
=
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeKV
<
Problem
>();
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
dropout
.
template
Run
<
decltype
(
gemm_0
),
SMPLComputeDataType
,
RandValOutputDataType
>(
randval_ptr
,
randval_ptr
,
q_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
+
i_total_loops
*
kN0
,
seqlen_k_start
+
i_total_loops
*
kN0
,
p_compute
,
p_compute
,
randval_dram_window
);
randval_dram_window
);
...
@@ -746,7 +747,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -746,7 +747,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Fmha
Dropout
dropout
)
const
Dropout
Type
&
dropout
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
View file @
ad3e94bb
...
@@ -28,7 +28,6 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -28,7 +28,6 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
...
@@ -51,6 +50,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -51,6 +50,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
// ... together with tensor distribution. tensor dist should able to overwrite this
...
@@ -124,7 +124,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -124,7 +124,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
float
descale_qk
,
float
descale_qk
,
float
descale_sv
,
float
descale_sv
,
void
*
smem_ptr
,
void
*
smem_ptr
,
Fmha
Dropout
&
/*dropout*/
)
const
// not supported
Block
Dropout
&
/*dropout*/
)
const
// not supported
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
ad3e94bb
...
@@ -718,7 +718,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -718,7 +718,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeDropout
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeDropout
()
{
{
if
constexpr
(
Problem
::
FmhaDropout
::
I
sDropout
)
if
constexpr
(
Problem
::
kHa
sDropout
)
{
{
constexpr
auto
gemm_0
=
QXPolicy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_0
=
QXPolicy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
config
=
constexpr
auto
config
=
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
ad3e94bb
...
@@ -15,6 +15,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
...
@@ -15,6 +15,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
BlockAttentionBiasEnum
BiasEnum_
,
BlockAttentionBiasEnum
BiasEnum_
,
bool
kHasBiasGrad_
,
bool
kHasBiasGrad_
,
bool
kStoreLSE_
,
bool
kStoreLSE_
,
bool
kHasDropout_
,
bool
kDoFp8StaticQuant_
,
bool
kDoFp8StaticQuant_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaTraits
struct
TileFmhaTraits
...
@@ -26,6 +27,7 @@ struct TileFmhaTraits
...
@@ -26,6 +27,7 @@ struct TileFmhaTraits
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
};
...
@@ -37,6 +39,7 @@ template <bool kPadSeqLenQ /* padding for seqlen_q */,
...
@@ -37,6 +39,7 @@ template <bool kPadSeqLenQ /* padding for seqlen_q */,
BlockAttentionBiasEnum
BiasEnum
,
BlockAttentionBiasEnum
BiasEnum
,
bool
kHasBiasGrad
,
bool
kHasBiasGrad
,
bool
kStoreLSE
,
bool
kStoreLSE
,
bool
kHasDropout
,
bool
kDoFp8StaticQuant
,
bool
kDoFp8StaticQuant
,
bool
kHasUnevenSplits_
=
true
,
bool
kHasUnevenSplits_
=
true
,
index_t
kBlockPerCu
=
-
1
/* overwrite occupancy if not -1 */
>
index_t
kBlockPerCu
=
-
1
/* overwrite occupancy if not -1 */
>
...
@@ -47,6 +50,7 @@ struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ,
...
@@ -47,6 +50,7 @@ struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ,
BiasEnum
,
BiasEnum
,
kHasBiasGrad
,
kHasBiasGrad
,
kStoreLSE
,
kStoreLSE
,
kHasDropout
,
kDoFp8StaticQuant
,
kDoFp8StaticQuant
,
kBlockPerCu
>
kBlockPerCu
>
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment