Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
74f1516c
Commit
74f1516c
authored
Jul 10, 2024
by
danyao12
Browse files
tmp save
parent
497ccb87
Changes
43
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1972 additions
and
528 deletions
+1972
-528
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+1
-2
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+16
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+9
-8
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+9
-8
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+36
-8
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+72
-8
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+1
-1
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+3
-7
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+9
-33
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+3
-2
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+2
-7
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+105
-79
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+414
-265
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
+8
-12
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+43
-36
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+30
-29
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
...e/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
+142
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
...ude/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
+3
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+1066
-0
No files found.
example/ck_tile/01_fmha/CMakeLists.txt
View file @
74f1516c
...
...
@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# ... because they are auto-generated
if
(
FMHA_FWD_FAST_EXP2
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
endif
()
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero
)
# Allow comparing floating points directly in order to check sentinel values
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
...
...
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
74f1516c
...
...
@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = {
"alibi"
:
"bias_enum::alibi"
}
DROPOUT_MAP
=
{
"no"
:
"ck_tile::BlockDropout<false, true, false>"
,
"dropout_wg32"
:
"ck_tile::BlockDropout<true, true, false>"
,
"dropout_wg32_storerandval"
:
"ck_tile::BlockDropout<true, true, true >"
,
"dropout_wg16"
:
"ck_tile::BlockDropout<true, false, false>"
,
"dropout_wg16_storerandval"
:
"ck_tile::BlockDropout<true, false, true >"
}
DROPOUT_CHECK_MAP
=
{
"no"
:
"t.has_dropout == false"
,
"dropout_wg32"
:
"t.has_dropout == true && t.is_store_randval == false"
,
"dropout_wg32_storerandval"
:
"t.has_dropout == true && t.is_store_randval == true"
,
"dropout_wg16"
:
"t.has_dropout == true && t.is_store_randval == false"
,
"dropout_wg16_storerandval"
:
"t.has_dropout == true && t.is_store_randval == true"
,
}
MODE_MAP
=
{
"batch"
:
"false"
,
"group"
:
"true"
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
74f1516c
...
...
@@ -53,10 +53,10 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
...
...
@@ -73,6 +73,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
fmha_shape_{F_idx},
{F_mode},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
...
...
@@ -89,7 +90,7 @@ using fmha_kernel_{F_idx} =
fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx},
{F_bias
}, {F_
lse
}, {F_
dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx},
fmha_dropout_{F_idx
}, {F_
bias
}, {F_
lse
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
...
...
@@ -124,9 +125,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_FWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (
t.has_dropout ==
{F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && ({F_dropout
_check
}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse
}, {F_dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_
dropout}, {F_
bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
...
...
@@ -238,7 +239,7 @@ class FmhaFwdPipeline:
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
=
=
'
t
'
:
n
+=
'_dropout'
if
self
.
F_dropout
!
=
'
no
'
:
n
+=
f
'_
{
self
.
F_
dropout
}
'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
return
n
...
...
@@ -269,7 +270,7 @@ class FmhaFwdApiPool:
inners
=
inners
+
FMHA_FWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
=
BOOL_MAP
[
trait
.
dropout
]
,
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
]
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
...
...
@@ -344,7 +345,7 @@ class FmhaFwdKernel:
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
BOOL
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_dropout
=
DROPOUT
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
...
...
@@ -416,7 +417,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
list
(
DROPOUT_MAP
.
keys
())[:
3
]):
if
hdim
==
256
:
# if True:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
74f1516c
...
...
@@ -29,6 +29,7 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = {
FMHA_FWD_SPLITKV_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
namespace {{
template <bool kHasUnevenSplits>
...
...
@@ -51,7 +52,6 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
kHasUnevenSplits,
{F_occupancy}>;
...
...
@@ -71,6 +71,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
fmha_shape,
{F_mode},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_trait>;
using fmha_pipeline = {F_pipeline}<
...
...
@@ -98,7 +99,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
}}
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx},
{F_bias
}, {F_
lse
}, {F_
dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx},
fmha_dropout_{F_idx
}, {F_
bias
}, {F_
lse
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
...
...
@@ -224,9 +225,9 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream
}}
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (
t.has_dropout ==
{F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && ({F_dropout
_check
}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse
}, {F_dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_
dropout}, {F_
bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
...
...
@@ -267,7 +268,7 @@ class FmhaFwdSplitKVPipeline:
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
=
=
'
t
'
:
n
+=
'_dropout'
if
self
.
F_dropout
!
=
'
no
'
:
n
+=
f
'_
{
self
.
F_
dropout
}
'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
return
n
...
...
@@ -322,7 +323,7 @@ class FmhaFwdSplitKVApiPool:
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
=
BOOL_MAP
[
trait
.
dropout
]
,
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
]
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
...
...
@@ -380,7 +381,7 @@ class FmhaFwdSplitKVKernel:
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
BOOL
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_dropout
=
DROPOUT
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
...
...
@@ -531,7 +532,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
# splitkv kernel donot support dropout
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"f"
]):
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
list
(
DROPOUT_MAP
.
keys
())[:
1
]):
if
hdim
==
256
:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
...
...
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
74f1516c
...
...
@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
);
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"deterministic"
,
"0"
,
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
@@ -180,6 +184,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
bool
kname
=
arg_parser
.
get_bool
(
"kname"
);
bool
deterministic
=
arg_parser
.
get_bool
(
"deterministic"
);
ck_tile
::
stream_config
stream_config
{
nullptr
,
true
,
...
...
@@ -265,6 +270,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(
mode
==
mode_enum
::
batch
?
seqlen_q
:
seqstart_q_host
.
back
());
const
ck_tile
::
index_t
shape_seqlen_k
=
(
mode
==
mode_enum
::
batch
?
seqlen_k
:
seqstart_k_host
.
back
());
const
ck_tile
::
index_t
kN0
=
(
hdim_q
>
32
&
hdim_q
<=
128
)
?
128
:
64
;
const
ck_tile
::
index_t
nsplits
=
deterministic
?
ck_tile
::
integer_divide_ceil
(
max_seqlen_k
,
kN0
)
:
1
;
ck_tile
::
HostTensor
<
QDataType
>
q_host
(
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_q
));
...
...
@@ -302,6 +310,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias
?
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
AccDataType
>
dq_acc_host
(
i_perm
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
nsplits
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_q
}
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
nsplits
,
shape_batch
,
shape_seqlen_q
,
nhead
,
hdim_q
});
if
(
init_method
==
0
)
{
...
...
@@ -362,6 +374,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dq_acc_buf
(
dq_acc_host
.
get_element_space_size_in_bytes
());
q_buf
.
ToDevice
(
q_host
.
data
());
k_buf
.
ToDevice
(
k_host
.
data
());
...
...
@@ -387,8 +400,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cout
<<
"["
<<
prec
<<
"|"
<<
mode
<<
"|"
<<
io_layout
(
i_perm
,
o_perm
)
<<
"] b:"
<<
batch
<<
", h:"
<<
nhead
<<
"/"
<<
nhead_k
<<
", s:"
<<
seqlen_q
<<
"/"
<<
seqlen_k
<<
", d:"
<<
hdim_q
<<
"/"
<<
hdim_v
<<
", scale:"
<<
scale
<<
", bias:"
<<
bias
<<
", dbias:"
<<
use_dbias
<<
", p_drop:"
<<
p_drop
<<
", mask:"
<<
mask
<<
std
::
flush
;
<<
", dbias:"
<<
use_dbias
<<
", p_drop:"
<<
p_drop
<<
", s_randval:"
<<
s_randval
<<
", deterministic:"
<<
deterministic
<<
", mask:"
<<
mask
<<
std
::
flush
;
std
::
size_t
workspace_size
=
dq_acc_host
.
get_element_space_size_in_bytes
()
*
sizeof
(
AccDataType
)
/
(
1024
*
1024
);
if
(
deterministic
==
1
)
{
std
::
cout
<<
"
\n
Deterministic mode ON: "
<<
workspace_size
<<
" MByte memory workspace allocated"
<<
std
::
endl
;
}
auto
fmha_traits
=
fmha_bwd_traits
{
hdim_q
,
hdim_v
,
...
...
@@ -397,7 +419,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask
.
type
,
bias
.
type
,
use_dbias
,
p_drop
>
0.0
f
};
p_drop
>
0.0
f
,
s_randval
,
deterministic
};
auto
fmha_args
=
[
&
]()
{
assert
(
nhead
%
nhead_k
==
0
);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
...
...
@@ -437,6 +461,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
batch_stride_dk
=
(
nhead
*
shape_seqlen_k
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_dv
=
(
nhead
*
shape_seqlen_k
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_dbias
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
split_stride_dq_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
*
hdim_q
);
return
fmha_bwd_args
{
q_buf
.
GetDeviceBuffer
(),
k_buf
.
GetDeviceBuffer
(),
...
...
@@ -452,6 +478,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
dk_buf
.
GetDeviceBuffer
(),
dv_buf
.
GetDeviceBuffer
(),
dbias_buf
.
GetDeviceBuffer
(),
dq_acc_buf
.
GetDeviceBuffer
(),
seqstart_q
.
GetDeviceBuffer
(),
seqstart_k
.
GetDeviceBuffer
(),
nullptr
,
...
...
@@ -496,12 +523,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_dk
,
batch_stride_dv
,
batch_stride_dbias
,
split_stride_dq_acc
,
mask
.
left
,
mask
.
right
,
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
p_undrop
,
s_randval
,
{
drop_seed
,
drop_offset
}};
}();
...
...
@@ -738,6 +765,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_buf
.
ToDevice
(
lse_host
.
data
());
dq_buf
.
SetZero
();
dbias_buf
.
SetZero
();
dq_acc_buf
.
SetZero
();
ck_tile
::
stream_config
stream_config_v
{
nullptr
,
true
,
0
,
0
,
1
,
arg_parser
.
get_str
(
"timer"
)
==
std
::
string
(
"gpu"
)};
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
74f1516c
...
...
@@ -77,6 +77,7 @@ struct fmha_bwd_args
void
*
dk_ptr
;
void
*
dv_ptr
;
void
*
dbias_ptr
;
void
*
dq_acc_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
...
...
@@ -120,12 +121,12 @@ struct fmha_bwd_args
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
ck_tile
::
index_t
batch_stride_dbias
;
ck_tile
::
index_t
split_stride_dq_acc
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
float
p_undrop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
...
...
@@ -145,10 +146,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
do_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
dq_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
...
...
@@ -175,11 +176,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_lsed
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
else
...
...
@@ -192,10 +193,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
do_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
dq_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
...
...
@@ -230,11 +231,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
batch_stride_dk
,
args
.
batch_stride_dv
,
args
.
batch_stride_dbias
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}();
...
...
@@ -286,19 +287,54 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
template
<
typename
FmhaBwdConvertQGradKernel
>
auto
fmha_bwd_convert_dq_create_kargs_and_grids
(
fmha_bwd_args
args
)
{
auto
kargs
=
[
&
]
{
// create group mode kernel arguments
if
constexpr
(
FmhaBwdConvertQGradKernel
::
kIsGroupMode
)
{
return
FmhaBwdConvertQGradKernel
::
MakeKargs
(
args
.
dq_acc_ptr
,
args
.
dq_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
stride_q
,
args
.
nhead_stride_q
,
args
.
split_stride_dq_acc
);
}
else
{
// create batch mode kernel arguments
return
FmhaBwdConvertQGradKernel
::
MakeKargs
(
args
.
dq_acc_ptr
,
args
.
dq_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
stride_q
,
args
.
nhead_stride_q
,
args
.
batch_stride_q
,
args
.
split_stride_dq_acc
);
}
}();
dim3
grids
=
FmhaBwdConvertQGradKernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
max_seqlen_q
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
ck_tile
::
BlockFmhaBwdPipelineEnum
FmhaBwdPipelineEnum_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
bool
kHasBiasGrad_
,
bool
kHasDropout_
,
bool
kPadS_
,
bool
kPadSK_
,
bool
kPadD_
,
bool
kPadDv_
>
bool
kPadDv_
,
bool
kIsDeterministic_
>
struct
fmha_bwd_dq_dk_dv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
...
...
@@ -306,13 +342,14 @@ struct fmha_bwd_dq_dk_dv_traits_
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
auto
FmhaBwdPipelineEnum
=
FmhaBwdPipelineEnum_
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
FmhaDropout_
>
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
};
template
<
typename
Traits_
>
...
...
@@ -343,6 +380,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template
<
typename
Traits_
>
std
::
string
fmha_bwd_dot_do_o_get_name_
();
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
bool
kPadS_
,
bool
kPadD_
,
bool
kIsDeterministic_
>
struct
fmha_bwd_convert_dq_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
};
template
<
typename
Traits_
>
float
fmha_bwd_convert_dq_
(
const
ck_tile
::
stream_config
&
,
fmha_bwd_args
);
template
<
typename
Traits_
>
void
fmha_bwd_convert_dq_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_bwd_args
);
template
<
typename
Traits_
>
std
::
string
fmha_bwd_convert_dq_get_name_
();
// This is the public API, will be generated by script
struct
fmha_bwd_traits
{
...
...
@@ -354,6 +416,8 @@ struct fmha_bwd_traits
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_dbias
;
bool
has_dropout
;
bool
is_store_randval
;
bool
is_deterministic
;
// TODO: padding check is inside this api
};
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
74f1516c
...
...
@@ -622,6 +622,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias
.
type
,
lse
,
p_drop
>
0.0
f
,
s_randval
,
squant
};
auto
p_compute_element_func
=
[
&
]()
{
...
...
@@ -744,7 +745,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask
.
right
,
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
s_randval
,
{
drop_seed
,
drop_offset
}};
}();
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
74f1516c
...
...
@@ -143,7 +143,6 @@ struct fmha_fwd_args
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
...
...
@@ -190,7 +189,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
else
...
...
@@ -235,7 +233,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}();
...
...
@@ -292,7 +289,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
else
...
...
@@ -341,7 +337,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}();
...
...
@@ -427,9 +422,9 @@ template <ck_tile::index_t HDim_,
bool
kIsVLayoutRowMajor_
,
ck_tile
::
BlockFmhaPipelineEnum
FmhaPipelineEnum_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
bool
kStoreLse_
,
bool
kHasDropout_
,
bool
kDoFp8StaticQuant_
,
bool
kPadS_
,
bool
kPadSK_
,
...
...
@@ -449,9 +444,9 @@ struct fmha_fwd_traits_
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
auto
FmhaPipelineEnum
=
FmhaPipelineEnum_
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
FmhaDropout_
>
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
...
...
@@ -508,6 +503,7 @@ struct fmha_fwd_traits
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_lse
;
bool
has_dropout
;
bool
is_store_randval
;
bool
do_fp8_static_quant
;
// TODO: padding check is inside this api
};
...
...
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
74f1516c
...
...
@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
};
// 2D XOR, NOTE: "xor" is a keyword
template
<
typename
LowLengths
,
typename
RightShift
>
template
<
typename
LowLengths
>
struct
xor_t
:
public
base_transform
<
2
,
2
>
{
static
constexpr
auto
type_enum
=
coord_transform_enum
::
xor_t
;
...
...
@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using
UpLengths
=
LowLengths
;
UpLengths
up_lengths_
;
RightShift
right_shift_
;
CK_TILE_HOST_DEVICE
constexpr
xor_t
()
:
up_lengths_
{}
,
right_shift_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
xor_t
()
:
up_lengths_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
xor_t
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
:
up_lengths_
{
low_lengths
},
right_shift_
{
right_shift
}
{
}
CK_TILE_HOST_DEVICE
constexpr
xor_t
(
const
LowLengths
&
low_lengths
)
:
up_lengths_
{
low_lengths
}
{}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
...
...
@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}];
const
auto
idx_low_1_tmp
=
(
idx_up
[
number
<
1
>
{}]
-
idx_up
[
number
<
0
>
{}]
*
right_shift_
)
%
up_lengths_
[
number
<
1
>
{}];
const
auto
idx_low_1
=
(
idx_low_1_tmp
>=
0
)
?
idx_low_1_tmp
:
up_lengths_
[
number
<
1
>
{}]
+
idx_low_1_tmp
;
idx_low
(
number
<
1
>
{})
=
idx_low_1
;
idx_low
(
number
<
1
>
{})
=
idx_up
[
number
<
1
>
{}]
^
(
idx_up
[
number
<
0
>
{}]
%
up_lengths_
[
number
<
1
>
{}]);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
;
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
;
}
// MUST be static function
...
...
@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array
<
index_t
,
2
>
up_vector_lengths
=
low_vector_lengths
;
array
<
index_t
,
2
>
up_vector_strides
=
low_vector_strides
;
if
constexpr
(
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
)
{
if
(
low_vector_lengths
[
1
]
!=
-
1
)
{
up_vector_lengths
(
1
)
=
gcd
(
low_vector_lengths
[
1
],
abs
(
right_shift_
));
}
}
return
make_tuple
(
up_vector_lengths
,
up_vector_strides
);
}
...
...
@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"right_shift_: "
);
print
(
right_shift_
);
printf
(
"}"
);
}
};
...
...
@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return
modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
}
template
<
typename
LowLengths
,
typename
RightShift
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
template
<
typename
LowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
)
{
return
xor_t
<
LowLengths
,
RightShift
>
{
low_lengths
,
right_shift
};
return
xor_t
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
LowLength
,
typename
OffsetLength
>
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
74f1516c
...
...
@@ -746,7 +746,8 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return
make_tuple
(
make_static_tile_distribution
(
tile_distribution_encoding
<
typename
Encoding
::
RsLengths
,
decltype
(
sliced_h_lengths
),
// only need to change the
remove_cvref_t
<
decltype
(
sliced_h_lengths
)
>
,
// only need to
// change the
// h_lengths type
typename
Encoding
::
Ps2RHssMajor
,
typename
Encoding
::
Ps2RHssMinor
,
...
...
include/ck_tile/ops/fmha.hpp
View file @
74f1516c
...
...
@@ -16,13 +16,8 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
...
...
include/ck_tile/ops/fmha/block/block_dropout.hpp
View file @
74f1516c
...
...
@@ -22,20 +22,27 @@ struct NullBlockDropout
}
};
template
<
bool
IsDropout_
=
true
,
bool
IsWG32_
=
true
,
bool
IsStoreRandval_
=
false
>
struct
BlockDropout
{
static
constexpr
bool
IsDropout
=
IsDropout_
;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static
constexpr
bool
IsWG32
=
IsWG32_
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
CK_TILE_HOST_DEVICE
BlockDropout
(
index_t
i_batch
,
index_t
i_head
,
index_t
nheads
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
float
rp_undrop_
,
uint8_t
p_undrop_in_uint8_t_
,
bool
is_store_randval_
)
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
get_lane_id
()),
uint8_t
p_undrop_in_uint8_t_
)
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
(
IsWG32
?
get_lane_id
()
:
((
get_lane_id
()
&
47
)
+
((
get_warp_id
()
&
1
)
<<
4
)))),
rp_undrop
(
rp_undrop_
),
p_undrop_in_uint8_t
(
p_undrop_in_uint8_t_
),
is_store_randval
(
is_store_randval_
)
p_undrop_in_uint8_t
(
p_undrop_in_uint8_t_
)
{
}
...
...
@@ -43,6 +50,8 @@ struct BlockDropout
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
if
constexpr
(
IsDropout
)
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
...
...
@@ -72,6 +81,14 @@ struct BlockDropout
return
randval_dram_window
;
}
else
{
(
void
)
randval_dram_block_window_tmp
;
(
void
)
seqlen_qk_start
;
return
make_null_tile_window
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{}));
}
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsBlockDescriptor
()
...
...
@@ -122,16 +139,23 @@ struct BlockDropout
sequence
<
0
,
0
>>
{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
// except headdim256.
constexpr
auto
randval_block_inner_part_dstr_encoding
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
BlockGemm
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
CDataType
,
float
>
)
{
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaF16F16F32M16N16K16
::
CWarpDstrEncoding
{};
}
else
{
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaBf16Bf16F32M16N16K16
::
CWarpDstrEncoding
{};
}
}();
...
...
@@ -175,6 +199,7 @@ struct BlockDropout
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
void
*
randval_ptr
,
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
...
...
@@ -208,43 +233,6 @@ struct BlockDropout
randval_lds_window
.
get_window_origin
(),
MakeRandValLdsShuffleTileDistribution
<
BlockGemm
>
());
const
int
start_m0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
0
>
{});
if
(
is_store_randval
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
// save to Global
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
});
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
});
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
};
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
...
...
@@ -282,8 +270,23 @@ struct BlockDropout
:
PComputeDataType
(
0
);
});
});
// save to Global
if
constexpr
(
IsStoreRandval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
}
}
template
<
typename
BlockGemm
,
...
...
@@ -291,6 +294,7 @@ struct BlockDropout
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
...
...
@@ -308,25 +312,48 @@ struct BlockDropout
// register distribute
auto
randval
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
if
constexpr
(
IsWG32
)
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
else
static_assert
(
randval
.
kThreadElementSpaceSize
==
4
);
const
int
start_n0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
1
>
{});
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
int
block_row_start
,
block_col_start
;
if
constexpr
(
IsWG32
)
{
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
}
else
{
block_row_start
=
start_m0_idx
/
32
;
block_col_start
=
(
start_n0_idx
/
32
)
+
get_warp_id
()
/
2
;
}
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
uint8_t
*
random_uint8_t_
=
random_uint8_t
;
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
if
constexpr
(
!
IsWG32
)
{
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
int
start_idx
=
((
get_lane_id
()
>>
4
)
&
1
)
+
(((
start_m0_idx
>>
4
)
&
1
)
<<
1
);
uint32_t
*
random_uint32_t
=
reinterpret_cast
<
uint32_t
*>
(
random_uint8_t
);
random_uint8_t_
=
reinterpret_cast
<
uint8_t
*>
(
&
random_uint32_t
[
start_idx
]);
}
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t
[
i_random_idx
++
];
randval
(
r_idx
)
=
random_uint8_t
_
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
,
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
>
{};
...
...
@@ -337,19 +364,19 @@ struct BlockDropout
});
});
// save to Global
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
0
});
}
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerStep
});
}
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerBlock
,
-
kNPerBlock
});
}
...
...
@@ -358,7 +385,6 @@ struct BlockDropout
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
74f1516c
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
View file @
74f1516c
...
...
@@ -7,38 +7,34 @@
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaBwdTilePartitioner
template
<
ck_tile
::
index_t
kN0
>
struct
FmhaBwd
K
TilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
kN0
)
,
nhead_
,
batch_size_
);
return
dim3
(
batch_size_
,
nhead_
,
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
kN0
));
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_k*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
index_t
i_batch
=
blockIdx
.
x
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
template
<
ck_tile
::
index_t
k
BlockSize
>
struct
FmhaBwd
OGradDotO
TilePartitioner
template
<
ck_tile
::
index_t
k
M0
>
struct
FmhaBwd
Q
TilePartitioner
{
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
k
BlockSize
),
nhead_
,
batch_size_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
k
M0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
)
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
74f1516c
...
...
@@ -47,10 +47,12 @@ struct FmhaFwdKernel
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
...
...
@@ -87,7 +89,8 @@ struct FmhaFwdKernel
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -185,7 +188,6 @@ struct FmhaFwdKernel
}
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
...
...
@@ -277,7 +279,6 @@ struct FmhaFwdKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -345,11 +346,13 @@ struct FmhaFwdKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
}
return
kargs
;
...
...
@@ -392,7 +395,6 @@ struct FmhaFwdKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -458,10 +460,12 @@ struct FmhaFwdKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
}
return
kargs
;
...
...
@@ -526,7 +530,7 @@ struct FmhaFwdKernel
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
...
...
@@ -566,7 +570,7 @@ struct FmhaFwdKernel
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
...
@@ -744,28 +748,31 @@ struct FmhaFwdKernel
}
}();
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
// dropout
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
if
constexpr
(
kHasDropout
)
{
return
BlockDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
,
kargs
.
is_store_randval
};
rp_undrop
=
kargs
.
rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
}
else
{
return
NullBlockDropout
{};
};
}();
FmhaDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
);
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
74f1516c
...
...
@@ -46,10 +46,12 @@ struct FmhaFwdSplitKVKernel
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
...
...
@@ -86,7 +88,8 @@ struct FmhaFwdSplitKVKernel
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -189,7 +192,6 @@ struct FmhaFwdSplitKVKernel
}
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
...
...
@@ -282,7 +284,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -350,11 +351,13 @@ struct FmhaFwdSplitKVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
}
return
kargs
;
...
...
@@ -402,7 +405,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -469,10 +471,12 @@ struct FmhaFwdSplitKVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
}
}
return
kargs
;
...
...
@@ -536,7 +540,7 @@ struct FmhaFwdSplitKVKernel
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
...
...
@@ -571,7 +575,7 @@ struct FmhaFwdSplitKVKernel
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
...
@@ -747,7 +751,6 @@ struct FmhaFwdSplitKVKernel
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
if
constexpr
(
kHasDropout
)
{
...
...
@@ -755,21 +758,19 @@ struct FmhaFwdSplitKVKernel
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
}
Block
Dropout
dropout
(
i_batch
,
Fmha
Dropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
p_undrop_in_uint8_t
);
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
0 → 100644
View file @
74f1516c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdConvertQGrad
{
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
QGradDataType
=
remove_cvref_t
<
typename
Problem
::
QGradDataType
>
;
static
constexpr
index_t
kM0
=
Problem
::
Shape
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
Shape
::
kN0
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
Shape
::
kQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kIsDeterministic
=
Problem
::
kIsDeterministic
;
static
constexpr
index_t
kAlignmentQGradAcc
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentPostQGradAcc
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentPostQGrad
<
Problem
>();
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
// Convert only
template
<
typename
QGradAccDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
void
operator
()(
const
QGradAccDramBlockWindowTmp
&
dq_acc_dram_block_window_tmp
,
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
)
const
{
static_assert
(
std
::
is_same_v
<
AccDataType
,
remove_cvref_t
<
typename
QGradAccDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
dq_acc_dram_window
=
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGradAccDramTileDistribution
<
Problem
>());
auto
dq_acc
=
load_tile
(
dq_acc_dram_window
);
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
store_tile
(
dq_dram_block_window_tmp
,
dq
);
}
// Reduce + Convert
template
<
typename
QGradAccDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
void
operator
()(
const
QGradAccDramBlockWindowTmp
&
dq_acc_dram_block_window_tmp
,
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
index_t
nsplits
)
const
{
static_assert
(
std
::
is_same_v
<
AccDataType
,
remove_cvref_t
<
typename
QGradAccDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
dq_acc_dram_window
=
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGradAccDeterministicDramTileDistribution
<
Problem
>());
auto
dq_acc
=
decltype
(
load_tile
(
dq_acc_dram_window
)){};
clear_tile
(
dq_acc
);
constexpr
auto
dq_acc_spans
=
decltype
(
dq_acc
)
::
get_distributed_spans
();
index_t
i_total_loops
=
0
;
auto
dq_acc_buf
=
load_tile
(
dq_acc_dram_window
);
move_tile_window
(
dq_acc_dram_window
,
{
1
,
0
,
0
});
do
{
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_acc
(
n_i_j_idx
)
+=
dq_acc_buf
(
n_i_j_idx
);
});
});
});
dq_acc_buf
=
load_tile
(
dq_acc_dram_window
);
move_tile_window
(
dq_acc_dram_window
,
{
1
,
0
,
0
});
i_total_loops
+=
1
;
}
while
(
i_total_loops
<
(
nsplits
-
1
));
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_acc
(
n_i_j_idx
)
+=
dq_acc_buf
(
n_i_j_idx
);
});
});
});
// declare dq
constexpr
auto
dq_converted_dstr
=
Policy
::
template
MakePostQGradAccDeterministicDramTileDistribution
<
Problem
>();
auto
dq_converted
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_converted_dstr
);
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_converted
(
n_i_j_idx
)
=
type_convert
<
QGradDataType
>
(
dq_acc
[
n_i_j_idx
]);
});
});
});
constexpr
auto
dq_dstr
=
Policy
::
template
MakePostQGradDeterministicDramTileDistribution
<
Problem
>();
auto
dq
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_dstr
);
dq
.
get_thread_buffer
()
=
dq_converted
.
get_thread_buffer
();
store_tile
(
dq_dram_block_window_tmp
,
dq
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
View file @
74f1516c
...
...
@@ -4,11 +4,11 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dot_do_o
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
OGradDotO
DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
Pipeline
DefaultPolicy
>
struct
BlockFmhaBwdOGradDotO
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
...
...
@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
Grad
<
Problem
>();
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
deleted
100644 → 0
View file @
497ccb87
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// These templates are not used here.
using
BlockFmhaBwdOGradDotODefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
false
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
false
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_k
s
_kt
s
_vr.hpp
→
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_k
r
_kt
r
_vr.hpp
View file @
74f1516c
This diff is collapsed.
Click to expand it.
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment