Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
74f1516c
Commit
74f1516c
authored
Jul 10, 2024
by
danyao12
Browse files
tmp save
parent
497ccb87
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1972 additions
and
528 deletions
+1972
-528
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+1
-2
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+16
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+9
-8
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+9
-8
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+36
-8
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+72
-8
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+1
-1
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+3
-7
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+9
-33
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+3
-2
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+2
-7
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+105
-79
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+414
-265
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
+8
-12
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+43
-36
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+30
-29
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
...e/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
+142
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
...ude/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
+3
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
+0
-20
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+1066
-0
No files found.
example/ck_tile/01_fmha/CMakeLists.txt
View file @
74f1516c
...
...
@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# ... because they are auto-generated
if
(
FMHA_FWD_FAST_EXP2
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
endif
()
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero
)
# Allow comparing floating points directly in order to check sentinel values
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
...
...
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
74f1516c
...
...
@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = {
"alibi"
:
"bias_enum::alibi"
}
DROPOUT_MAP
=
{
"no"
:
"ck_tile::BlockDropout<false, true, false>"
,
"dropout_wg32"
:
"ck_tile::BlockDropout<true, true, false>"
,
"dropout_wg32_storerandval"
:
"ck_tile::BlockDropout<true, true, true >"
,
"dropout_wg16"
:
"ck_tile::BlockDropout<true, false, false>"
,
"dropout_wg16_storerandval"
:
"ck_tile::BlockDropout<true, false, true >"
}
DROPOUT_CHECK_MAP
=
{
"no"
:
"t.has_dropout == false"
,
"dropout_wg32"
:
"t.has_dropout == true && t.is_store_randval == false"
,
"dropout_wg32_storerandval"
:
"t.has_dropout == true && t.is_store_randval == true"
,
"dropout_wg16"
:
"t.has_dropout == true && t.is_store_randval == false"
,
"dropout_wg16_storerandval"
:
"t.has_dropout == true && t.is_store_randval == true"
,
}
MODE_MAP
=
{
"batch"
:
"false"
,
"group"
:
"true"
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
74f1516c
...
...
@@ -53,10 +53,10 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
...
...
@@ -73,6 +73,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
fmha_shape_{F_idx},
{F_mode},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
...
...
@@ -89,7 +90,7 @@ using fmha_kernel_{F_idx} =
fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx},
{F_bias
}, {F_
lse
}, {F_
dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx},
fmha_dropout_{F_idx
}, {F_
bias
}, {F_
lse
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
...
...
@@ -124,9 +125,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_FWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (
t.has_dropout ==
{F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && ({F_dropout
_check
}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse
}, {F_dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_
dropout}, {F_
bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
...
...
@@ -238,7 +239,7 @@ class FmhaFwdPipeline:
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
=
=
'
t
'
:
n
+=
'_dropout'
if
self
.
F_dropout
!
=
'
no
'
:
n
+=
f
'_
{
self
.
F_
dropout
}
'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
return
n
...
...
@@ -269,7 +270,7 @@ class FmhaFwdApiPool:
inners
=
inners
+
FMHA_FWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
=
BOOL_MAP
[
trait
.
dropout
]
,
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
]
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
...
...
@@ -344,7 +345,7 @@ class FmhaFwdKernel:
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
BOOL
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_dropout
=
DROPOUT
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
...
...
@@ -416,7 +417,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
list
(
DROPOUT_MAP
.
keys
())[:
3
]):
if
hdim
==
256
:
# if True:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
74f1516c
...
...
@@ -29,6 +29,7 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = {
FMHA_FWD_SPLITKV_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
namespace {{
template <bool kHasUnevenSplits>
...
...
@@ -51,7 +52,6 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
kHasUnevenSplits,
{F_occupancy}>;
...
...
@@ -71,6 +71,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
fmha_shape,
{F_mode},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_trait>;
using fmha_pipeline = {F_pipeline}<
...
...
@@ -98,7 +99,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
}}
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx},
{F_bias
}, {F_
lse
}, {F_
dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx},
fmha_dropout_{F_idx
}, {F_
bias
}, {F_
lse
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
...
...
@@ -224,9 +225,9 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream
}}
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (
t.has_dropout ==
{F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && ({F_dropout
_check
}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse
}, {F_dropout
}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_
dropout}, {F_
bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
...
...
@@ -267,7 +268,7 @@ class FmhaFwdSplitKVPipeline:
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
=
=
'
t
'
:
n
+=
'_dropout'
if
self
.
F_dropout
!
=
'
no
'
:
n
+=
f
'_
{
self
.
F_
dropout
}
'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
return
n
...
...
@@ -322,7 +323,7 @@ class FmhaFwdSplitKVApiPool:
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
=
BOOL_MAP
[
trait
.
dropout
]
,
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_dropout
_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
]
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
...
...
@@ -380,7 +381,7 @@ class FmhaFwdSplitKVKernel:
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
BOOL
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_dropout
=
DROPOUT
_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
...
...
@@ -531,7 +532,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
# splitkv kernel donot support dropout
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"f"
]):
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
list
(
DROPOUT_MAP
.
keys
())[:
1
]):
if
hdim
==
256
:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
...
...
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
74f1516c
...
...
@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
);
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"deterministic"
,
"0"
,
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
@@ -177,9 +181,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed
.
reset
();
}
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
bool
kname
=
arg_parser
.
get_bool
(
"kname"
);
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
bool
kname
=
arg_parser
.
get_bool
(
"kname"
);
bool
deterministic
=
arg_parser
.
get_bool
(
"deterministic"
);
ck_tile
::
stream_config
stream_config
{
nullptr
,
true
,
...
...
@@ -265,6 +270,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(
mode
==
mode_enum
::
batch
?
seqlen_q
:
seqstart_q_host
.
back
());
const
ck_tile
::
index_t
shape_seqlen_k
=
(
mode
==
mode_enum
::
batch
?
seqlen_k
:
seqstart_k_host
.
back
());
const
ck_tile
::
index_t
kN0
=
(
hdim_q
>
32
&
hdim_q
<=
128
)
?
128
:
64
;
const
ck_tile
::
index_t
nsplits
=
deterministic
?
ck_tile
::
integer_divide_ceil
(
max_seqlen_k
,
kN0
)
:
1
;
ck_tile
::
HostTensor
<
QDataType
>
q_host
(
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_q
));
...
...
@@ -302,6 +310,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias
?
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
AccDataType
>
dq_acc_host
(
i_perm
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
nsplits
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_q
}
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
nsplits
,
shape_batch
,
shape_seqlen_q
,
nhead
,
hdim_q
});
if
(
init_method
==
0
)
{
...
...
@@ -362,6 +374,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dq_acc_buf
(
dq_acc_host
.
get_element_space_size_in_bytes
());
q_buf
.
ToDevice
(
q_host
.
data
());
k_buf
.
ToDevice
(
k_host
.
data
());
...
...
@@ -387,8 +400,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cout
<<
"["
<<
prec
<<
"|"
<<
mode
<<
"|"
<<
io_layout
(
i_perm
,
o_perm
)
<<
"] b:"
<<
batch
<<
", h:"
<<
nhead
<<
"/"
<<
nhead_k
<<
", s:"
<<
seqlen_q
<<
"/"
<<
seqlen_k
<<
", d:"
<<
hdim_q
<<
"/"
<<
hdim_v
<<
", scale:"
<<
scale
<<
", bias:"
<<
bias
<<
", dbias:"
<<
use_dbias
<<
", p_drop:"
<<
p_drop
<<
", mask:"
<<
mask
<<
std
::
flush
;
<<
", dbias:"
<<
use_dbias
<<
", p_drop:"
<<
p_drop
<<
", s_randval:"
<<
s_randval
<<
", deterministic:"
<<
deterministic
<<
", mask:"
<<
mask
<<
std
::
flush
;
std
::
size_t
workspace_size
=
dq_acc_host
.
get_element_space_size_in_bytes
()
*
sizeof
(
AccDataType
)
/
(
1024
*
1024
);
if
(
deterministic
==
1
)
{
std
::
cout
<<
"
\n
Deterministic mode ON: "
<<
workspace_size
<<
" MByte memory workspace allocated"
<<
std
::
endl
;
}
auto
fmha_traits
=
fmha_bwd_traits
{
hdim_q
,
hdim_v
,
...
...
@@ -397,7 +419,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask
.
type
,
bias
.
type
,
use_dbias
,
p_drop
>
0.0
f
};
p_drop
>
0.0
f
,
s_randval
,
deterministic
};
auto
fmha_args
=
[
&
]()
{
assert
(
nhead
%
nhead_k
==
0
);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
...
...
@@ -437,6 +461,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
batch_stride_dk
=
(
nhead
*
shape_seqlen_k
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_dv
=
(
nhead
*
shape_seqlen_k
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_dbias
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
split_stride_dq_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
*
hdim_q
);
return
fmha_bwd_args
{
q_buf
.
GetDeviceBuffer
(),
k_buf
.
GetDeviceBuffer
(),
...
...
@@ -452,6 +478,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
dk_buf
.
GetDeviceBuffer
(),
dv_buf
.
GetDeviceBuffer
(),
dbias_buf
.
GetDeviceBuffer
(),
dq_acc_buf
.
GetDeviceBuffer
(),
seqstart_q
.
GetDeviceBuffer
(),
seqstart_k
.
GetDeviceBuffer
(),
nullptr
,
...
...
@@ -496,12 +523,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_dk
,
batch_stride_dv
,
batch_stride_dbias
,
split_stride_dq_acc
,
mask
.
left
,
mask
.
right
,
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
p_undrop
,
s_randval
,
{
drop_seed
,
drop_offset
}};
}();
...
...
@@ -738,6 +765,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_buf
.
ToDevice
(
lse_host
.
data
());
dq_buf
.
SetZero
();
dbias_buf
.
SetZero
();
dq_acc_buf
.
SetZero
();
ck_tile
::
stream_config
stream_config_v
{
nullptr
,
true
,
0
,
0
,
1
,
arg_parser
.
get_str
(
"timer"
)
==
std
::
string
(
"gpu"
)};
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
74f1516c
...
...
@@ -77,6 +77,7 @@ struct fmha_bwd_args
void
*
dk_ptr
;
void
*
dv_ptr
;
void
*
dbias_ptr
;
void
*
dq_acc_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
...
...
@@ -120,12 +121,12 @@ struct fmha_bwd_args
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
ck_tile
::
index_t
batch_stride_dbias
;
ck_tile
::
index_t
split_stride_dq_acc
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
float
p_undrop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
...
...
@@ -145,10 +146,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
do_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
dq_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
...
...
@@ -175,11 +176,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_lsed
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
else
...
...
@@ -192,10 +193,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
do_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
dq_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
...
...
@@ -230,11 +231,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
batch_stride_dk
,
args
.
batch_stride_dv
,
args
.
batch_stride_dbias
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}();
...
...
@@ -286,19 +287,54 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
template
<
typename
FmhaBwdConvertQGradKernel
>
auto
fmha_bwd_convert_dq_create_kargs_and_grids
(
fmha_bwd_args
args
)
{
auto
kargs
=
[
&
]
{
// create group mode kernel arguments
if
constexpr
(
FmhaBwdConvertQGradKernel
::
kIsGroupMode
)
{
return
FmhaBwdConvertQGradKernel
::
MakeKargs
(
args
.
dq_acc_ptr
,
args
.
dq_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
stride_q
,
args
.
nhead_stride_q
,
args
.
split_stride_dq_acc
);
}
else
{
// create batch mode kernel arguments
return
FmhaBwdConvertQGradKernel
::
MakeKargs
(
args
.
dq_acc_ptr
,
args
.
dq_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
stride_q
,
args
.
nhead_stride_q
,
args
.
batch_stride_q
,
args
.
split_stride_dq_acc
);
}
}();
dim3
grids
=
FmhaBwdConvertQGradKernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
max_seqlen_q
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
ck_tile
::
BlockFmhaBwdPipelineEnum
FmhaBwdPipelineEnum_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
bool
kHasBiasGrad_
,
bool
kHasDropout_
,
bool
kPadS_
,
bool
kPadSK_
,
bool
kPadD_
,
bool
kPadDv_
>
bool
kPadDv_
,
bool
kIsDeterministic_
>
struct
fmha_bwd_dq_dk_dv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
...
...
@@ -306,13 +342,14 @@ struct fmha_bwd_dq_dk_dv_traits_
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
auto
FmhaBwdPipelineEnum
=
FmhaBwdPipelineEnum_
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
FmhaDropout_
>
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
};
template
<
typename
Traits_
>
...
...
@@ -343,6 +380,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template
<
typename
Traits_
>
std
::
string
fmha_bwd_dot_do_o_get_name_
();
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
bool
kPadS_
,
bool
kPadD_
,
bool
kIsDeterministic_
>
struct
fmha_bwd_convert_dq_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
};
template
<
typename
Traits_
>
float
fmha_bwd_convert_dq_
(
const
ck_tile
::
stream_config
&
,
fmha_bwd_args
);
template
<
typename
Traits_
>
void
fmha_bwd_convert_dq_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_bwd_args
);
template
<
typename
Traits_
>
std
::
string
fmha_bwd_convert_dq_get_name_
();
// This is the public API, will be generated by script
struct
fmha_bwd_traits
{
...
...
@@ -354,6 +416,8 @@ struct fmha_bwd_traits
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_dbias
;
bool
has_dropout
;
bool
is_store_randval
;
bool
is_deterministic
;
// TODO: padding check is inside this api
};
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
74f1516c
...
...
@@ -622,6 +622,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias
.
type
,
lse
,
p_drop
>
0.0
f
,
s_randval
,
squant
};
auto
p_compute_element_func
=
[
&
]()
{
...
...
@@ -744,7 +745,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask
.
right
,
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
s_randval
,
{
drop_seed
,
drop_offset
}};
}();
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
74f1516c
...
...
@@ -143,7 +143,6 @@ struct fmha_fwd_args
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
...
...
@@ -190,7 +189,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
else
...
...
@@ -235,7 +233,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}();
...
...
@@ -292,7 +289,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
else
...
...
@@ -341,7 +337,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}();
...
...
@@ -427,9 +422,9 @@ template <ck_tile::index_t HDim_,
bool
kIsVLayoutRowMajor_
,
ck_tile
::
BlockFmhaPipelineEnum
FmhaPipelineEnum_
,
typename
FmhaMask_
,
typename
FmhaDropout_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
bool
kStoreLse_
,
bool
kHasDropout_
,
bool
kDoFp8StaticQuant_
,
bool
kPadS_
,
bool
kPadSK_
,
...
...
@@ -449,9 +444,9 @@ struct fmha_fwd_traits_
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
auto
FmhaPipelineEnum
=
FmhaPipelineEnum_
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
FmhaDropout_
>
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
...
...
@@ -508,6 +503,7 @@ struct fmha_fwd_traits
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_lse
;
bool
has_dropout
;
bool
is_store_randval
;
bool
do_fp8_static_quant
;
// TODO: padding check is inside this api
};
...
...
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
74f1516c
...
...
@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
};
// 2D XOR, NOTE: "xor" is a keyword
template
<
typename
LowLengths
,
typename
RightShift
>
template
<
typename
LowLengths
>
struct
xor_t
:
public
base_transform
<
2
,
2
>
{
static
constexpr
auto
type_enum
=
coord_transform_enum
::
xor_t
;
...
...
@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using
UpLengths
=
LowLengths
;
UpLengths
up_lengths_
;
RightShift
right_shift_
;
CK_TILE_HOST_DEVICE
constexpr
xor_t
()
:
up_lengths_
{}
,
right_shift_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
xor_t
()
:
up_lengths_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
xor_t
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
:
up_lengths_
{
low_lengths
},
right_shift_
{
right_shift
}
{
}
CK_TILE_HOST_DEVICE
constexpr
xor_t
(
const
LowLengths
&
low_lengths
)
:
up_lengths_
{
low_lengths
}
{}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
...
...
@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}];
const
auto
idx_low_1_tmp
=
(
idx_up
[
number
<
1
>
{}]
-
idx_up
[
number
<
0
>
{}]
*
right_shift_
)
%
up_lengths_
[
number
<
1
>
{}];
const
auto
idx_low_1
=
(
idx_low_1_tmp
>=
0
)
?
idx_low_1_tmp
:
up_lengths_
[
number
<
1
>
{}]
+
idx_low_1_tmp
;
idx_low
(
number
<
1
>
{})
=
idx_low_1
;
idx_low
(
number
<
1
>
{})
=
idx_up
[
number
<
1
>
{}]
^
(
idx_up
[
number
<
0
>
{}]
%
up_lengths_
[
number
<
1
>
{}]);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
;
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
;
}
// MUST be static function
...
...
@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array
<
index_t
,
2
>
up_vector_lengths
=
low_vector_lengths
;
array
<
index_t
,
2
>
up_vector_strides
=
low_vector_strides
;
if
constexpr
(
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
)
{
if
(
low_vector_lengths
[
1
]
!=
-
1
)
{
up_vector_lengths
(
1
)
=
gcd
(
low_vector_lengths
[
1
],
abs
(
right_shift_
));
}
}
return
make_tuple
(
up_vector_lengths
,
up_vector_strides
);
}
...
...
@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"right_shift_: "
);
print
(
right_shift_
);
printf
(
"}"
);
}
};
...
...
@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return
modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
}
template
<
typename
LowLengths
,
typename
RightShift
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
template
<
typename
LowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
)
{
return
xor_t
<
LowLengths
,
RightShift
>
{
low_lengths
,
right_shift
};
return
xor_t
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
LowLength
,
typename
OffsetLength
>
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
74f1516c
...
...
@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return
make_tuple
(
make_static_tile_distribution
(
tile_distribution_encoding
<
typename
Encoding
::
RsLengths
,
decltype
(
sliced_h_lengths
),
// only need to change the
// h_lengths type
remove_cvref_t
<
decltype
(
sliced_h_lengths
)
>
,
// only need to
// change the
// h_lengths type
typename
Encoding
::
Ps2RHssMajor
,
typename
Encoding
::
Ps2RHssMinor
,
typename
Encoding
::
Ys2RHsMajor
,
...
...
include/ck_tile/ops/fmha.hpp
View file @
74f1516c
...
...
@@ -16,13 +16,8 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
...
...
include/ck_tile/ops/fmha/block/block_dropout.hpp
View file @
74f1516c
...
...
@@ -22,20 +22,27 @@ struct NullBlockDropout
}
};
template
<
bool
IsDropout_
=
true
,
bool
IsWG32_
=
true
,
bool
IsStoreRandval_
=
false
>
struct
BlockDropout
{
static
constexpr
bool
IsDropout
=
IsDropout_
;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static
constexpr
bool
IsWG32
=
IsWG32_
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
CK_TILE_HOST_DEVICE
BlockDropout
(
index_t
i_batch
,
index_t
i_head
,
index_t
nheads
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
float
rp_undrop_
,
uint8_t
p_undrop_in_uint8_t_
,
bool
is_store_randval_
)
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
get_lane_id
()),
uint8_t
p_undrop_in_uint8_t_
)
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
(
IsWG32
?
get_lane_id
()
:
((
get_lane_id
()
&
47
)
+
((
get_warp_id
()
&
1
)
<<
4
)))),
rp_undrop
(
rp_undrop_
),
p_undrop_in_uint8_t
(
p_undrop_in_uint8_t_
),
is_store_randval
(
is_store_randval_
)
p_undrop_in_uint8_t
(
p_undrop_in_uint8_t_
)
{
}
...
...
@@ -44,33 +51,43 @@ struct BlockDropout
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
if
constexpr
(
IsDropout
)
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
const
auto
block_origin
=
randval_dram_block_window_tmp
.
get_window_origin
();
auto
randval_dram_window
=
[
&
]()
{
if
constexpr
(
IsFwd
)
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
block_origin
.
at
(
number
<
0
>
{}),
seqlen_qk_start
});
// M/N
}
else
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
seqlen_qk_start
,
block_origin
.
at
(
number
<
1
>
{})});
// M/N
}
}();
const
auto
block_origin
=
randval_dram_block_window_tmp
.
get_window_origin
();
auto
randval_dram_window
=
[
&
]()
{
if
constexpr
(
IsFwd
)
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
block_origin
.
at
(
number
<
0
>
{}),
seqlen_qk_start
});
// M/N
}
else
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
seqlen_qk_start
,
block_origin
.
at
(
number
<
1
>
{})});
// M/N
}
}();
return
randval_dram_window
;
}
else
{
(
void
)
randval_dram_block_window_tmp
;
(
void
)
seqlen_qk_start
;
return
randval_dram_window
;
return
make_null_tile_window
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{}));
}
}
template
<
typename
BlockGemm
>
...
...
@@ -122,16 +139,23 @@ struct BlockDropout
sequence
<
0
,
0
>>
{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
// except headdim256.
constexpr
auto
randval_block_inner_part_dstr_encoding
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
BlockGemm
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
CDataType
,
float
>
)
{
return
typename
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaF16F16F32M16N16K16
::
CWarpDstrEncoding
{};
}
else
{
return
typename
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaBf16Bf16F32M16N16K16
::
CWarpDstrEncoding
{};
}
}();
...
...
@@ -175,6 +199,7 @@ struct BlockDropout
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
void
*
randval_ptr
,
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
...
...
@@ -208,43 +233,6 @@ struct BlockDropout
randval_lds_window
.
get_window_origin
(),
MakeRandValLdsShuffleTileDistribution
<
BlockGemm
>
());
const
int
start_m0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
0
>
{});
if
(
is_store_randval
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
// save to Global
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
});
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
});
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
};
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
...
...
@@ -282,8 +270,23 @@ struct BlockDropout
:
PComputeDataType
(
0
);
});
});
// save to Global
if
constexpr
(
IsStoreRandval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
}
}
template
<
typename
BlockGemm
,
...
...
@@ -291,6 +294,7 @@ struct BlockDropout
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
...
...
@@ -308,25 +312,48 @@ struct BlockDropout
// register distribute
auto
randval
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
if
constexpr
(
IsWG32
)
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
else
static_assert
(
randval
.
kThreadElementSpaceSize
==
4
);
const
int
start_n0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
1
>
{});
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
int
block_row_start
,
block_col_start
;
if
constexpr
(
IsWG32
)
{
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
}
else
{
block_row_start
=
start_m0_idx
/
32
;
block_col_start
=
(
start_n0_idx
/
32
)
+
get_warp_id
()
/
2
;
}
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
uint8_t
*
random_uint8_t_
=
random_uint8_t
;
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
if
constexpr
(
!
IsWG32
)
{
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
int
start_idx
=
((
get_lane_id
()
>>
4
)
&
1
)
+
(((
start_m0_idx
>>
4
)
&
1
)
<<
1
);
uint32_t
*
random_uint32_t
=
reinterpret_cast
<
uint32_t
*>
(
random_uint8_t
);
random_uint8_t_
=
reinterpret_cast
<
uint8_t
*>
(
&
random_uint32_t
[
start_idx
]);
}
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t
[
i_random_idx
++
];
randval
(
r_idx
)
=
random_uint8_t
_
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
,
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
>
{};
...
...
@@ -337,19 +364,19 @@ struct BlockDropout
});
});
// save to Global
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
0
});
}
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerStep
});
}
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerBlock
,
-
kNPerBlock
});
}
...
...
@@ -358,7 +385,6 @@ struct BlockDropout
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
74f1516c
...
...
@@ -59,9 +59,12 @@ struct FmhaBwdDQDKDVKernel
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
FmhaPipeline
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
static
constexpr
bool
kIsDeterministic
=
FmhaPipeline
::
kIsDeterministic
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
...
...
@@ -94,7 +97,8 @@ struct FmhaBwdDQDKDVKernel
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
);
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -117,7 +121,7 @@ struct FmhaBwdDQDKDVKernel
const
void
*
lse_ptr
;
const
void
*
do_ptr
;
const
void
*
d_ptr
;
void
*
dq_ptr
;
void
*
dq_
acc_
ptr
;
void
*
dk_ptr
;
void
*
dv_ptr
;
...
...
@@ -131,9 +135,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
num_head_q
;
ck_tile
::
index_t
nhead_ratio_qk
;
float
raw_scale
;
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
;
#endif
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
...
...
@@ -206,7 +208,6 @@ struct FmhaBwdDQDKDVKernel
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
...
...
@@ -218,6 +219,10 @@ struct FmhaBwdDQDKDVKernel
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
struct
FmhaBwdDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdBatchModeKargs
:
FmhaBwdCommonKargs
,
...
...
@@ -228,7 +233,8 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdBatchModeBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
...
...
@@ -247,7 +253,8 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdCommonBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
...
...
@@ -266,10 +273,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
...
...
@@ -304,11 +311,11 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -317,7 +324,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dv_ptr
,
seqlen_q
,
...
...
@@ -327,9 +334,7 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
...
...
@@ -346,6 +351,7 @@ struct FmhaBwdDQDKDVKernel
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
...
...
@@ -384,11 +390,18 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
...
...
@@ -404,10 +417,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
...
...
@@ -434,11 +447,11 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -447,7 +460,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dv_ptr
,
-
1
,
// seqlen will be updated by another pointer
...
...
@@ -457,9 +470,7 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
...
...
@@ -476,6 +487,7 @@ struct FmhaBwdDQDKDVKernel
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
...
...
@@ -506,10 +518,16 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
...
...
@@ -576,7 +594,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias
=
key_start
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
...
...
@@ -618,7 +636,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dbias
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
...
@@ -646,9 +664,6 @@ struct FmhaBwdDQDKDVKernel
const
OGradDataType
*
do_ptr
=
reinterpret_cast
<
const
OGradDataType
*>
(
kargs
.
do_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
batch_offset_do
;
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
KGradDataType
*
dk_ptr
=
reinterpret_cast
<
KGradDataType
*>
(
kargs
.
dk_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_k
+
batch_offset_dk
;
...
...
@@ -663,45 +678,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
const
auto
q_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
const
auto
qt_dram_naive
=
transform_tensor_view
(
q_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
qt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
}();
const
auto
q_dram
=
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
k_ptr
,
...
...
@@ -709,45 +689,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
}();
const
auto
kt_dram_naive
=
transform_tensor_view
(
k_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_k
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
kt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
else
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
}();
const
auto
k_dram
=
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
...
...
@@ -756,20 +701,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}();
const
auto
lse_dram
=
[
&
]()
{
...
...
@@ -792,145 +727,88 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_do
,
1
),
number
<
FmhaPipeline
::
kAlignmentOGrad
>
{},
number
<
1
>
{});
const
auto
do_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
}();
const
auto
dot_dram_naive
=
transform_tensor_view
(
do_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
dot_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
const
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
const
auto
do_dram
=
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
auto
q_dram_window
=
make_tile_window
(
q_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
auto
qt_dram_window
=
make_tile_window
(
qt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{});
}(),
{
0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
i_n0
,
0
});
auto
kt_dram_window
=
make_tile_window
(
kt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{});
}(),
{
0
,
i_n0
});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
i_n0
,
0
});
auto
do_dram_window
=
make_tile_window
(
do_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
0
,
0
});
auto
dot_dram_window
=
make_tile_window
(
dot_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{});
}(),
{
0
,
0
});
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
auto
dq_dram_window
=
[
&
,
i_tile_n_
=
i_tile_n
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsDeterministic
)
{
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_q
+
static_cast
<
long_index_t
>
(
i_tile_n_
)
*
kargs
.
split_stride_dq_acc
+
batch_offset_q
;
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
hdim_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
else
{
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
}();
auto
lse_dram_window
=
make_tile_window
(
lse_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
});
...
...
@@ -1008,9 +886,7 @@ struct FmhaBwdDQDKDVKernel
// TODO: how to use s_read?
AccDataType
slope
=
*
(
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
AccDataType
,
false
>
(
slope
,
...
...
@@ -1038,7 +914,6 @@ struct FmhaBwdDQDKDVKernel
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
if
constexpr
(
kHasDropout
)
{
...
...
@@ -1047,21 +922,19 @@ struct FmhaBwdDQDKDVKernel
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
}
BlockDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
FmhaDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
);
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
...
@@ -1103,14 +976,11 @@ struct FmhaBwdDQDKDVKernel
}();
auto
[
dk_acc_tile
,
dv_acc_tile
]
=
FmhaPipeline
{}(
q_dram_window
,
qt_dram_window
,
k_dram_window
,
kt_dram_window
,
v_dram_window
,
bias_dram_window
,
randval_dram_window
,
do_dram_window
,
dot_dram_window
,
lse_dram_window
,
d_dram_window
,
dq_dram_window
,
...
...
@@ -1118,9 +988,7 @@ struct FmhaBwdDQDKDVKernel
mask
,
position_encoding
,
kargs
.
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
kargs
.
scale
,
#endif
rp_undrop
,
scale_rp_undrop
,
smem_ptr
,
...
...
@@ -1418,4 +1286,285 @@ struct FmhaBwdOGradDotOKernel
}
};
template
<
typename
TilePartitioner_
,
typename
FmhaBwdConvertQGrad_
>
struct
FmhaBwdConvertQGradKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaBwdConvertQGrad
=
ck_tile
::
remove_cvref_t
<
FmhaBwdConvertQGrad_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdConvertQGrad
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdConvertQGrad
::
kBlockPerCu
;
static
constexpr
ck_tile
::
index_t
kM0
=
FmhaBwdConvertQGrad
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
FmhaBwdConvertQGrad
::
kN0
;
static
constexpr
ck_tile
::
index_t
kQKHeaddim
=
FmhaBwdConvertQGrad
::
kQKHeaddim
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
AccDataType
>
;
using
QGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
QGradDataType
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaBwdConvertQGrad
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaBwdConvertQGrad
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaBwdConvertQGrad
::
kPadHeadDimQ
;
static
constexpr
bool
kIsDeterministic
=
FmhaBwdConvertQGrad
::
kIsDeterministic
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_bwd_convert_dq_d"
)
+
_TS_
(
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QGradDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
)
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
))
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
);
#undef _SS_
#undef _TS_
// clang-format on
}
// to avoid duplicated base class prblem, introduce an template arg
template
<
ck_tile
::
index_t
I
>
struct
FmhaBwdConvertQGradEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
FmhaBwdConvertQGradCommonKargs
{
const
void
*
dq_acc_ptr
;
void
*
dq_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
stride_dq
;
ck_tile
::
index_t
nhead_stride_dq
;
};
struct
FmhaBwdConvertQGradDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdConvertQGradBatchModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
ck_tile
::
index_t
batch_stride_dq
;
};
struct
FmhaBwdConvertQGradGroupModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqlen_k_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
FmhaBwdConvertQGradGroupModeKargs
,
FmhaBwdConvertQGradBatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
batch_stride_dq
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
stride_dq
,
nhead_stride_dq
},
{},
batch_stride_dq
};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqlen_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
//
hdim_q
,
stride_dq
,
nhead_stride_dq
},
{},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
long_index_t
batch_offset_dq
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_dq
=
query_start
*
kargs
.
stride_dq
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
kargs
.
seqlen_k
=
kargs
.
seqlen_k_ptr
[
i_batch
];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
{
return
;
}
}
else
{
batch_offset_dq
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq
;
}
// for simplicity, batch stride we just modify the pointer
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_dq
+
batch_offset_dq
;
// dQAcc/dQ DRAM and DRAM window
const
auto
dq_acc_dram
=
[
&
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsDeterministic
)
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
seqlen_q
*
kargs
.
hdim_q
)
+
batch_offset_dq
;
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
constexpr
auto
dq_fold
=
4
;
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
nsplits
,
kargs
.
seqlen_q
/
dq_fold
,
kargs
.
hdim_q
*
dq_fold
),
make_tuple
(
kargs
.
split_stride_dq_acc
,
kargs
.
hdim_q
*
dq_fold
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
/
dq_fold
>
{},
number
<
kQKHeaddim
*
dq_fold
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq
)
+
batch_offset_dq
;
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
auto
dq_acc_dram_window
=
[
&
]()
{
if
constexpr
(
kIsDeterministic
)
{
constexpr
auto
dq_fold
=
4
;
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
/
dq_fold
>
{},
number
<
kQKHeaddim
*
dq_fold
>
{}),
{
0
,
i_m0
/
dq_fold
,
0
});
}
else
{
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
}
}();
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
if
constexpr
(
kIsDeterministic
)
{
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
,
nsplits
);
}
else
{
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
View file @
74f1516c
...
...
@@ -7,38 +7,34 @@
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaBwdTilePartitioner
template
<
ck_tile
::
index_t
kN0
>
struct
FmhaBwd
K
TilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
kN0
)
,
nhead_
,
batch_size_
);
return
dim3
(
batch_size_
,
nhead_
,
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
kN0
));
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_k*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
index_t
i_batch
=
blockIdx
.
x
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
template
<
ck_tile
::
index_t
k
BlockSize
>
struct
FmhaBwd
OGradDotO
TilePartitioner
template
<
ck_tile
::
index_t
k
M0
>
struct
FmhaBwd
Q
TilePartitioner
{
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
k
BlockSize
),
nhead_
,
batch_size_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
k
M0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
)
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
74f1516c
...
...
@@ -47,10 +47,12 @@ struct FmhaFwdKernel
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
...
...
@@ -87,7 +89,8 @@ struct FmhaFwdKernel
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -185,7 +188,6 @@ struct FmhaFwdKernel
}
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
...
...
@@ -277,7 +279,6 @@ struct FmhaFwdKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -345,11 +346,13 @@ struct FmhaFwdKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
}
}
return
kargs
;
...
...
@@ -392,7 +395,6 @@ struct FmhaFwdKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -458,10 +460,12 @@ struct FmhaFwdKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
}
}
return
kargs
;
...
...
@@ -526,7 +530,7 @@ struct FmhaFwdKernel
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
...
...
@@ -566,7 +570,7 @@ struct FmhaFwdKernel
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
...
@@ -744,28 +748,31 @@ struct FmhaFwdKernel
}
}();
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kHasDropout
)
{
return
BlockDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
,
kargs
.
is_store_randval
};
}
else
{
return
NullBlockDropout
{};
};
}();
// dropout
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
}
FmhaDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
);
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
74f1516c
...
...
@@ -46,10 +46,12 @@ struct FmhaFwdSplitKVKernel
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
...
...
@@ -86,7 +88,8 @@ struct FmhaFwdSplitKVKernel
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -189,7 +192,6 @@ struct FmhaFwdSplitKVKernel
}
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
...
...
@@ -282,7 +284,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -350,11 +351,13 @@ struct FmhaFwdSplitKVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
}
}
return
kargs
;
...
...
@@ -402,7 +405,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -469,10 +471,12 @@ struct FmhaFwdSplitKVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
}
}
return
kargs
;
...
...
@@ -536,7 +540,7 @@ struct FmhaFwdSplitKVKernel
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
...
...
@@ -571,7 +575,7 @@ struct FmhaFwdSplitKVKernel
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
...
@@ -747,7 +751,6 @@ struct FmhaFwdSplitKVKernel
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
if
constexpr
(
kHasDropout
)
{
...
...
@@ -755,21 +758,19 @@ struct FmhaFwdSplitKVKernel
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
}
BlockDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
FmhaDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
);
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
0 → 100644
View file @
74f1516c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdConvertQGrad
{
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
QGradDataType
=
remove_cvref_t
<
typename
Problem
::
QGradDataType
>
;
static
constexpr
index_t
kM0
=
Problem
::
Shape
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
Shape
::
kN0
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
Shape
::
kQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kIsDeterministic
=
Problem
::
kIsDeterministic
;
static
constexpr
index_t
kAlignmentQGradAcc
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentPostQGradAcc
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentPostQGrad
<
Problem
>();
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
// Convert only
template
<
typename
QGradAccDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
void
operator
()(
const
QGradAccDramBlockWindowTmp
&
dq_acc_dram_block_window_tmp
,
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
)
const
{
static_assert
(
std
::
is_same_v
<
AccDataType
,
remove_cvref_t
<
typename
QGradAccDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
dq_acc_dram_window
=
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGradAccDramTileDistribution
<
Problem
>());
auto
dq_acc
=
load_tile
(
dq_acc_dram_window
);
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
store_tile
(
dq_dram_block_window_tmp
,
dq
);
}
// Reduce + Convert
template
<
typename
QGradAccDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
void
operator
()(
const
QGradAccDramBlockWindowTmp
&
dq_acc_dram_block_window_tmp
,
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
index_t
nsplits
)
const
{
static_assert
(
std
::
is_same_v
<
AccDataType
,
remove_cvref_t
<
typename
QGradAccDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
dq_acc_dram_window
=
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGradAccDeterministicDramTileDistribution
<
Problem
>());
auto
dq_acc
=
decltype
(
load_tile
(
dq_acc_dram_window
)){};
clear_tile
(
dq_acc
);
constexpr
auto
dq_acc_spans
=
decltype
(
dq_acc
)
::
get_distributed_spans
();
index_t
i_total_loops
=
0
;
auto
dq_acc_buf
=
load_tile
(
dq_acc_dram_window
);
move_tile_window
(
dq_acc_dram_window
,
{
1
,
0
,
0
});
do
{
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_acc
(
n_i_j_idx
)
+=
dq_acc_buf
(
n_i_j_idx
);
});
});
});
dq_acc_buf
=
load_tile
(
dq_acc_dram_window
);
move_tile_window
(
dq_acc_dram_window
,
{
1
,
0
,
0
});
i_total_loops
+=
1
;
}
while
(
i_total_loops
<
(
nsplits
-
1
));
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_acc
(
n_i_j_idx
)
+=
dq_acc_buf
(
n_i_j_idx
);
});
});
});
// declare dq
constexpr
auto
dq_converted_dstr
=
Policy
::
template
MakePostQGradAccDeterministicDramTileDistribution
<
Problem
>();
auto
dq_converted
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_converted_dstr
);
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
2
>
{}],
[
&
](
auto
idx2
)
{
constexpr
auto
n_i_j_idx
=
make_tuple
(
idx0
,
idx1
,
idx2
);
dq_converted
(
n_i_j_idx
)
=
type_convert
<
QGradDataType
>
(
dq_acc
[
n_i_j_idx
]);
});
});
});
constexpr
auto
dq_dstr
=
Policy
::
template
MakePostQGradDeterministicDramTileDistribution
<
Problem
>();
auto
dq
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_dstr
);
dq
.
get_thread_buffer
()
=
dq_converted
.
get_thread_buffer
();
store_tile
(
dq_dram_block_window_tmp
,
dq
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp
View file @
74f1516c
...
...
@@ -4,11 +4,11 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dot_do_o
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
OGradDotO
DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
Pipeline
DefaultPolicy
>
struct
BlockFmhaBwdOGradDotO
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
...
...
@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
Grad
<
Problem
>();
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp
deleted
100644 → 0
View file @
497ccb87
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace
ck_tile
{
// These templates are not used here.
using
BlockFmhaBwdOGradDotODefaultPolicy
=
BlockFmhaBwdPipelineDefaultPolicy
<
/* QLoadOnce_ = */
false
,
/* QTLoadOnce_ = */
false
,
/* KLoadOnce_ = */
false
,
/* KTLoadOnce_ = */
false
,
/* VLoadOnce_ = */
false
,
/* OGradLoadOnce_ = */
false
,
/* OGradTLoadOnce_ = */
false
>
;
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_k
s
_kt
s
_vr.hpp
→
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_k
r
_kt
r
_vr.hpp
View file @
74f1516c
...
...
@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
dq_dk_dv_pipeline_ks_kts_vr
_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_
pipeline
_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwd
DQDKDV
Pipeline
KSKTSVR
DefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineK
S
KT
S
VR
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaBwdPipelineDefaultPolicy
>
struct
BlockFmhaBwdDQDKDVPipelineK
R
KT
R
VR
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
...
...
@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
using
VGradDataType
=
remove_cvref_t
<
typename
Problem
::
VGradDataType
>
;
using
BiasGradDataType
=
remove_cvref_t
<
typename
Problem
::
BiasGradDataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaDropout
=
remove_cvref_t
<
typename
Problem
::
FmhaDropout
>
;
using
HotLoopScheduler
=
typename
Policy
::
template
HotLoopScheduler
<
Problem
>;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
...
...
@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
bool
kQLoadOnce
=
false
;
static
constexpr
bool
kQTLoadOnce
=
false
;
static
constexpr
bool
kKLoadOnce
=
true
;
static
constexpr
bool
kKTLoadOnce
=
true
;
static
constexpr
bool
kVLoadOnce
=
true
;
static
constexpr
bool
kOGradLoadOnce
=
false
;
static
constexpr
bool
kOGradTLoadOnce
=
false
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
Problem
::
kHasBiasGrad
;
static
constexpr
bool
kIsDeterministic
=
Problem
::
kIsDeterministic
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
...
...
@@ -71,12 +65,10 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentOGrad
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentQGrad
=
kPadHeadDimQ
?
2
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentKGrad
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentKGrad
<
Problem
>();
static
constexpr
index_t
kAlignmentVGrad
=
...
...
@@ -84,7 +76,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetTransposedAlignmentBias
<
Problem
>();
static
constexpr
const
char
*
name
=
"k
s
_kt
s
_vr"
;
static
constexpr
const
char
*
name
=
"k
r
_kt
r
_vr"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
...
...
@@ -92,14 +84,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
}
template
<
typename
QDramBlockWindowTmp
,
typename
QTDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
KTDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
OGradDramBlockWindowTmp
,
typename
OGradTDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
DDramBlockWindowTmp
,
typename
QGradDramBlockWindowTmp
,
...
...
@@ -107,14 +96,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
const
QTDramBlockWindowTmp
&
qt_dram_block_window_tmp
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
const
KTDramBlockWindowTmp
&
kt_dram_block_window_tmp
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
const
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
const
OGradDramBlockWindowTmp
&
do_dram_block_window_tmp
,
const
OGradTDramBlockWindowTmp
&
dot_dram_block_window_tmp
,
const
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
const
DDramBlockWindowTmp
&
d_dram_block_window_tmp
,
const
QGradDramBlockWindowTmp
&
dq_dram_block_window_tmp
,
...
...
@@ -122,43 +108,29 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
,
#endif
float
rp_undrop
,
float
scale_rp_undrop
,
void
*
smem_ptr
,
Block
Dropout
&
dropout
)
const
Fmha
Dropout
&
dropout
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
OGradDataType
,
remove_cvref_t
<
typename
OGradTDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
LSEDataType
,
remove_cvref_t
<
typename
LSEDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QGradDataType
,
remove_cvref_t
<
typename
QGradDramBlockWindowTmp
::
DataType
>>
,
std
::
is_same_v
<
DDataType
,
remove_cvref_t
<
typename
DDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
QTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kQKHeaddim
==
KTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
VDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kM0
==
OGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kVHeaddim
==
OGradTDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
LSEDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
DDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
...
...
@@ -166,83 +138,6 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kN0
==
BiasGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// QT tile in LDS
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
qt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsBlockDescriptor
<
Problem
>());
auto
qt_lds_window
=
make_tile_window
(
qt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kK3
>
{}),
{
0
,
0
});
// K tile in LDS
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
KDataType
*>
(
smem_ptr
),
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
0
});
// KT tile in LDS
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
kt_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsBlockDescriptor
<
Problem
>());
auto
kt_lds_window
=
make_tile_window
(
kt_lds
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// OGrad tile in LDS
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
// OGradT tile in LDS
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
dot_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsBlockDescriptor
<
Problem
>());
auto
dot_lds_window
=
make_tile_window
(
dot_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kK1
>
{}),
{
0
,
0
});
// SGrad tile in LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()
+
Policy
::
template
GetSmemSizeKT
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetPTOGradTBlockGemm
<
Problem
>();
...
...
@@ -250,34 +145,19 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
constexpr
auto
gemm_3
=
Policy
::
template
GetSGradTQTBlockGemm
<
Problem
>();
constexpr
auto
gemm_4
=
Policy
::
template
GetSGradKTBlockGemm
<
Problem
>();
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVInRegDramTileDistribution
<
Problem
,
decltype
(
gemm_2
)>());
auto
v
=
load_tile
(
v_dram_window
);
// persistent V register tile
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
// init VGrad & KGrad
auto
dv_acc
=
decltype
(
gemm_1
.
MakeCBlockTile
()){};
auto
dk_acc
=
decltype
(
gemm_3
.
MakeCBlockTile
()){};
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
// K, HBM ->LDS ->Reg
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
k_origin
=
k_dram_window
.
get_window_origin
();
// Early termination
const
auto
[
seqlen_q_start
,
seqlen_q_end
]
=
mask
.
GetTileRangeAlongY
(
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
...
...
@@ -290,205 +170,415 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
}
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKRegSliceBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
VDataType
*
v_lds_ptr
=
static_cast
<
VDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
store_tile
(
k_lds_window
,
k_block_tile
);
// // persistent K in LDS
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
kt_dram_block_window
=
kt_dram_block_window_tmp
;
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
kt_dram_window
=
make_tile_window
(
kt_dram_block_window
.
get_bottom_tensor_view
(),
kt_dram_block_window
.
get_window_lengths
(),
kt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKTDramTileDistribution
<
Problem
>());
// K^T DRAM tile window for
// load
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
auto
kt_block_tile
=
load_tile
(
kt_dram_window
);
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto
kt_block_tile
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKTRegWriteBlockDescriptor
<
Problem
>());
KDataType
*
kt_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeK
<
Problem
>()));
auto
kt_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsWriteBlockDescriptor
<
Problem
>());
auto
kt_lds_write_window
=
make_tile_window
(
kt_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
auto
kt_lds_read_window
=
make_tile_window
(
kt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeKTRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto
k_block_tile
=
load_tile
(
k_dram_window
);
auto
v_block_tile
=
load_tile
(
v_dram_window
);
auto
kt_shuffle_tmp
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeShuffledKTRegBlockDescriptor
<
Problem
>()
);
s
huffl
e_tile
(
kt_
shuffle_tmp
,
kt_block_tile
);
store_tile
(
k_lds_write_window
,
k_block_tile
);
shuffle_tile
(
kt_block_tile
,
k_block_tile
);
s
tor
e_tile
(
kt_
lds_write_window
,
kt_block_tile
);
store_tile
(
kt_lds_window
,
kt_shuffle_tmp
);
// persistent K^T in LDS
block_sync_lds
();
k_reg_tensor
=
load_tile
(
k_lds_read_window
);
block_sync_lds
();
auto
q_dram_block_window
=
auto
kt_reg_tensor
=
load_tile
(
kt_lds_read_window
);
store_tile
(
v_lds_write_window
,
v_block_tile
);
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
qt_dram_block_window
=
make_tile_window
(
qt_dram_block_window_tmp
.
get_bottom_tensor_view
(),
qt_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
q_lds_window
.
get_window_origin
(),
Policy
::
template
MakeQRegSliceBlockDescriptor
<
Problem
>());
auto
do_dram_block_window
=
auto
pt_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakePTRegSliceBlockDescriptor
<
Problem
>());
// QT: Reg -> Reg-> LDS
auto
qt_block_tile
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeQTRegWriteBlockDescriptor
<
Problem
>());
QDataType
*
qt_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
qt_lds_write
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsWriteBlockDescriptor
<
Problem
>());
auto
qt_lds_write_window
=
make_tile_window
(
qt_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
auto
qt_lds_read_window
=
make_tile_window
(
qt_lds_read
,
make_tuple
(
number
<
kQKHeaddim
>
{},
number
<
kM0
>
{}),
{
0
,
0
},
Policy
::
template
MakeQTRegSliceBlockDescriptor
<
Problem
>());
// dO: HBM ->Reg ->LDS
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window_tmp
.
get_bottom_tensor_view
(),
do_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
},
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
auto
dot_dram_block_window
=
make_tile_window
(
dot_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dot_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_q_start
});
OGradDataType
*
do_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()));
auto
dq_dram_block_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
auto
do_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
lse_dram_block_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
d_dram_block_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
});
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
do_lds_window
.
get_window_origin
(),
Policy
::
template
MakeOGradRegSliceBlockDescriptor
<
Problem
>());
// dOT: Reg ->Reg ->LDS
auto
dot_block_tile
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeOGradTRegWriteBlockDescriptor
<
Problem
>());
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_block_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})});
// M/N
OGradDataType
*
dot_lds_ptr
=
static_cast
<
OGradDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()));
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_block_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
dot_write_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsWriteBlockDescriptor
<
Problem
>());
auto
qt_dram_window
=
make_tile_window
(
qt_dram_block_window
.
get_bottom_tensor_view
(),
qt_dram_block_window
.
get_window_lengths
(),
qt_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQTDramTileDistribution
<
Problem
>());
auto
dot_lds_write_window
=
make_tile_window
(
dot_write_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kK2
>
{}),
{
0
,
0
});
auto
dot_dram_window
=
make_tile_window
(
dot_dram_block_window
.
get_bottom_tensor_view
(),
dot_dram_block_window
.
get_window_lengths
(),
dot_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradTDramTileDistribution
<
Problem
>());
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window
.
get_bottom_tensor_view
()
,
lse_dram_block_window
.
get_window_lengths
(
),
lse_dram_block_window
.
get_window_origin
()
,
Policy
::
template
Make
LSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)
>());
auto
dot_lds_read_window
=
make_tile_window
(
dot_read_lds
,
make_tuple
(
number
<
kVHeaddim
>
{},
number
<
kM0
>
{}
),
{
0
,
0
}
,
Policy
::
template
Make
OGradTRegSliceBlockDescriptor
<
Problem
>());
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window
.
get_bottom_tensor_view
(),
d_dram_block_window
.
get_window_lengths
(),
d_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
// dS: Reg -> Reg -> LDS
GemmDataType
*
ds_lds_ptr
=
static_cast
<
GemmDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()
+
Policy
::
template
GetSmemSizeD
<
Problem
>()));
auto
ds_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
ds_lds_ptr
,
Policy
::
template
MakeSGradLdsBlockDescriptor
<
Problem
>());
auto
ds_lds_window
=
make_tile_window
(
ds_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
ds_lds_read_window
=
make_tile_window
(
ds_lds_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kM0
>
{},
number
<
kK4
>
{}),
ds_lds_window
.
get_window_origin
(),
Policy
::
template
MakeSGradRegSliceBlockDescriptor
<
Problem
>());
auto
dst_reg_tensor
=
make_static_distributed_tensor
<
GemmDataType
>
(
Policy
::
template
MakeSGradTRegSliceBlockDescriptor
<
Problem
>());
// Bias: HBM ->Reg ->Reg ->LDS
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window
.
get_bottom_tensor_view
(),
bias_dram_block_window
.
get_window_lengths
(),
bias_dram_block_window
.
get_window_origin
()
,
make_tile_window
(
bias_dram_block_window
_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window
_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
bias_origin
.
at
(
number
<
1
>
{})}
,
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
BiasDataType
*
biast_lds_ptr
=
static_cast
<
BiasDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()));
auto
biast_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
biast_lds_ptr
,
Policy
::
template
MakeBiasTLdsBlockDescriptor
<
Problem
>());
auto
biast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
});
auto
biast_lds_window
=
make_tile_window
(
biast_lds_shuffle_window
.
get_bottom_tensor_view
(),
biast_lds_shuffle_window
.
get_window_lengths
(),
biast_lds_shuffle_window
.
get_window_origin
(),
Policy
::
template
MakeBiasTTileDistribution
<
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>
(
static_assert
(
std
::
is_same_v
<
BiasDataType
,
BiasGradDataType
>
,
"BiasDataType and BiasGradDataType should be the same!"
);
// LSE: HBM -> LDS ->Reg
auto
lse_dram_window
=
make_tile_window
(
lse_dram_block_window_tmp
.
get_bottom_tensor_view
(),
lse_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
},
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
LSEDataType
*
lse_lds_ptr
=
static_cast
<
LSEDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
lse_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
lse_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
lse_lds_write_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
lse_lds_read_window
=
make_tile_window
(
lse_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
},
Policy
::
template
MakeLSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
// D: HBM ->Reg
auto
d_dram_window
=
make_tile_window
(
d_dram_block_window_tmp
.
get_bottom_tensor_view
(),
d_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
},
Policy
::
template
MakeLSEDDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
DDataType
*
d_lds_ptr
=
static_cast
<
DDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQT
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGrad
<
Problem
>()
+
Policy
::
template
GetSmemSizeOGradT
<
Problem
>()
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()
+
Policy
::
template
GetSmemSizeLSE
<
Problem
>()));
auto
d_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
d_lds_ptr
,
Policy
::
template
MakeLSEDLdsWriteBlockDescriptor
<
Problem
>());
auto
d_lds_write_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
});
auto
d_lds_read_window
=
make_tile_window
(
d_lds
,
make_tuple
(
number
<
kM0
>
{}),
{
0
},
Policy
::
template
MakeLSEDLdsReadBlockDescriptor
<
Problem
,
decltype
(
gemm_0
)>());
// RandVal: HBM ->Reg
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
),
false
>(
randval_dram_block_window_tmp
,
seqlen_q_start
);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kM0
/
kK1
;
constexpr
index_t
k2_loops
=
kVHeaddim
/
kK2
;
constexpr
index_t
k3_loops
=
kM0
/
kK3
;
// BiasGrad
// Reg ->LDS ->Reg ->HBM
const
auto
dbias_origin
=
dbias_dram_block_window_tmp
.
get_window_origin
();
auto
dbias_dram_window
=
make_tile_window
(
dbias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dbias_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
dbias_origin
.
at
(
number
<
1
>
{})});
// M/N
auto
dbiast_lds_shuffle_window
=
make_tile_window
(
biast_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
kN0
>
{}),
{
0
,
0
},
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
// ----------------------------Loop write out------------------------------//
auto
dq_dram_window
=
make_tile_window
(
dq_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
// Deterministic mode staff
auto
dq_buffer_view
=
dq_dram_block_window_tmp
.
get_bottom_tensor_view
().
get_buffer_view
();
auto
dq_tensor_desc
=
dq_dram_block_window_tmp
.
get_bottom_tensor_view
().
get_tensor_descriptor
();
auto
seqlen_q
=
dq_tensor_desc
.
get_lengths
()[
number
<
0
>
{}];
auto
hdim_q
=
dq_tensor_desc
.
get_lengths
()[
number
<
1
>
{}];
constexpr
auto
dq_fold
=
4
;
auto
dq_write_tensor_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
seqlen_q
/
dq_fold
,
hdim_q
*
dq_fold
),
make_tuple
(
hdim_q
*
dq_fold
,
1
),
number
<
kAlignmentQGrad
>
{},
number
<
1
>
{});
auto
dq_tensor_view
=
tensor_view
<
decltype
(
dq_buffer_view
),
decltype
(
dq_write_tensor_desc
)
>
{
dq_buffer_view
,
dq_write_tensor_desc
};
auto
dq_dram_window_deterministic
=
make_tile_window
(
dq_tensor_view
,
make_tuple
(
number
<
kM0
/
dq_fold
>
{},
number
<
kQKHeaddim
*
dq_fold
>
{}),
{
seqlen_q_start
/
dq_fold
,
0
});
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
==
kK0
,
"kQKHeaddim should equal to kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
==
kK2
,
"kVHeaddim should equal to kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
do
{
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window
.
get_bottom_tensor_view
(),
q_dram_block_window
.
get_window_lengths
(),
q_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// Q DRAM tile window for
// load
auto
do_dram_window
=
make_tile_window
(
do_dram_block_window
.
get_bottom_tensor_view
(),
do_dram_block_window
.
get_window_lengths
(),
do_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeOGradDramTileDistribution
<
Problem
>());
// OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
/*
* Prefetch Q, LSE, dO, D
*/
auto
q_block_tile
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
auto
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
auto
q_block_tile
=
load_tile
(
q_dram_window
);
{
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
auto
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
clear_tile
(
st_acc
);
// Initialize S^T
auto
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write 0
q_block_tile
=
load_tile
(
q_dram_window
);
// global read 1
}
/*
* Store prefetched data into LDS
*/
store_tile
(
q_lds_window
,
q_block_tile
);
shuffle_tile
(
qt_block_tile
,
q_block_tile
);
store_tile
(
qt_lds_write_window
,
qt_block_tile
);
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kN0
,
(
i_k0
+
1
)
*
kK0
>
{}));
block_sync_lds
();
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
store_tile
(
q_lds_window
,
q_block_tile
);
// LDS write i + 1
q_block_tile
=
load_tile
(
q_dram_window
);
// global read i + 2
});
}
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
dot_block_tile
,
do_block_tile
);
store_tile
(
dot_lds_write_window
,
dot_block_tile
);
const
auto
dot_prefetch
=
load_tile
(
dot_dram_window
);
// prefetch load OGrad^T tile
{
// tail
block_sync_lds
();
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kN0
,
(
k0_loops
-
1
)
*
kK0
>
{}));
block_sync_lds
();
store_tile
(
d_lds_write_window
,
d_block_tile
);
block_sync_lds
();
store_tile
(
q_lds_window
,
q_block_tile
);
block_sync_lds
();
/*
* Prefetch LDS data into Reg to Asynchronous Data Movement and MFMA pipeline
*/
gemm_0
(
st_acc
,
q_lds_window
,
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kN0
,
k0_loops
*
kK0
>
{}));
}
auto
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
auto
lse
=
load_tile
(
lse_lds_read_window
);
auto
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
auto
d
=
load_tile
(
d_lds_read_window
);
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
__builtin_amdgcn_sched_barrier
(
0
);
// Hot loop
do
{
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
clear_tile
(
st_acc
);
q_block_tile
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
gemm_0
(
st_acc
,
q_reg_tensor
,
k_reg_tensor
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
0
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
block_sync_lds
();
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
...
...
@@ -498,11 +588,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
auto
biast_tile
=
load_tile
(
biast_lds_window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
=
raw_scale
*
x
+
type_convert
<
AccDataType
>
(
y
);
#else
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
#endif
},
st_acc
,
biast_tile
);
...
...
@@ -510,52 +596,36 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
st_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc
(
i_j_idx
)
*=
raw_scale
;
#else
st_acc
(
i_j_idx
)
*=
scale
;
#endif
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
st_acc
);
#endif
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
q_origin
=
q_dram_block_window
.
get_window_origin
();
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
st_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
const
auto
lse
=
load_tile
(
lse_dram_window
);
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
...
...
@@ -574,12 +644,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
constexpr
auto
pt_spans
=
decltype
(
pt
)
::
get_distributed_spans
();
sweep_tile_span
(
pt_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
#endif
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
...
...
@@ -589,31 +658,16 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
{
pt
(
i_j_idx
)
=
exp2
(
scale
*
st_acc
[
i_j_idx
]
-
row_lse
);
}
#else
pt
(
i_j_idx
)
=
exp
(
st_acc
[
i_j_idx
]
-
get_validated_lse
(
lse
[
i_idx
]));
#endif
});
});
auto
dot_shuffle_tmp
=
make_static_distributed_tensor
<
OGradDataType
>
(
Policy
::
template
MakeShuffledOGradTRegBlockDescriptor
<
Problem
>());
block_sync_lds
();
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
shuffle_tile
(
dot_shuffle_tmp
,
dot_prefetch
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
pt
,
randval_dram_window
);
}
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
if
constexpr
(
kHasDropout
)
{
dropout
.
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>
(
seqlen_q_start
+
i_total_loops
*
kM0
,
pt
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
...
...
@@ -625,87 +679,37 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
}
}();
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
dot
=
load_tile
(
dot_dram_window
);
// load next OGrad^T
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
i_k1
*
kK1
,
0
>
{},
sequence
<
(
i_k1
+
1
)
*
kK1
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
shuffle_tile
(
dot_shuffle_tmp
,
dot
);
store_tile
(
dot_lds_window
,
dot_shuffle_tmp
);
// store the prefetch
move_tile_window
(
dot_dram_window
,
{
0
,
kK1
});
});
}
auto
do_block_tile
=
load_tile
(
do_dram_window
);
// prefetch load OGrad tile
// tail
{
block_sync_lds
();
gemm_1
(
dv_acc
,
get_slice_tile
(
pt_gemm
,
sequence
<
(
k1_loops
-
1
)
*
kK1
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
dot_lds_window
);
block_sync_lds
();
}
// STAGE 3, P^T@OGrad^T Gemm1
pt_reg_tensor
.
get_thread_buffer
()
=
pt_gemm
.
get_thread_buffer
();
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
clear_tile
(
dpt_acc
);
{
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
gemm_2
(
dpt_acc
,
do_reg_tensor
,
v_reg_tensor
);
clear_tile
(
dpt_acc
);
// Initialize PGrad^T
block_sync_lds
();
store_tile
(
do
_lds_window
,
do
_block_tile
);
// LDS write 0
do_block_tile
=
load_tile
(
do_dram_window
);
// global read 1
}
store_tile
(
q
_lds_window
,
q
_block_tile
);
shuffle_tile
(
qt_block_tile
,
q_block_tile
);
store_tile
(
qt_lds_write_window
,
qt_block_tile
);
if
constexpr
(
k2_loops
>
2
)
{
static_for
<
0
,
k2_loops
-
2
,
1
>
{}([
&
](
auto
i_k2
)
{
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
i_k2
*
kK2
>
{},
sequence
<
kN0
,
(
i_k2
+
1
)
*
kK2
>
{}));
block_sync_lds
();
move_tile_window
(
do_dram_window
,
{
0
,
kK2
});
store_tile
(
do_lds_window
,
do_block_tile
);
// LDS write i + 1
do_block_tile
=
load_tile
(
do_dram_window
);
// global read i + 2
});
}
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
const
auto
qt_prefetch
=
load_tile
(
qt_dram_window
);
// prefetch load Q^T tile
{
// tail
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
2
)
*
kK2
>
{},
sequence
<
kN0
,
(
k2_loops
-
1
)
*
kK2
>
{}));
block_sync_lds
();
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
dot_block_tile
,
do_block_tile
);
store_tile
(
dot_lds_write_window
,
dot_block_tile
);
store_tile
(
do_lds_window
,
do_block_tile
);
block_sync_lds
();
gemm_2
(
dpt_acc
,
do_lds_window
,
get_slice_tile
(
v
,
sequence
<
0
,
(
k2_loops
-
1
)
*
kK2
>
{},
sequence
<
kN0
,
k2_loops
*
kK2
>
{}));
}
store_tile
(
d_lds_write_window
,
d_block_tile
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 5, P^T(PGrad^T - D)
const
auto
d
=
load_tile
(
d_dram_window
);
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
...
...
@@ -713,16 +717,16 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
sweep_tile_span
(
dst_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
pt
[
i_j_idx
]
>=
0
;
dst
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
kHasDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
dst
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbiast
=
[
&
]()
{
if
constexpr
(
kHa
sDropout
)
if
constexpr
(
FmhaDropout
::
I
sDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
...
...
@@ -741,107 +745,321 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
auto
dbiast_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_
block_
window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_
block_
window
,
{
kM0
,
0
});
store_tile
(
dbias_dram_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_window
,
{
kM0
,
0
});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_shuffle_tmp
=
make_static_distributed_tensor
<
QDataType
>
(
Policy
::
template
MakeShuffledQTRegBlockDescriptor
<
Problem
>());
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
dst_reg_tensor
.
get_thread_buffer
()
=
dst_gemm
.
get_thread_buffer
();
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
dst_gemm
);
block_sync_lds
();
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
lse
=
load_tile
(
lse_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE7 SGrad@K^T
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
}
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
();
}
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
-
kN0
});
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
d
=
load_tile
(
d_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
// QGrad Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
}
else
{
shuffle_tile
(
qt_shuffle_tmp
,
qt_prefetch
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
}
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
if
constexpr
(
kIsDeterministic
)
{
auto
dq_write_reg_tensor
=
make_static_distributed_tensor
<
AccDataType
>
(
Policy
::
template
MakeQGradWriteBlockDescriptor
<
Problem
>());
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
dq_write_reg_tensor
.
get_thread_buffer
()
=
dq_acc
.
get_thread_buffer
(
);
if
constexpr
(
k3_loops
>
1
)
{
static_for
<
0
,
k3_loops
-
1
,
1
>
{}([
&
](
auto
i_k3
)
{
const
auto
qt
=
load_tile
(
qt_dram_window
);
// load next Q^T
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
i_k3
*
kK3
,
0
>
{},
sequence
<
(
i_k3
+
1
)
*
kK3
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
shuffle_tile
(
qt_shuffle_tmp
,
qt
);
store_tile
(
qt_lds_window
,
qt_shuffle_tmp
);
// store the prefetch
move_tile_window
(
qt_dram_window
,
{
0
,
kK3
});
});
store_tile
(
dq_dram_window_deterministic
,
dq_write_reg_tensor
);
move_tile_window
(
dq_dram_window_deterministic
,
{
kM0
/
dq_fold
,
0
});
}
// tail
else
{
block_sync_lds
();
gemm_3
(
dk_acc
,
get_slice_tile
(
dst_gemm
,
sequence
<
(
k3_loops
-
1
)
*
kK3
,
0
>
{},
sequence
<
kM0
,
kN0
>
{}),
qt_lds_window
);
block_sync_lds
();
update_tile
(
dq_dram_window
,
dq_acc
);
move_tile_window
(
dq_dram_window
,
{
kM0
,
0
});
}
// STAGE 7, SGrad@K^T Gemm4
store_tile
(
ds_lds_window
,
dst_gemm
);
i_total_loops
+=
1
;
seqlen_q_step
+=
kM0
;
}
while
(
i_total_loops
<
(
num_total_loop
-
1
));
__builtin_amdgcn_sched_barrier
(
0
);
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
// Initialize QGrad
// Tail
auto
st_acc
=
SPTBlockTileType
{};
clear_tile
(
st_acc
);
gemm_0
(
st_acc
,
q_reg_tensor
,
k_reg_tensor
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
block_sync_lds
();
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_shuffle_tmp
,
bias_tile
);
store_tile
(
biast_lds_shuffle_window
,
bias_shuffle_tmp
);
block_sync_lds
();
auto
biast_tile
=
load_tile
(
biast_lds_window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
},
st_acc
,
biast_tile
);
move_tile_window
(
bias_dram_window
,
{
kM0
,
0
});
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
st_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
gemm_4
(
dq_acc
,
get_slice_tile
(
ds_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kM0
,
(
i_k4
+
1
)
*
kK4
>
{}),
get_slice_tile
(
kt_lds_window
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{}));
st_acc
(
i_j_idx
)
*=
scale
;
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
});
});
}
// QGrad Scale
if
constexpr
(
kHasDropout
)
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
set_tile_if
(
st_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_lse
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_lse
;
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
)
;
return
raw_lse
;
}
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
update_tile
(
dq_dram_block_window
,
dq
);
};
// move tile windows
move_tile_window
(
q_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
dq_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
do_dram_block_window
,
{
kM0
,
0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
}
while
(
++
i_total_loops
<
num_total_loop
);
auto
pt
=
SPTBlockTileType
{};
constexpr
auto
pt_spans
=
decltype
(
pt
)
::
get_distributed_spans
();
sweep_tile_span
(
pt_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
pt
(
i_j_idx
)
=
exp2
(
st_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
pt
(
i_j_idx
)
=
exp2
(
scale
*
st_acc
[
i_j_idx
]
-
row_lse
);
}
});
});
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
pt
,
randval_dram_window
);
}
// KGrad Scale
if
constexpr
(
kHasDropout
)
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
pt
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
pt
);
}
}();
pt_reg_tensor
.
get_thread_buffer
()
=
pt_gemm
.
get_thread_buffer
();
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
clear_tile
(
dpt_acc
);
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
gemm_2
(
dpt_acc
,
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
// STAGE 5, P^T(PGrad^T - D)
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
dst_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
pt
[
i_j_idx
]
>=
0
;
dst
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbiast
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
dst
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
dst
);
}
}();
store_tile
(
biast_lds_shuffle_window
,
dbiast
);
block_sync_lds
();
auto
dbiast_tile
=
load_tile
(
dbiast_lds_shuffle_window
);
auto
dbiast_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_window
,
dbiast_shuffle_tmp
);
move_tile_window
(
dbias_dram_window
,
{
kM0
,
0
});
}
// STAGE 6, SGrad^T@Q^T Gemm3
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
dst_reg_tensor
.
get_thread_buffer
()
=
dst_gemm
.
get_thread_buffer
();
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
dst_gemm
);
block_sync_lds
();
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
// STAGE 7, SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
}
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
();
}
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
// Results Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
else
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
// VGrad Scale
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsDeterministic
)
{
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
auto
dq_write_reg_tensor
=
make_static_distributed_tensor
<
AccDataType
>
(
Policy
::
template
MakeQGradWriteBlockDescriptor
<
Problem
>());
dq_write_reg_tensor
.
get_thread_buffer
()
=
dq_acc
.
get_thread_buffer
();
store_tile
(
dq_dram_window_deterministic
,
dq_write_reg_tensor
);
}
else
{
update_tile
(
dq_dram_window
,
dq_acc
);
}
return
ck_tile
::
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
};
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment