Commit 74f1516c authored by danyao12's avatar danyao12
Browse files

tmp save

parent 497ccb87
...@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) ...@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# ... because they are auto-generated # ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2) if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else() else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif() endif()
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
# Allow comparing floating points directly in order to check sentinel values # Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
......
...@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = { ...@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = {
"alibi" : "bias_enum::alibi" "alibi" : "bias_enum::alibi"
} }
DROPOUT_MAP = {
"no" : "ck_tile::BlockDropout<false, true, false>",
"dropout_wg32" : "ck_tile::BlockDropout<true, true, false>",
"dropout_wg32_storerandval" : "ck_tile::BlockDropout<true, true, true >",
"dropout_wg16" : "ck_tile::BlockDropout<true, false, false>",
"dropout_wg16_storerandval" : "ck_tile::BlockDropout<true, false, true >"
}
DROPOUT_CHECK_MAP = {
"no" : "t.has_dropout == false",
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
}
MODE_MAP = { MODE_MAP = {
"batch" : "false", "batch" : "false",
"group" : "true" "group" : "true"
......
...@@ -53,10 +53,10 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, ...@@ -53,10 +53,10 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_bias}, {F_bias},
false, false,
{F_lse}, {F_lse},
{F_dropout},
{F_squant}, {F_squant},
{F_occupancy}>; {F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask}; using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
...@@ -73,6 +73,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< ...@@ -73,6 +73,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
fmha_shape_{F_idx}, fmha_shape_{F_idx},
{F_mode}, {F_mode},
fmha_mask_{F_idx}, fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_trait_{F_idx}>; fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}< using fmha_pipeline_{F_idx} = {F_pipeline}<
...@@ -89,7 +90,7 @@ using fmha_kernel_{F_idx} = ...@@ -89,7 +90,7 @@ using fmha_kernel_{F_idx} =
fmha_epilogue_{F_idx}>; fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; {F_pipeline_enum}, fmha_mask_{F_idx}, fmha_dropout_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream> #include <iostream>
...@@ -124,9 +125,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < ...@@ -124,9 +125,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}} }}
""" """
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && ({F_dropout_check}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a); return fmha_fwd_<trait_>(s, a);
}} }}
""" """
...@@ -238,7 +239,7 @@ class FmhaFwdPipeline: ...@@ -238,7 +239,7 @@ class FmhaFwdPipeline:
else: else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse' if self.F_lse == 't' : n += '_lse'
if self.F_dropout == 't' : n += '_dropout' if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
if self.F_squant == 't' : n += '_squant' if self.F_squant == 't' : n += '_squant'
return n return n
...@@ -269,7 +270,7 @@ class FmhaFwdApiPool: ...@@ -269,7 +270,7 @@ class FmhaFwdApiPool:
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_lse=BOOL_MAP[trait.lse], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
...@@ -344,7 +345,7 @@ class FmhaFwdKernel: ...@@ -344,7 +345,7 @@ class FmhaFwdKernel:
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_dropout = DROPOUT_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy, F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
...@@ -416,7 +417,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -416,7 +417,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
squant = 't' if dtype == 'fp8' else 'f' squant = 't' if dtype == 'fp8' else 'f'
pipelines = [] pipelines = []
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], list(DROPOUT_MAP.keys())[:3]):
if hdim == 256: if hdim == 256:
# if True: # if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
......
...@@ -29,6 +29,7 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = { ...@@ -29,6 +29,7 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = {
FMHA_FWD_SPLITKV_KERNEL_BODY=""" FMHA_FWD_SPLITKV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype}; using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask}; using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
namespace {{ namespace {{
template <bool kHasUnevenSplits> template <bool kHasUnevenSplits>
...@@ -51,7 +52,6 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, ...@@ -51,7 +52,6 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_bias}, {F_bias},
false, false,
{F_lse}, {F_lse},
{F_dropout},
{F_squant}, {F_squant},
kHasUnevenSplits, kHasUnevenSplits,
{F_occupancy}>; {F_occupancy}>;
...@@ -71,6 +71,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< ...@@ -71,6 +71,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
fmha_shape, fmha_shape,
{F_mode}, {F_mode},
fmha_mask_{F_idx}, fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_trait>; fmha_trait>;
using fmha_pipeline = {F_pipeline}< using fmha_pipeline = {F_pipeline}<
...@@ -98,7 +99,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a) ...@@ -98,7 +99,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
}} }}
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; {F_pipeline_enum}, fmha_mask_{F_idx}, fmha_dropout_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream> #include <iostream>
...@@ -224,9 +225,9 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream ...@@ -224,9 +225,9 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream
}} }}
""" """
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && ({F_dropout_check}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a); return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
...@@ -267,7 +268,7 @@ class FmhaFwdSplitKVPipeline: ...@@ -267,7 +268,7 @@ class FmhaFwdSplitKVPipeline:
else: else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse' if self.F_lse == 't' : n += '_lse'
if self.F_dropout == 't' : n += '_dropout' if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
if self.F_squant == 't' : n += '_squant' if self.F_squant == 't' : n += '_squant'
return n return n
...@@ -322,7 +323,7 @@ class FmhaFwdSplitKVApiPool: ...@@ -322,7 +323,7 @@ class FmhaFwdSplitKVApiPool:
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_lse=BOOL_MAP[trait.lse], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
...@@ -380,7 +381,7 @@ class FmhaFwdSplitKVKernel: ...@@ -380,7 +381,7 @@ class FmhaFwdSplitKVKernel:
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_dropout = DROPOUT_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy, F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
...@@ -531,7 +532,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -531,7 +532,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
pipelines = [] pipelines = []
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
# splitkv kernel donot support dropout # splitkv kernel donot support dropout
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"]): for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], list(DROPOUT_MAP.keys())[:1]):
if hdim == 256: if hdim == 256:
# if True: # if True:
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
......
...@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[]) ...@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
.insert("drop_offset", "0", "offset for random number generator") .insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel"); .insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("deterministic",
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
...@@ -177,9 +181,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -177,9 +181,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed.reset(); seed.reset();
} }
int stream_warmup = arg_parser.get_int("warmup"); int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat"); int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname"); bool kname = arg_parser.get_bool("kname");
bool deterministic = arg_parser.get_bool("deterministic");
ck_tile::stream_config stream_config{nullptr, ck_tile::stream_config stream_config{nullptr,
true, true,
...@@ -265,6 +270,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -265,6 +270,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k = const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
const ck_tile::index_t kN0 = (hdim_q > 32 & hdim_q <= 128) ? 128 : 64;
const ck_tile::index_t nsplits =
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
ck_tile::HostTensor<QDataType> q_host( ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
...@@ -302,6 +310,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -302,6 +310,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
if(init_method == 0) if(init_method == 0)
{ {
...@@ -362,6 +374,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -362,6 +374,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data()); q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data()); k_buf.ToDevice(k_host.data());
...@@ -387,8 +400,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -387,8 +400,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval
<< std::flush; << ", deterministic:" << deterministic << ", mask:" << mask << std::flush;
std::size_t workspace_size =
dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024);
if(deterministic == 1)
{
std::cout << "\nDeterministic mode ON: " << workspace_size
<< " MByte memory workspace allocated" << std::endl;
}
auto fmha_traits = fmha_bwd_traits{hdim_q, auto fmha_traits = fmha_bwd_traits{hdim_q,
hdim_v, hdim_v,
...@@ -397,7 +419,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -397,7 +419,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.type, mask.type,
bias.type, bias.type,
use_dbias, use_dbias,
p_drop > 0.0f}; p_drop > 0.0f,
s_randval,
deterministic};
auto fmha_args = [&]() { auto fmha_args = [&]() {
assert(nhead % nhead_k == 0); assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
...@@ -437,6 +461,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -437,6 +461,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
return fmha_bwd_args{q_buf.GetDeviceBuffer(), return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
...@@ -452,6 +478,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -452,6 +478,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
dk_buf.GetDeviceBuffer(), dk_buf.GetDeviceBuffer(),
dv_buf.GetDeviceBuffer(), dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(), dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(),
nullptr, nullptr,
...@@ -496,12 +523,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -496,12 +523,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_dk, batch_stride_dk,
batch_stride_dv, batch_stride_dv,
batch_stride_dbias, batch_stride_dbias,
split_stride_dq_acc,
mask.left, mask.left,
mask.right, mask.right,
static_cast<ck_tile::index_t>(mask.type), static_cast<ck_tile::index_t>(mask.type),
p_drop, p_drop,
p_undrop, p_undrop,
s_randval,
{drop_seed, drop_offset}}; {drop_seed, drop_offset}};
}(); }();
...@@ -738,6 +765,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -738,6 +765,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_buf.ToDevice(lse_host.data()); lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero(); dq_buf.SetZero();
dbias_buf.SetZero(); dbias_buf.SetZero();
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{ ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
......
...@@ -77,6 +77,7 @@ struct fmha_bwd_args ...@@ -77,6 +77,7 @@ struct fmha_bwd_args
void* dk_ptr; void* dk_ptr;
void* dv_ptr; void* dv_ptr;
void* dbias_ptr; void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
const void* seqstart_k_ptr; const void* seqstart_k_ptr;
const void* seqlen_k_ptr; const void* seqlen_k_ptr;
...@@ -120,12 +121,12 @@ struct fmha_bwd_args ...@@ -120,12 +121,12 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias; ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left; ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right; ck_tile::index_t window_size_right;
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
float p_drop; float p_drop;
float p_undrop; float p_undrop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset; std::tuple<uint64_t, uint64_t> drop_seed_offset;
}; };
...@@ -145,10 +146,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -145,10 +146,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.seqlen_k_ptr, args.seqlen_k_ptr,
...@@ -175,11 +176,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -175,11 +176,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_lsed, args.batch_stride_lsed,
args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
else else
...@@ -192,10 +193,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -192,10 +193,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
...@@ -230,11 +231,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -230,11 +231,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_dk, args.batch_stride_dk,
args.batch_stride_dv, args.batch_stride_dv,
args.batch_stride_dbias, args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
}(); }();
...@@ -286,19 +287,54 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) ...@@ -286,19 +287,54 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
} }
template <typename FmhaBwdConvertQGradKernel>
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqstart_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.stride_q,
args.nhead_stride_q,
args.split_stride_dq_acc);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_q,
args.nhead_stride_q,
args.batch_stride_q,
args.split_stride_dq_acc);
}
}();
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_, template <ck_tile::index_t HDim_,
typename DataType_, typename DataType_,
bool kIsGroupMode_, bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_, ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_, ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_, bool kHasBiasGrad_,
bool kHasDropout_,
bool kPadS_, bool kPadS_,
bool kPadSK_, bool kPadSK_,
bool kPadD_, bool kPadD_,
bool kPadDv_> bool kPadDv_,
bool kIsDeterministic_>
struct fmha_bwd_dq_dk_dv_traits_ struct fmha_bwd_dq_dk_dv_traits_
{ {
static constexpr ck_tile::index_t HDim = HDim_; static constexpr ck_tile::index_t HDim = HDim_;
...@@ -306,13 +342,14 @@ struct fmha_bwd_dq_dk_dv_traits_ ...@@ -306,13 +342,14 @@ struct fmha_bwd_dq_dk_dv_traits_
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>; using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kPadS = kPadS_; static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_; static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_; static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
}; };
template <typename Traits_> template <typename Traits_>
...@@ -343,6 +380,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); ...@@ -343,6 +380,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_> template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_(); std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct fmha_bwd_traits struct fmha_bwd_traits
{ {
...@@ -354,6 +416,8 @@ struct fmha_bwd_traits ...@@ -354,6 +416,8 @@ struct fmha_bwd_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias; bool has_dbias;
bool has_dropout; bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
...@@ -622,6 +622,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -622,6 +622,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias.type, bias.type,
lse, lse,
p_drop > 0.0f, p_drop > 0.0f,
s_randval,
squant}; squant};
auto p_compute_element_func = [&]() { auto p_compute_element_func = [&]() {
...@@ -744,7 +745,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -744,7 +745,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.right, mask.right,
static_cast<ck_tile::index_t>(mask.type), static_cast<ck_tile::index_t>(mask.type),
p_drop, p_drop,
s_randval,
{drop_seed, drop_offset}}; {drop_seed, drop_offset}};
}(); }();
......
...@@ -143,7 +143,6 @@ struct fmha_fwd_args ...@@ -143,7 +143,6 @@ struct fmha_fwd_args
ck_tile::index_t window_size_right; ck_tile::index_t window_size_right;
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
float p_drop; float p_drop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset; std::tuple<uint64_t, uint64_t> drop_seed_offset;
}; };
...@@ -190,7 +189,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -190,7 +189,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
else else
...@@ -235,7 +233,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -235,7 +233,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
}(); }();
...@@ -292,7 +289,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -292,7 +289,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
else else
...@@ -341,7 +337,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -341,7 +337,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
}(); }();
...@@ -427,9 +422,9 @@ template <ck_tile::index_t HDim_, ...@@ -427,9 +422,9 @@ template <ck_tile::index_t HDim_,
bool kIsVLayoutRowMajor_, bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_, ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_, ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_, bool kStoreLse_,
bool kHasDropout_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
bool kPadS_, bool kPadS_,
bool kPadSK_, bool kPadSK_,
...@@ -449,9 +444,9 @@ struct fmha_fwd_traits_ ...@@ -449,9 +444,9 @@ struct fmha_fwd_traits_
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>; using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_; static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadSK = kPadSK_;
...@@ -508,6 +503,7 @@ struct fmha_fwd_traits ...@@ -508,6 +503,7 @@ struct fmha_fwd_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse; bool has_lse;
bool has_dropout; bool has_dropout;
bool is_store_randval;
bool do_fp8_static_quant; bool do_fp8_static_quant;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
......
...@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1> ...@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
}; };
// 2D XOR, NOTE: "xor" is a keyword // 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths, typename RightShift> template <typename LowLengths>
struct xor_t : public base_transform<2, 2> struct xor_t : public base_transform<2, 2>
{ {
static constexpr auto type_enum = coord_transform_enum::xor_t; static constexpr auto type_enum = coord_transform_enum::xor_t;
...@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2> ...@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using UpLengths = LowLengths; using UpLengths = LowLengths;
UpLengths up_lengths_; UpLengths up_lengths_;
RightShift right_shift_;
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {} CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths, CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
const RightShift& right_shift)
: up_lengths_{low_lengths}, right_shift_{right_shift}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{ {
...@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2> ...@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low(number<0>{}) = idx_up[number<0>{}]; idx_low(number<0>{}) = idx_up[number<0>{}];
const auto idx_low_1_tmp = idx_low(number<1>{}) =
(idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}]; idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
const auto idx_low_1 =
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp;
idx_low(number<1>{}) = idx_low_1;
} }
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx> template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
...@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2> ...@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{ {
return ck_tile::is_known_at_compile_time<UpLengths>::value && return ck_tile::is_known_at_compile_time<UpLengths>::value;
ck_tile::is_known_at_compile_time<RightShift>::value;
} }
// MUST be static function // MUST be static function
...@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2> ...@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array<index_t, 2> up_vector_lengths = low_vector_lengths; array<index_t, 2> up_vector_lengths = low_vector_lengths;
array<index_t, 2> up_vector_strides = low_vector_strides; array<index_t, 2> up_vector_strides = low_vector_strides;
if constexpr(ck_tile::is_known_at_compile_time<RightShift>::value)
{
if(low_vector_lengths[1] != -1)
{
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
}
}
return make_tuple(up_vector_lengths, up_vector_strides); return make_tuple(up_vector_lengths, up_vector_strides);
} }
...@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2> ...@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
//
printf("right_shift_: ");
print(right_shift_);
printf("}"); printf("}");
} }
}; };
...@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus, ...@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return modulo<Modulus, UpLength>{modulus, up_length}; return modulo<Modulus, UpLength>{modulus, up_length};
} }
template <typename LowLengths, typename RightShift> template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths, CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
const RightShift& right_shift)
{ {
return xor_t<LowLengths, RightShift>{low_lengths, right_shift}; return xor_t<LowLengths>{low_lengths};
} }
template <typename LowLength, typename OffsetLength> template <typename LowLength, typename OffsetLength>
......
...@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( ...@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return make_tuple( return make_tuple(
make_static_tile_distribution( make_static_tile_distribution(
tile_distribution_encoding<typename Encoding::RsLengths, tile_distribution_encoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
// h_lengths type // change the
// h_lengths type
typename Encoding::Ps2RHssMajor, typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor, typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor, typename Encoding::Ys2RHsMajor,
......
...@@ -16,13 +16,8 @@ ...@@ -16,13 +16,8 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
......
...@@ -22,20 +22,27 @@ struct NullBlockDropout ...@@ -22,20 +22,27 @@ struct NullBlockDropout
} }
}; };
template <bool IsDropout_ = true, bool IsWG32_ = true, bool IsStoreRandval_ = false>
struct BlockDropout struct BlockDropout
{ {
static constexpr bool IsDropout = IsDropout_;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static constexpr bool IsWG32 = IsWG32_;
static constexpr bool IsStoreRandval = IsStoreRandval_;
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch,
index_t i_head, index_t i_head,
index_t nheads, index_t nheads,
unsigned long long seed, unsigned long long seed,
unsigned long long offset, unsigned long long offset,
float rp_undrop_, float rp_undrop_,
uint8_t p_undrop_in_uint8_t_, uint8_t p_undrop_in_uint8_t_)
bool is_store_randval_) : ph(seed,
: ph(seed, offset + (i_batch * nheads + i_head) * get_warp_size() + get_lane_id()), offset + (i_batch * nheads + i_head) * get_warp_size() +
(IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))),
rp_undrop(rp_undrop_), rp_undrop(rp_undrop_),
p_undrop_in_uint8_t(p_undrop_in_uint8_t_), p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
is_store_randval(is_store_randval_)
{ {
} }
...@@ -44,33 +51,43 @@ struct BlockDropout ...@@ -44,33 +51,43 @@ struct BlockDropout
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start) index_t seqlen_qk_start)
{ {
constexpr auto config = if constexpr(IsDropout)
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>(); {
using WG = remove_cvref_t<decltype(config.template at<0>())>; constexpr auto config =
constexpr index_t MWarp = config.template at<1>(); BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
constexpr index_t NWarp = config.template at<2>(); using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kMPerStep = MWarp * WG::kM; constexpr index_t MWarp = config.template at<1>();
constexpr index_t kNPerStep = NWarp * WG::kN; constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
auto randval_dram_window = [&]() {
if constexpr(IsFwd)
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
}
else
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
}
}();
const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); return randval_dram_window;
auto randval_dram_window = [&]() { }
if constexpr(IsFwd) else
{ {
return make_tile_window( (void)randval_dram_block_window_tmp;
randval_dram_block_window_tmp.get_bottom_tensor_view(), (void)seqlen_qk_start;
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
}
else
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
}
}();
return randval_dram_window; return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
}
} }
template <typename BlockGemm> template <typename BlockGemm>
...@@ -122,16 +139,23 @@ struct BlockDropout ...@@ -122,16 +139,23 @@ struct BlockDropout
sequence<0, 0>>{}; sequence<0, 0>>{};
// Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd.
// except headdim256.
constexpr auto randval_block_inner_part_dstr_encoding = []() { constexpr auto randval_block_inner_part_dstr_encoding = []() {
if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> && if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
std::is_same_v<typename BlockGemm::BDataType, half_t> && std::is_same_v<typename BlockGemm::BDataType, half_t> &&
std::is_same_v<typename BlockGemm::CDataType, float>) std::is_same_v<typename BlockGemm::CDataType, float>)
{ {
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; if constexpr(IsWG32)
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaF16F16F32M16N16K16::CWarpDstrEncoding{};
} }
else else
{ {
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; if constexpr(IsWG32)
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaBf16Bf16F32M16N16K16::CWarpDstrEncoding{};
} }
}(); }();
...@@ -175,6 +199,7 @@ struct BlockDropout ...@@ -175,6 +199,7 @@ struct BlockDropout
typename PComputeWindow, typename PComputeWindow,
typename RandValDramWindow> typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(void* randval_ptr, CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
const index_t start_m0_idx,
const index_t start_n0_idx, const index_t start_n0_idx,
PComputeWindow& p_compute, PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const RandValDramWindow& randval_dram_window) const
...@@ -208,43 +233,6 @@ struct BlockDropout ...@@ -208,43 +233,6 @@ struct BlockDropout
randval_lds_window.get_window_origin(), randval_lds_window.get_window_origin(),
MakeRandValLdsShuffleTileDistribution<BlockGemm>()); MakeRandValLdsShuffleTileDistribution<BlockGemm>());
const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{});
if(is_store_randval)
{
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&>(rowcol));
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// save to LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
// read from LDS to register
auto randval = load_tile(randval_lds_read_window);
// save to Global
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
});
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
});
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
};
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
...@@ -282,8 +270,23 @@ struct BlockDropout ...@@ -282,8 +270,23 @@ struct BlockDropout
: PComputeDataType(0); : PComputeDataType(0);
}); });
}); });
// save to Global
if constexpr(IsStoreRandval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
}
}); });
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
}
}); });
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
}
} }
template <typename BlockGemm, template <typename BlockGemm,
...@@ -291,6 +294,7 @@ struct BlockDropout ...@@ -291,6 +294,7 @@ struct BlockDropout
typename PComputeWindow, typename PComputeWindow,
typename RandValDramWindow> typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
const index_t start_n0_idx,
PComputeWindow& p_compute, PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const RandValDramWindow& randval_dram_window) const
{ {
...@@ -308,25 +312,48 @@ struct BlockDropout ...@@ -308,25 +312,48 @@ struct BlockDropout
// register distribute // register distribute
auto randval = auto randval =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>()); make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
static_assert(randval.kThreadElementSpaceSize == 16); if constexpr(IsWG32)
static_assert(randval.kThreadElementSpaceSize == 16);
else
static_assert(randval.kThreadElementSpaceSize == 4);
const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{});
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
int block_row_start = (start_m0_idx / WG::kM) + i_m0; int block_row_start, block_col_start;
int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id(); if constexpr(IsWG32)
uint2 rowcol = make_uint2(block_row_start, block_col_start); {
block_row_start = (start_m0_idx / WG::kM) + i_m0;
block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
}
else
{
block_row_start = start_m0_idx / 32;
block_col_start = (start_n0_idx / 32) + get_warp_id() / 2;
}
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number // generate random number
uint8_t random_uint8_t[16]; uint8_t random_uint8_t[16];
uint8_t* random_uint8_t_ = random_uint8_t;
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol)); ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
if constexpr(!IsWG32)
{
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
int start_idx = ((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
uint32_t* random_uint32_t = reinterpret_cast<uint32_t*>(random_uint8_t);
random_uint8_t_ = reinterpret_cast<uint8_t*>(&random_uint32_t[start_idx]);
}
constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
int i_random_idx = 0; int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t[i_random_idx++]; randval(r_idx) = random_uint8_t_[i_random_idx++];
constexpr auto p_idx0 = constexpr auto p_idx0 =
tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{}; tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{};
constexpr auto p_idx1 = tile_distributed_index<i_n0>{}; constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
...@@ -337,19 +364,19 @@ struct BlockDropout ...@@ -337,19 +364,19 @@ struct BlockDropout
}); });
}); });
// save to Global // save to Global
if(is_store_randval) if constexpr(IsStoreRandval)
{ {
const auto randval_store = cast_tile<RandValOutputDataType>(randval); const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store); store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {kMPerStep, 0}); move_tile_window(randval_dram_window, {kMPerStep, 0});
} }
}); });
if(is_store_randval) if constexpr(IsStoreRandval)
{ {
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep}); move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
} }
}); });
if(is_store_randval) if constexpr(IsStoreRandval)
{ {
move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock}); move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
} }
...@@ -358,7 +385,6 @@ struct BlockDropout ...@@ -358,7 +385,6 @@ struct BlockDropout
ck_tile::philox ph; ck_tile::philox ph;
const float rp_undrop; const float rp_undrop;
const uint8_t p_undrop_in_uint8_t; const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -59,9 +59,12 @@ struct FmhaBwdDQDKDVKernel ...@@ -59,9 +59,12 @@ struct FmhaBwdDQDKDVKernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; using FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
// clang-format off // clang-format off
template <typename T> struct t2s; template <typename T> struct t2s;
...@@ -94,7 +97,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -94,7 +97,8 @@ struct FmhaBwdDQDKDVKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ); (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) +
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
// clang-format on // clang-format on
...@@ -117,7 +121,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -117,7 +121,7 @@ struct FmhaBwdDQDKDVKernel
const void* lse_ptr; const void* lse_ptr;
const void* do_ptr; const void* do_ptr;
const void* d_ptr; const void* d_ptr;
void* dq_ptr; void* dq_acc_ptr;
void* dk_ptr; void* dk_ptr;
void* dv_ptr; void* dv_ptr;
...@@ -131,9 +135,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -131,9 +135,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t num_head_q; ck_tile::index_t num_head_q;
ck_tile::index_t nhead_ratio_qk; ck_tile::index_t nhead_ratio_qk;
float raw_scale; float raw_scale;
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale; float scale;
#endif
ck_tile::index_t stride_q; ck_tile::index_t stride_q;
ck_tile::index_t stride_k; ck_tile::index_t stride_k;
...@@ -206,7 +208,6 @@ struct FmhaBwdDQDKDVKernel ...@@ -206,7 +208,6 @@ struct FmhaBwdDQDKDVKernel
float rp_undrop = 1; float rp_undrop = 1;
float scale_rp_undrop = 1; float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false;
uint64_t drop_seed = 1; uint64_t drop_seed = 1;
uint64_t drop_offset = 0; uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
...@@ -218,6 +219,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -218,6 +219,10 @@ struct FmhaBwdDQDKDVKernel
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
}; };
struct FmhaBwdDeterministicKargs
{
ck_tile::index_t split_stride_dq_acc = 0;
};
struct FmhaBwdBatchModeKargs struct FmhaBwdBatchModeKargs
: FmhaBwdCommonKargs, : FmhaBwdCommonKargs,
...@@ -228,7 +233,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -228,7 +233,8 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs<0>>>, FmhaBwdEmptyKargs<0>>>,
std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>, std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>,
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>, std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>> std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>>,
std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
{ {
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
...@@ -247,7 +253,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -247,7 +253,8 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs<0>>>, FmhaBwdEmptyKargs<0>>>,
std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>, std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>, std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>> std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>,
std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
{ {
const int32_t* seqstart_q_ptr; const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr; const int32_t* seqstart_k_ptr;
...@@ -266,10 +273,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -266,10 +273,10 @@ struct FmhaBwdDQDKDVKernel
const void* do_ptr, const void* do_ptr,
const void* d_ptr, const void* d_ptr,
void* rand_val_ptr, void* rand_val_ptr,
void* dq_ptr,
void* dk_ptr, void* dk_ptr,
void* dv_ptr, void* dv_ptr,
void* dbias_ptr, void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k, ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q, ck_tile::index_t hdim_q,
...@@ -304,11 +311,11 @@ struct FmhaBwdDQDKDVKernel ...@@ -304,11 +311,11 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias, ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
...@@ -317,7 +324,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -317,7 +324,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr, lse_ptr,
do_ptr, do_ptr,
d_ptr, d_ptr,
dq_ptr, dq_acc_ptr,
dk_ptr, dk_ptr,
dv_ptr, dv_ptr,
seqlen_q, seqlen_q,
...@@ -327,9 +334,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -327,9 +334,7 @@ struct FmhaBwdDQDKDVKernel
num_head_q, num_head_q,
nhead_ratio_qk, nhead_ratio_qk,
scale, scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale * ck_tile::log2e_v<>), static_cast<float>(scale * ck_tile::log2e_v<>),
#endif
stride_q, stride_q,
stride_k, stride_k,
stride_v, stride_v,
...@@ -346,6 +351,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -346,6 +351,7 @@ struct FmhaBwdDQDKDVKernel
{}, // placeholder for dbias {}, // placeholder for dbias
{}, // placeholder for mask {}, // placeholder for mask
{}, // placeholder for dropout {}, // placeholder for dropout
{}, // placeholder for deterministic
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v, batch_stride_v,
...@@ -384,11 +390,18 @@ struct FmhaBwdDQDKDVKernel ...@@ -384,11 +390,18 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); kargs.init_dropout(p_drop, drop_seed_offset, scale);
kargs.rand_val_ptr = rand_val_ptr; if constexpr(kIsStoreRandval)
kargs.stride_randval = stride_randval; {
kargs.nhead_stride_randval = nhead_stride_randval; kargs.rand_val_ptr = rand_val_ptr;
kargs.batch_stride_randval = batch_stride_randval; kargs.stride_randval = stride_randval;
kargs.is_store_randval = s_randval; kargs.nhead_stride_randval = nhead_stride_randval;
kargs.batch_stride_randval = batch_stride_randval;
}
}
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
} }
return kargs; return kargs;
...@@ -404,10 +417,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -404,10 +417,10 @@ struct FmhaBwdDQDKDVKernel
const void* do_ptr, const void* do_ptr,
const void* d_ptr, const void* d_ptr,
void* rand_val_ptr, void* rand_val_ptr,
void* dq_ptr,
void* dk_ptr, void* dk_ptr,
void* dv_ptr, void* dv_ptr,
void* dbias_ptr, void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
const void* seqstart_k_ptr, const void* seqstart_k_ptr,
const void* seqlen_k_ptr, const void* seqlen_k_ptr,
...@@ -434,11 +447,11 @@ struct FmhaBwdDQDKDVKernel ...@@ -434,11 +447,11 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dbias, ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_lsed,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
...@@ -447,7 +460,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -447,7 +460,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr, lse_ptr,
do_ptr, do_ptr,
d_ptr, d_ptr,
dq_ptr, dq_acc_ptr,
dk_ptr, dk_ptr,
dv_ptr, dv_ptr,
-1, // seqlen will be updated by another pointer -1, // seqlen will be updated by another pointer
...@@ -457,9 +470,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -457,9 +470,7 @@ struct FmhaBwdDQDKDVKernel
num_head_q, num_head_q,
nhead_ratio_qk, nhead_ratio_qk,
scale, scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale * ck_tile::log2e_v<>), static_cast<float>(scale * ck_tile::log2e_v<>),
#endif
stride_q, stride_q,
stride_k, stride_k,
stride_v, stride_v,
...@@ -476,6 +487,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -476,6 +487,7 @@ struct FmhaBwdDQDKDVKernel
{}, // placeholder for dbias {}, // placeholder for dbias
{}, // placeholder for mask {}, // placeholder for mask
{}, // placeholder for dropout {}, // placeholder for dropout
{}, // placeholder for deterministic
reinterpret_cast<const int32_t*>(seqstart_q_ptr), reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr), reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)}; reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
...@@ -506,10 +518,16 @@ struct FmhaBwdDQDKDVKernel ...@@ -506,10 +518,16 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); kargs.init_dropout(p_drop, drop_seed_offset, scale);
kargs.rand_val_ptr = rand_val_ptr; if constexpr(kIsStoreRandval)
kargs.stride_randval = stride_randval; {
kargs.nhead_stride_randval = nhead_stride_randval; kargs.rand_val_ptr = rand_val_ptr;
kargs.is_store_randval = s_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
}
}
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
} }
return kargs; return kargs;
...@@ -576,7 +594,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -576,7 +594,7 @@ struct FmhaBwdDQDKDVKernel
{ {
batch_offset_dbias = key_start; batch_offset_dbias = key_start;
} }
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
batch_offset_randval = query_start * kargs.stride_randval; batch_offset_randval = query_start * kargs.stride_randval;
} }
...@@ -618,7 +636,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -618,7 +636,7 @@ struct FmhaBwdDQDKDVKernel
{ {
batch_offset_dbias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dbias; batch_offset_dbias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dbias;
} }
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
batch_offset_randval = batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval; static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
...@@ -646,9 +664,6 @@ struct FmhaBwdDQDKDVKernel ...@@ -646,9 +664,6 @@ struct FmhaBwdDQDKDVKernel
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) + const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do + static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
batch_offset_do; batch_offset_do;
QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) + KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_k + static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_k +
batch_offset_dk; batch_offset_dk;
...@@ -663,45 +678,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -663,45 +678,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_q, 1), make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{}, number<FmhaPipeline::kAlignmentQ>{},
number<1>{}); number<1>{});
const auto q_dram = [&]() { const auto q_dram = pad_tensor_view(
if constexpr(FmhaPipeline::kQLoadOnce) q_dram_naive,
{ make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
return pad_tensor_view( sequence<kPadSeqLenQ, kPadHeadDimQ>{});
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
}();
const auto qt_dram_naive =
transform_tensor_view(q_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_q),
make_pass_through_transform(kargs.seqlen_q)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto qt_dram = [&]() {
if constexpr(FmhaPipeline::kQTLoadOnce)
{
return pad_tensor_view(
qt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kM0>{}),
sequence<kPadHeadDimQ, kPadSeqLenQ>{});
}
else
{
return pad_tensor_view(
qt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kK3>{}),
sequence<kPadHeadDimQ, kPadSeqLenQ>{});
}
}();
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr, k_ptr,
...@@ -709,45 +689,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -709,45 +689,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_k, 1), make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{}, number<FmhaPipeline::kAlignmentK>{},
number<1>{}); number<1>{});
const auto k_dram = [&]() { const auto k_dram = pad_tensor_view(
if constexpr(FmhaPipeline::kKLoadOnce) k_dram_naive,
{ make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
return pad_tensor_view( sequence<kPadSeqLenK, kPadHeadDimQ>{});
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}
}();
const auto kt_dram_naive =
transform_tensor_view(k_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_q),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto kt_dram = [&]() {
if constexpr(FmhaPipeline::kKTLoadOnce)
{
return pad_tensor_view(
kt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimQ, kPadSeqLenK>{});
}
else
{
return pad_tensor_view(
kt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kK4>{}),
sequence<kPadHeadDimQ, kPadSeqLenK>{});
}
}();
const auto v_dram = [&]() { const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
...@@ -756,20 +701,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -756,20 +701,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_v, 1), make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{}, number<FmhaPipeline::kAlignmentV>{},
number<1>{}); number<1>{});
if constexpr(FmhaPipeline::kVLoadOnce) return pad_tensor_view(
{ v_dram_naive,
return pad_tensor_view( make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
v_dram_naive, sequence<kPadSeqLenK, kPadHeadDimV>{});
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
}
else
{
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK2>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
}
}(); }();
const auto lse_dram = [&]() { const auto lse_dram = [&]() {
...@@ -792,145 +727,88 @@ struct FmhaBwdDQDKDVKernel ...@@ -792,145 +727,88 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_do, 1), make_tuple(kargs.stride_do, 1),
number<FmhaPipeline::kAlignmentOGrad>{}, number<FmhaPipeline::kAlignmentOGrad>{},
number<1>{}); number<1>{});
const auto do_dram = [&]() { const auto do_dram = pad_tensor_view(
if constexpr(FmhaPipeline::kOGradLoadOnce) do_dram_naive,
{ make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
return pad_tensor_view( sequence<kPadSeqLenQ, kPadHeadDimV>{});
do_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}
else
{
return pad_tensor_view(
do_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK2>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}
}();
const auto dot_dram_naive =
transform_tensor_view(do_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_q)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto dot_dram = [&]() {
if constexpr(FmhaPipeline::kOGradTLoadOnce)
{
return pad_tensor_view(
dot_dram_naive,
make_tuple(number<FmhaPipeline::kVHeaddim>{}, number<FmhaPipeline::kM0>{}),
sequence<kPadHeadDimV, kPadSeqLenQ>{});
}
else
{
return pad_tensor_view(
dot_dram_naive,
make_tuple(number<FmhaPipeline::kVHeaddim>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenQ>{});
}
}();
auto dq_dram = [&]() {
const auto dq_dram_naive = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
dq_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
auto q_dram_window = make_tile_window( auto q_dram_window = make_tile_window(
q_dram, q_dram,
[&]() { make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
{0, 0}); {0, 0});
auto qt_dram_window =
make_tile_window(qt_dram,
[&]() {
if constexpr(FmhaPipeline::kQTLoadOnce)
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kM0>{});
else
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kK3>{});
}(),
{0, 0});
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram, k_dram,
[&]() { make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
if constexpr(FmhaPipeline::kKLoadOnce)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
}(),
{i_n0, 0}); {i_n0, 0});
auto kt_dram_window =
make_tile_window(kt_dram,
[&]() {
if constexpr(FmhaPipeline::kKTLoadOnce)
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kN0>{});
else
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kK4>{});
}(),
{0, i_n0});
auto v_dram_window = make_tile_window( auto v_dram_window = make_tile_window(
v_dram, v_dram,
[&]() { make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
if constexpr(FmhaPipeline::kVLoadOnce)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kVHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK2>{});
}(),
{i_n0, 0}); {i_n0, 0});
auto do_dram_window = make_tile_window( auto do_dram_window = make_tile_window(
do_dram, do_dram,
[&]() { make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
if constexpr(FmhaPipeline::kOGradLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kVHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK2>{});
}(),
{0, 0}); {0, 0});
auto dot_dram_window = auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
make_tile_window(dot_dram, if constexpr(kIsDeterministic)
[&]() { {
if constexpr(FmhaPipeline::kOGradTLoadOnce) AccDataType* dq_acc_ptr =
return make_tuple(number<FmhaPipeline::kVHeaddim>{}, reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
number<FmhaPipeline::kM0>{}); static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q +
else static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
return make_tuple(number<FmhaPipeline::kVHeaddim>{}, batch_offset_q;
number<FmhaPipeline::kK1>{});
}(), auto dq_acc_dram = [&]() {
{0, 0}); const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global>(
auto dq_dram_window = make_tile_window( dq_acc_ptr,
dq_dram, make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}), make_tuple(kargs.hdim_q, 1),
{0, 0}); number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}
else
{
AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q + batch_offset_q;
auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}
}();
auto lse_dram_window = auto lse_dram_window =
make_tile_window(lse_dram, make_tuple(number<FmhaPipeline::kM0>{}), {0}); make_tile_window(lse_dram, make_tuple(number<FmhaPipeline::kM0>{}), {0});
...@@ -1008,9 +886,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -1008,9 +886,7 @@ struct FmhaBwdDQDKDVKernel
// TODO: how to use s_read? // TODO: how to use s_read?
AccDataType slope = *(reinterpret_cast<const AccDataType*>(kargs.alibi_slope_ptr) + AccDataType slope = *(reinterpret_cast<const AccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_); i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>; slope *= ck_tile::log2e_v<>;
#endif
if constexpr(kHasMask) if constexpr(kHasMask)
{ {
return make_alibi_from_lr_mask<AccDataType, false>(slope, return make_alibi_from_lr_mask<AccDataType, false>(slope,
...@@ -1038,7 +914,6 @@ struct FmhaBwdDQDKDVKernel ...@@ -1038,7 +914,6 @@ struct FmhaBwdDQDKDVKernel
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 0; uint64_t drop_seed = 0;
uint64_t drop_offset = 0; uint64_t drop_offset = 0;
bool is_store_randval = false;
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
...@@ -1047,21 +922,19 @@ struct FmhaBwdDQDKDVKernel ...@@ -1047,21 +922,19 @@ struct FmhaBwdDQDKDVKernel
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
drop_seed = kargs.drop_seed; drop_seed = kargs.drop_seed;
drop_offset = kargs.drop_offset; drop_offset = kargs.drop_offset;
is_store_randval = kargs.is_store_randval;
} }
BlockDropout dropout(i_batch, FmhaDropout dropout(i_batch,
i_nhead, i_nhead,
kargs.num_head_q, kargs.num_head_q,
drop_seed, drop_seed,
drop_offset, drop_offset,
rp_undrop, rp_undrop,
p_undrop_in_uint8_t, p_undrop_in_uint8_t);
is_store_randval);
auto randval_dram_window = [&, i_nhead_ = i_nhead]() { auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths = constexpr auto randval_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{}); make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
RandValOutputDataType* rand_val_ptr = RandValOutputDataType* rand_val_ptr =
reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) + reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
...@@ -1103,14 +976,11 @@ struct FmhaBwdDQDKDVKernel ...@@ -1103,14 +976,11 @@ struct FmhaBwdDQDKDVKernel
}(); }();
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window, auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
qt_dram_window,
k_dram_window, k_dram_window,
kt_dram_window,
v_dram_window, v_dram_window,
bias_dram_window, bias_dram_window,
randval_dram_window, randval_dram_window,
do_dram_window, do_dram_window,
dot_dram_window,
lse_dram_window, lse_dram_window,
d_dram_window, d_dram_window,
dq_dram_window, dq_dram_window,
...@@ -1118,9 +988,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -1118,9 +988,7 @@ struct FmhaBwdDQDKDVKernel
mask, mask,
position_encoding, position_encoding,
kargs.raw_scale, kargs.raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
kargs.scale, kargs.scale,
#endif
rp_undrop, rp_undrop,
scale_rp_undrop, scale_rp_undrop,
smem_ptr, smem_ptr,
...@@ -1418,4 +1286,285 @@ struct FmhaBwdOGradDotOKernel ...@@ -1418,4 +1286,285 @@ struct FmhaBwdOGradDotOKernel
} }
}; };
template <typename TilePartitioner_, typename FmhaBwdConvertQGrad_>
struct FmhaBwdConvertQGradKernel
{
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaBwdConvertQGrad = ck_tile::remove_cvref_t<FmhaBwdConvertQGrad_>;
static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu;
static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0;
static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0;
static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim;
using AccDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::AccDataType>;
using QGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::QGradDataType>;
static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadHeadDimQ) n += "d";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_" + _SS_(t2s<QGradDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + (kIsDeterministic ? "_deterministic" : "") + "_" +
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn);
#undef _SS_
#undef _TS_
// clang-format on
}
// to avoid duplicated base class prblem, introduce an template arg
template <ck_tile::index_t I>
struct FmhaBwdConvertQGradEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct FmhaBwdConvertQGradCommonKargs
{
const void* dq_acc_ptr;
void* dq_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t stride_dq;
ck_tile::index_t nhead_stride_dq;
};
struct FmhaBwdConvertQGradDeterministicKargs
{
ck_tile::index_t split_stride_dq_acc = 0;
};
struct FmhaBwdConvertQGradBatchModeKargs
: FmhaBwdConvertQGradCommonKargs,
std::conditional_t<kIsDeterministic,
FmhaBwdConvertQGradDeterministicKargs,
FmhaBwdConvertQGradEmptyKargs<0>>
{
ck_tile::index_t batch_stride_dq;
};
struct FmhaBwdConvertQGradGroupModeKargs
: FmhaBwdConvertQGradCommonKargs,
std::conditional_t<kIsDeterministic,
FmhaBwdConvertQGradDeterministicKargs,
FmhaBwdConvertQGradEmptyKargs<0>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqlen_k_ptr;
};
using Kargs = std::conditional_t<kIsGroupMode,
FmhaBwdConvertQGradGroupModeKargs,
FmhaBwdConvertQGradBatchModeKargs>;
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* dq_acc_ptr,
void* dq_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t batch_stride_dq,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr, dq_ptr, seqlen_q, seqlen_k, hdim_q, stride_dq, nhead_stride_dq},
{},
batch_stride_dq};
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
return kargs;
}
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* dq_acc_ptr,
void* dq_ptr,
const void* seqstart_q_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr,
dq_ptr,
-1, // seqlen will be updated by another pointer
-1, //
hdim_q,
stride_dq,
nhead_stride_dq},
{},
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
return kargs;
}
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// divide problem
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
long_index_t batch_offset_dq = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_dq = query_start * kargs.stride_dq;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
{
return;
}
}
else
{
batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq;
}
// for simplicity, batch stride we just modify the pointer
QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dq +
batch_offset_dq;
// dQAcc/dQ DRAM and DRAM window
const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() {
if constexpr(kIsDeterministic)
{
const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.seqlen_q * kargs.hdim_q) +
batch_offset_dq;
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
constexpr auto dq_fold = 4;
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(nsplits, kargs.seqlen_q / dq_fold, kargs.hdim_q * dq_fold),
make_tuple(kargs.split_stride_dq_acc, kargs.hdim_q * dq_fold, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{});
return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<1>{},
number<kM0 / dq_fold>{},
number<kQKHeaddim * dq_fold>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq) + batch_offset_dq;
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{});
return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
}();
auto dq_dram = [&]() {
auto dq_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(dq_dram_naive,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
auto dq_acc_dram_window = [&]() {
if constexpr(kIsDeterministic)
{
constexpr auto dq_fold = 4;
return make_tile_window(dq_acc_dram,
make_tuple(number<1>{},
number<kM0 / dq_fold>{},
number<kQKHeaddim * dq_fold>{}),
{0, i_m0 / dq_fold, 0});
}
else
{
return make_tile_window(
dq_acc_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
}
}();
auto dq_dram_window =
make_tile_window(dq_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
if constexpr(kIsDeterministic)
{
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits);
}
else
{
FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window);
}
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -7,38 +7,34 @@ ...@@ -7,38 +7,34 @@
namespace ck_tile { namespace ck_tile {
template <typename BlockFmhaShape_> template <ck_tile::index_t kN0>
struct FmhaBwdTilePartitioner struct FmhaBwdKTilePartitioner
{ {
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
CK_TILE_HOST static constexpr auto CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_); return dim3(batch_size_, nhead_, ck_tile::integer_divide_ceil(seqlen_k_, kN0));
} }
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/) CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
{ {
const index_t i_block = blockIdx.x; const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.y; const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z; const index_t i_batch = blockIdx.x;
return ck_tile::make_tuple(i_block, i_nhead, i_batch); return ck_tile::make_tuple(i_block, i_nhead, i_batch);
} }
}; };
template <ck_tile::index_t kBlockSize> template <ck_tile::index_t kM0>
struct FmhaBwdOGradDotOTilePartitioner struct FmhaBwdQTilePartitioner
{ {
CK_TILE_HOST static constexpr auto CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_); return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
} }
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/) CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
......
...@@ -47,10 +47,12 @@ struct FmhaFwdKernel ...@@ -47,10 +47,12 @@ struct FmhaFwdKernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; using FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
// clang-format off // clang-format off
template <typename T> struct t2s; template <typename T> struct t2s;
...@@ -87,7 +89,8 @@ struct FmhaFwdKernel ...@@ -87,7 +89,8 @@ struct FmhaFwdKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) +
(kIsStoreRandval ? "_storerandval" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
// clang-format on // clang-format on
...@@ -185,7 +188,6 @@ struct FmhaFwdKernel ...@@ -185,7 +188,6 @@ struct FmhaFwdKernel
} }
float rp_undrop = 1; float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false;
uint64_t drop_seed = 1; uint64_t drop_seed = 1;
uint64_t drop_offset = 0; uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
...@@ -277,7 +279,6 @@ struct FmhaFwdKernel ...@@ -277,7 +279,6 @@ struct FmhaFwdKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
...@@ -345,11 +346,13 @@ struct FmhaFwdKernel ...@@ -345,11 +346,13 @@ struct FmhaFwdKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); kargs.init_dropout(p_drop, drop_seed_offset);
kargs.rand_val_ptr = rand_val_ptr; if constexpr(kIsStoreRandval)
kargs.stride_randval = stride_randval; {
kargs.nhead_stride_randval = nhead_stride_randval; kargs.rand_val_ptr = rand_val_ptr;
kargs.batch_stride_randval = batch_stride_randval; kargs.stride_randval = stride_randval;
kargs.is_store_randval = s_randval; kargs.nhead_stride_randval = nhead_stride_randval;
kargs.batch_stride_randval = batch_stride_randval;
}
} }
return kargs; return kargs;
...@@ -392,7 +395,6 @@ struct FmhaFwdKernel ...@@ -392,7 +395,6 @@ struct FmhaFwdKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
...@@ -458,10 +460,12 @@ struct FmhaFwdKernel ...@@ -458,10 +460,12 @@ struct FmhaFwdKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); kargs.init_dropout(p_drop, drop_seed_offset);
kargs.rand_val_ptr = rand_val_ptr; if constexpr(kIsStoreRandval)
kargs.stride_randval = stride_randval; {
kargs.nhead_stride_randval = nhead_stride_randval; kargs.rand_val_ptr = rand_val_ptr;
kargs.is_store_randval = s_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
}
} }
return kargs; return kargs;
...@@ -526,7 +530,7 @@ struct FmhaFwdKernel ...@@ -526,7 +530,7 @@ struct FmhaFwdKernel
{ {
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
} }
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
batch_offset_randval = query_start * kargs.stride_randval; batch_offset_randval = query_start * kargs.stride_randval;
} }
...@@ -566,7 +570,7 @@ struct FmhaFwdKernel ...@@ -566,7 +570,7 @@ struct FmhaFwdKernel
{ {
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
} }
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
batch_offset_randval = batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval; static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
...@@ -744,28 +748,31 @@ struct FmhaFwdKernel ...@@ -744,28 +748,31 @@ struct FmhaFwdKernel
} }
}(); }();
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { // dropout
if constexpr(kHasDropout) float rp_undrop = 1;
{ uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
return BlockDropout{i_batch_, uint64_t drop_seed = 0;
i_nhead_, uint64_t drop_offset = 0;
kargs.num_head_q,
kargs.drop_seed, if constexpr(kHasDropout)
kargs.drop_offset, {
kargs.rp_undrop, rp_undrop = kargs.rp_undrop;
kargs.p_undrop_in_uint8_t, p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
kargs.is_store_randval}; drop_seed = kargs.drop_seed;
} drop_offset = kargs.drop_offset;
else }
{ FmhaDropout dropout(i_batch,
return NullBlockDropout{}; i_nhead,
}; kargs.num_head_q,
}(); drop_seed,
drop_offset,
rp_undrop,
p_undrop_in_uint8_t);
auto randval_dram_window = [&, i_nhead_ = i_nhead]() { auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths = constexpr auto randval_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{}); make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
RandValOutputDataType* rand_val_ptr = RandValOutputDataType* rand_val_ptr =
reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) + reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
......
...@@ -46,10 +46,12 @@ struct FmhaFwdSplitKVKernel ...@@ -46,10 +46,12 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; using FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
// clang-format off // clang-format off
template <typename T> struct t2s; template <typename T> struct t2s;
...@@ -86,7 +88,8 @@ struct FmhaFwdSplitKVKernel ...@@ -86,7 +88,8 @@ struct FmhaFwdSplitKVKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) +
(kIsStoreRandval ? "_storerandval" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
// clang-format on // clang-format on
...@@ -189,7 +192,6 @@ struct FmhaFwdSplitKVKernel ...@@ -189,7 +192,6 @@ struct FmhaFwdSplitKVKernel
} }
float rp_undrop = 1; float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false;
uint64_t drop_seed = 1; uint64_t drop_seed = 1;
uint64_t drop_offset = 0; uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
...@@ -282,7 +284,6 @@ struct FmhaFwdSplitKVKernel ...@@ -282,7 +284,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
...@@ -350,11 +351,13 @@ struct FmhaFwdSplitKVKernel ...@@ -350,11 +351,13 @@ struct FmhaFwdSplitKVKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); kargs.init_dropout(p_drop, drop_seed_offset);
kargs.rand_val_ptr = rand_val_ptr; if constexpr(kIsStoreRandval)
kargs.stride_randval = stride_randval; {
kargs.nhead_stride_randval = nhead_stride_randval; kargs.rand_val_ptr = rand_val_ptr;
kargs.batch_stride_randval = batch_stride_randval; kargs.stride_randval = stride_randval;
kargs.is_store_randval = s_randval; kargs.nhead_stride_randval = nhead_stride_randval;
kargs.batch_stride_randval = batch_stride_randval;
}
} }
return kargs; return kargs;
...@@ -402,7 +405,6 @@ struct FmhaFwdSplitKVKernel ...@@ -402,7 +405,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
...@@ -469,10 +471,12 @@ struct FmhaFwdSplitKVKernel ...@@ -469,10 +471,12 @@ struct FmhaFwdSplitKVKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); kargs.init_dropout(p_drop, drop_seed_offset);
kargs.rand_val_ptr = rand_val_ptr; if constexpr(kIsStoreRandval)
kargs.stride_randval = stride_randval; {
kargs.nhead_stride_randval = nhead_stride_randval; kargs.rand_val_ptr = rand_val_ptr;
kargs.is_store_randval = s_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
}
} }
return kargs; return kargs;
...@@ -536,7 +540,7 @@ struct FmhaFwdSplitKVKernel ...@@ -536,7 +540,7 @@ struct FmhaFwdSplitKVKernel
{ {
batch_offset_bias = query_start * kargs.stride_bias + key_start; batch_offset_bias = query_start * kargs.stride_bias + key_start;
} }
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
batch_offset_randval = query_start * kargs.stride_randval; batch_offset_randval = query_start * kargs.stride_randval;
} }
...@@ -571,7 +575,7 @@ struct FmhaFwdSplitKVKernel ...@@ -571,7 +575,7 @@ struct FmhaFwdSplitKVKernel
{ {
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias; batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
} }
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
batch_offset_randval = batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval; static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
...@@ -747,7 +751,6 @@ struct FmhaFwdSplitKVKernel ...@@ -747,7 +751,6 @@ struct FmhaFwdSplitKVKernel
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 0; uint64_t drop_seed = 0;
uint64_t drop_offset = 0; uint64_t drop_offset = 0;
bool is_store_randval = false;
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
...@@ -755,21 +758,19 @@ struct FmhaFwdSplitKVKernel ...@@ -755,21 +758,19 @@ struct FmhaFwdSplitKVKernel
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
drop_seed = kargs.drop_seed; drop_seed = kargs.drop_seed;
drop_offset = kargs.drop_offset; drop_offset = kargs.drop_offset;
is_store_randval = kargs.is_store_randval;
} }
BlockDropout dropout(i_batch, FmhaDropout dropout(i_batch,
i_nhead, i_nhead,
kargs.num_head_q, kargs.num_head_q,
drop_seed, drop_seed,
drop_offset, drop_offset,
rp_undrop, rp_undrop,
p_undrop_in_uint8_t, p_undrop_in_uint8_t);
is_store_randval);
auto randval_dram_window = [&, i_nhead_ = i_nhead]() { auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths = constexpr auto randval_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{}); make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
RandValOutputDataType* rand_val_ptr = RandValOutputDataType* rand_val_ptr =
reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) + reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdConvertQGrad
{
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
static constexpr index_t kM0 = Problem::Shape::kM0;
static constexpr index_t kN0 = Problem::Shape::kN0;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kQKHeaddim = Problem::Shape::kQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr index_t kAlignmentQGradAcc =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGrad<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
// Convert only
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const
{
static_assert(
std::is_same_v<AccDataType,
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccDramTileDistribution<Problem>());
auto dq_acc = load_tile(dq_acc_dram_window);
const auto dq = cast_tile<QGradDataType>(dq_acc);
store_tile(dq_dram_block_window_tmp, dq);
}
// Reduce + Convert
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
index_t nsplits) const
{
static_assert(
std::is_same_v<AccDataType,
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window = make_tile_window(
dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccDeterministicDramTileDistribution<Problem>());
auto dq_acc = decltype(load_tile(dq_acc_dram_window)){};
clear_tile(dq_acc);
constexpr auto dq_acc_spans = decltype(dq_acc)::get_distributed_spans();
index_t i_total_loops = 0;
auto dq_acc_buf = load_tile(dq_acc_dram_window);
move_tile_window(dq_acc_dram_window, {1, 0, 0});
do
{
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
});
});
});
dq_acc_buf = load_tile(dq_acc_dram_window);
move_tile_window(dq_acc_dram_window, {1, 0, 0});
i_total_loops += 1;
} while(i_total_loops < (nsplits - 1));
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
});
});
});
// declare dq
constexpr auto dq_converted_dstr =
Policy::template MakePostQGradAccDeterministicDramTileDistribution<Problem>();
auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr);
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_converted(n_i_j_idx) = type_convert<QGradDataType>(dq_acc[n_i_j_idx]);
});
});
});
constexpr auto dq_dstr =
Policy::template MakePostQGradDeterministicDramTileDistribution<Problem>();
auto dq = make_static_distributed_tensor<QGradDataType>(dq_dstr);
dq.get_thread_buffer() = dq_converted.get_thread_buffer();
store_tile(dq_dram_block_window_tmp, dq);
}
};
} // namespace ck_tile
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdOGradDotODefaultPolicy> template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdOGradDotO struct BlockFmhaBwdOGradDotO
{ {
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
...@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO ...@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
static constexpr index_t kAlignmentO = static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad = static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// These templates are not used here.
using BlockFmhaBwdOGradDotODefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ false,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ false,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
...@@ -6,13 +6,13 @@ ...@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy> template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineKSKTSVR struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
...@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>; using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>; using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
...@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = false; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kQTLoadOnce = false; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kKLoadOnce = true; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kKTLoadOnce = true; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kVLoadOnce = true; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kOGradLoadOnce = false; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kOGradTLoadOnce = false; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -71,12 +65,10 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -71,12 +65,10 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad = static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad = static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad = static constexpr index_t kAlignmentVGrad =
...@@ -84,7 +76,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -84,7 +76,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static constexpr index_t kAlignmentBias = static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>(); kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "ks_kts_vr"; static constexpr const char* name = "kr_ktr_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
...@@ -92,14 +84,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -92,14 +84,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp, typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp, typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp, typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp, typename QGradDramBlockWindowTmp,
...@@ -107,14 +96,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -107,14 +96,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& kt_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp, const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp, const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
...@@ -122,43 +108,29 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -122,43 +108,29 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float raw_scale, float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale, float scale,
#endif
float rp_undrop, float rp_undrop,
float scale_rp_undrop, float scale_rp_undrop,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QDataType,
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> && std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType,
remove_cvref_t<typename KTDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> && std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType, std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> && remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType, std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> && remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> && std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!"); "wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kVHeaddim ==
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
...@@ -166,83 +138,6 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -166,83 +138,6 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
// QT tile in LDS
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto qt_lds = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto kt_lds = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
// OGradT tile in LDS
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto dot_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM // Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
...@@ -250,34 +145,19 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -250,34 +145,19 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>(); constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>(); constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad // init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc); // K, HBM ->LDS ->Reg
clear_tile(dk_acc); auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
auto k_dram_window = make_tile_window( k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_origin(),
k_dram_block_window_tmp.get_window_lengths(), Policy::template MakeKDramTileDistribution<Problem>());
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin(); const auto k_origin = k_dram_window.get_window_origin();
// Early termination
const auto [seqlen_q_start, seqlen_q_end] = const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}); mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
...@@ -290,205 +170,415 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -290,205 +170,415 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
{ {
// Note: here dk_acc&dv_acc are all cleard, return it // Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it. // Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc); return make_tuple(dk_acc, dv_acc);
} }
} }
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_block_tile = load_tile(k_dram_window); auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVDramTileDistribution<Problem>());
VDataType* v_lds_ptr =
static_cast<VDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto v_lds = make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
auto kt_dram_block_window = kt_dram_block_window_tmp; auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto kt_dram_window = make_tile_window( auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
kt_dram_block_window.get_bottom_tensor_view(), Policy::template MakeVRegBlockDescriptor<Problem>());
kt_dram_block_window.get_window_lengths(),
kt_dram_block_window.get_window_origin(),
Policy::template MakeKTDramTileDistribution<Problem>()); // K^T DRAM tile window for
// load
auto kt_block_tile = load_tile(kt_dram_window); //------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto kt_block_tile = make_static_distributed_tensor<KDataType>(
Policy::template MakeKTRegWriteBlockDescriptor<Problem>());
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto kt_lds_write = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsWriteBlockDescriptor<Problem>());
auto kt_lds_write_window =
make_tile_window(kt_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
auto kt_lds_read_window =
make_tile_window(kt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
{0, 0},
Policy::template MakeKTRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto k_block_tile = load_tile(k_dram_window);
auto v_block_tile = load_tile(v_dram_window);
auto kt_shuffle_tmp = make_static_distributed_tensor<KDataType>( store_tile(k_lds_write_window, k_block_tile);
Policy::template MakeShuffledKTRegBlockDescriptor<Problem>()); shuffle_tile(kt_block_tile, k_block_tile);
shuffle_tile(kt_shuffle_tmp, kt_block_tile); store_tile(kt_lds_write_window, kt_block_tile);
store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS block_sync_lds();
k_reg_tensor = load_tile(k_lds_read_window);
block_sync_lds();
auto q_dram_block_window = auto kt_reg_tensor = load_tile(kt_lds_read_window);
store_tile(v_lds_write_window, v_block_tile);
block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window);
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0}); {seqlen_q_start, 0},
Policy::template MakeQDramTileDistribution<Problem>());
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto qt_dram_block_window = auto q_lds_read_window =
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_lds_window.get_bottom_tensor_view(),
qt_dram_block_window_tmp.get_window_lengths(), make_tuple(number<kM0>{}, number<kK0>{}),
{0, seqlen_q_start}); q_lds_window.get_window_origin(),
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
auto do_dram_block_window = auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakePTRegSliceBlockDescriptor<Problem>());
// QT: Reg -> Reg-> LDS
auto qt_block_tile = make_static_distributed_tensor<QDataType>(
Policy::template MakeQTRegWriteBlockDescriptor<Problem>());
QDataType* qt_lds_ptr =
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto qt_lds_write = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsWriteBlockDescriptor<Problem>());
auto qt_lds_write_window =
make_tile_window(qt_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
auto qt_lds_read_window =
make_tile_window(qt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
{0, 0},
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
// dO: HBM ->Reg ->LDS
auto do_dram_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(), do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0}); {seqlen_q_start, 0},
Policy::template MakeOGradDramTileDistribution<Problem>());
auto dot_dram_block_window = OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
dot_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto dq_dram_block_window = auto do_lds = make_tensor_view<address_space_enum::lds>(
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window = auto do_lds_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window = auto do_lds_read_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(do_lds_window.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(), make_tuple(number<kM0>{}, number<kK2>{}),
{seqlen_q_start}); do_lds_window.get_window_origin(),
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
// dOT: Reg ->Reg ->LDS
auto dot_block_tile = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeOGradTRegWriteBlockDescriptor<Problem>());
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
auto bias_dram_block_window = static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), Policy::template GetSmemSizeOGrad<Problem>()));
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); auto dot_write_lds = make_tensor_view<address_space_enum::lds>(
auto dbias_dram_block_window = dot_lds_ptr, Policy::template MakeOGradTLdsWriteBlockDescriptor<Problem>());
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto qt_dram_window = auto dot_lds_write_window =
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), make_tile_window(dot_write_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
qt_dram_block_window.get_window_lengths(),
qt_dram_block_window.get_window_origin(),
Policy::template MakeQTDramTileDistribution<Problem>());
auto dot_dram_window = auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
dot_dram_block_window.get_window_lengths(),
dot_dram_block_window.get_window_origin(),
Policy::template MakeOGradTDramTileDistribution<Problem>());
auto lse_dram_window = make_tile_window( auto dot_lds_read_window =
lse_dram_block_window.get_bottom_tensor_view(), make_tile_window(dot_read_lds,
lse_dram_block_window.get_window_lengths(), make_tuple(number<kVHeaddim>{}, number<kM0>{}),
lse_dram_block_window.get_window_origin(), {0, 0},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
auto d_dram_window = make_tile_window( // dS: Reg -> Reg -> LDS
d_dram_block_window.get_bottom_tensor_view(), GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
d_dram_block_window.get_window_lengths(), static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
d_dram_block_window.get_window_origin(), Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto ds_lds_read_window =
make_tile_window(ds_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK4>{}),
ds_lds_window.get_window_origin(),
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
// Bias: HBM ->Reg ->Reg ->LDS
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = auto bias_dram_window =
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
bias_dram_block_window.get_window_origin(), {seqlen_q_start, bias_origin.at(number<1>{})},
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto biast_lds_window = auto biast_lds_window =
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
biast_lds_shuffle_window.get_window_lengths(), biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(), biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>()); Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>( static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
LSEDataType* lse_lds_ptr = static_cast<LSEDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>()));
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
DDataType* d_lds_ptr = static_cast<DDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>()));
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start); randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0; // BiasGrad
constexpr index_t k0_loops = kQKHeaddim / kK0; // Reg ->LDS ->Reg ->HBM
constexpr index_t k1_loops = kM0 / kK1; const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3; auto dbias_dram_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
// Deterministic mode staff
auto dq_buffer_view = dq_dram_block_window_tmp.get_bottom_tensor_view().get_buffer_view();
auto dq_tensor_desc =
dq_dram_block_window_tmp.get_bottom_tensor_view().get_tensor_descriptor();
auto seqlen_q = dq_tensor_desc.get_lengths()[number<0>{}];
auto hdim_q = dq_tensor_desc.get_lengths()[number<1>{}];
constexpr auto dq_fold = 4;
auto dq_write_tensor_desc =
make_naive_tensor_descriptor(make_tuple(seqlen_q / dq_fold, hdim_q * dq_fold),
make_tuple(hdim_q * dq_fold, 1),
number<kAlignmentQGrad>{},
number<1>{});
auto dq_tensor_view = tensor_view<decltype(dq_buffer_view), decltype(dq_write_tensor_desc)>{
dq_buffer_view, dq_write_tensor_desc};
auto dq_dram_window_deterministic =
make_tile_window(dq_tensor_view,
make_tuple(number<kM0 / dq_fold>{}, number<kQKHeaddim * dq_fold>{}),
{seqlen_q_start / dq_fold, 0});
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4; constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0 /*
auto st_acc = SPTBlockTileType{}; * Prefetch Q, LSE, dO, D
*/
auto q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
auto lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
auto q_block_tile = load_tile(q_dram_window); auto do_block_tile = load_tile(do_dram_window);
{ move_tile_window(do_dram_window, {kM0, 0});
move_tile_window(q_dram_window, {0, kK0});
clear_tile(st_acc); // Initialize S^T auto d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
store_tile(q_lds_window, q_block_tile); // LDS write 0 /*
q_block_tile = load_tile(q_dram_window); // global read 1 * Store prefetched data into LDS
} */
store_tile(q_lds_window, q_block_tile);
shuffle_tile(qt_block_tile, q_block_tile);
store_tile(qt_lds_write_window, qt_block_tile);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) store_tile(lse_lds_write_window, lse_block_tile);
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2) store_tile(do_lds_window, do_block_tile);
{ shuffle_tile(dot_block_tile, do_block_tile);
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { store_tile(dot_lds_write_window, dot_block_tile);
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
move_tile_window(q_dram_window, {0, kK0});
store_tile(q_lds_window,
q_block_tile); // LDS write i + 1
q_block_tile = load_tile(q_dram_window); // global read i + 2
});
}
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile store_tile(d_lds_write_window, d_block_tile);
{ // tail block_sync_lds();
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kN0, (k0_loops - 1) * kK0>{}));
block_sync_lds();
store_tile(q_lds_window, q_block_tile); /*
block_sync_lds(); * Prefetch LDS data into Reg to Asynchronous Data Movement and MFMA pipeline
*/
gemm_0(st_acc, auto q_reg_tensor = load_tile(q_lds_read_window);
q_lds_window, auto lse = load_tile(lse_lds_read_window);
get_slice_tile(k_lds_window, auto do_reg_tensor = load_tile(do_lds_read_window);
sequence<0, (k0_loops - 1) * kK0>{}, auto d = load_tile(d_lds_read_window);
sequence<kN0, k0_loops * kK0>{}));
} clear_tile(dv_acc);
clear_tile(dk_acc);
__builtin_amdgcn_sched_barrier(0);
// Hot loop
do
{
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
clear_tile(st_acc);
q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
do_block_tile = load_tile(do_dram_window);
move_tile_window(do_dram_window, {kM0, 0});
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
gemm_0(st_acc, q_reg_tensor, k_reg_tensor);
auto dot_reg_tensor = load_tile(dot_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<0>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
const auto bias_tile = load_tile(bias_dram_window);
block_sync_lds(); block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>( auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>()); Policy::template MakeShuffledBiasTileDistribution<Problem>());
...@@ -498,11 +588,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -498,11 +588,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
auto biast_tile = load_tile(biast_lds_window); auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout( tile_elementwise_inout(
[&](auto& x, const auto& y) { [&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y); x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
}, },
st_acc, st_acc,
biast_tile); biast_tile);
...@@ -510,52 +596,36 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -510,52 +596,36 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
const auto q_origin = q_dram_block_window.get_window_origin();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices( const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc(i_j_idx) *= raw_scale;
#else
st_acc(i_j_idx) *= scale; st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col); position_encoding.update(st_acc(i_j_idx), row, col);
}); });
}); });
} }
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{ {
const auto q_origin = q_dram_block_window.get_window_origin(); bool need_perpixel_check = mask.IsEdgeTile(
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check) if(need_perpixel_check)
{ {
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) { set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col); return mask.IsOutOfBound(row, col);
}); });
} }
} }
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) { static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking) FmhaMask::IsMasking)
...@@ -574,12 +644,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -574,12 +644,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI) BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
...@@ -589,31 +658,16 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -589,31 +658,16 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
{ {
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
} }
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
}); });
}); });
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>( if constexpr(FmhaDropout::IsDropout)
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
block_sync_lds();
{ {
shuffle_tile(dot_shuffle_tmp, dot_prefetch); dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
store_tile(dot_lds_window, seqlen_q_step, k_origin.at(number<0>{}), pt, randval_dram_window);
dot_shuffle_tmp); // store the prefetch
} }
move_tile_window(dot_dram_window, {0, kK1});
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
const auto pt_gemm = [&]() { const auto pt_gemm = [&]() {
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); }, [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
...@@ -625,87 +679,37 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -625,87 +679,37 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
} }
}(); }();
if constexpr(k1_loops > 1) // STAGE 3, P^T@OGrad^T Gemm1
{ pt_reg_tensor.get_thread_buffer() = pt_gemm.get_thread_buffer();
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
block_sync_lds(); auto qt_reg_tensor = load_tile(qt_lds_read_window);
gemm_1(dv_acc,
get_slice_tile(pt_gemm,
sequence<i_k1 * kK1, 0>{},
sequence<(i_k1 + 1) * kK1, kN0>{}),
dot_lds_window);
block_sync_lds();
shuffle_tile(dot_shuffle_tmp, dot);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
move_tile_window(dot_dram_window, {0, kK1});
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
// tail
{
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
dot_lds_window);
block_sync_lds();
}
HotLoopScheduler::template GemmStagedScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 4, OGrad@V Gemm2 // STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{}; auto dpt_acc = SPGradTBlockTileType{};
clear_tile(dpt_acc);
{ gemm_2(dpt_acc, do_reg_tensor, v_reg_tensor);
move_tile_window(do_dram_window, {0, kK2});
clear_tile(dpt_acc); // Initialize PGrad^T block_sync_lds();
store_tile(do_lds_window, do_block_tile); // LDS write 0 store_tile(q_lds_window, q_block_tile);
do_block_tile = load_tile(do_dram_window); // global read 1 shuffle_tile(qt_block_tile, q_block_tile);
} store_tile(qt_lds_write_window, qt_block_tile);
if constexpr(k2_loops > 2) store_tile(lse_lds_write_window, lse_block_tile);
{
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
move_tile_window(do_dram_window, {0, kK2});
store_tile(do_lds_window,
do_block_tile); // LDS write i + 1
do_block_tile = load_tile(do_dram_window); // global read i + 2
});
}
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile store_tile(do_lds_window, do_block_tile);
{ // tail shuffle_tile(dot_block_tile, do_block_tile);
block_sync_lds(); store_tile(dot_lds_write_window, dot_block_tile);
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 2) * kK2>{},
sequence<kN0, (k2_loops - 1) * kK2>{}));
block_sync_lds();
store_tile(do_lds_window, do_block_tile); store_tile(d_lds_write_window, d_block_tile);
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 1) * kK2>{},
sequence<kN0, k2_loops * kK2>{}));
}
HotLoopScheduler::template GemmStagedScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 5, P^T(PGrad^T - D) // STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{}; auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
...@@ -713,16 +717,16 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -713,16 +717,16 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0; bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) = dst(i_j_idx) = pt[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
pt[i_j_idx] * ? (dpt_acc[i_j_idx] - d[i_idx])
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); : d[i_idx]);
}); });
}); });
if constexpr(kHasBiasGrad) if constexpr(kHasBiasGrad)
{ {
const auto dbiast = [&]() { const auto dbiast = [&]() {
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in(
[&rp_undrop](const auto& x) { [&rp_undrop](const auto& x) {
...@@ -741,107 +745,321 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -741,107 +745,321 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>( auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); store_tile(dbias_dram_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_block_window, {kM0, 0}); move_tile_window(dbias_dram_window, {kM0, 0});
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>( const auto dst_gemm = cast_tile<GemmDataType>(dst);
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
dst_reg_tensor.get_thread_buffer() = dst_gemm.get_thread_buffer();
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, dst_gemm);
block_sync_lds(); block_sync_lds();
auto ds_reg_tensor = load_tile(ds_lds_read_window);
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
move_tile_window(ds_lds_read_window, {0, kK4});
q_reg_tensor = load_tile(q_lds_read_window);
lse = load_tile(lse_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<3>();
__builtin_amdgcn_sched_barrier(0);
// STAGE7 SGrad@K^T
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {0, kK4});
}
auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
}
});
move_tile_window(ds_lds_read_window, {0, -kN0});
do_reg_tensor = load_tile(do_lds_read_window);
d = load_tile(d_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<4>();
// QGrad Scale
if constexpr(FmhaDropout::IsDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{ {
shuffle_tile(qt_shuffle_tmp, qt_prefetch); tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
} }
move_tile_window(qt_dram_window, {0, kK3}); if constexpr(kIsDeterministic)
{
auto dq_write_reg_tensor = make_static_distributed_tensor<AccDataType>(
Policy::template MakeQGradWriteBlockDescriptor<Problem>());
const auto dst_gemm = cast_tile<GemmDataType>(dst); dq_write_reg_tensor.get_thread_buffer() = dq_acc.get_thread_buffer();
if constexpr(k3_loops > 1) store_tile(dq_dram_window_deterministic, dq_write_reg_tensor);
{ move_tile_window(dq_dram_window_deterministic, {kM0 / dq_fold, 0});
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
const auto qt = load_tile(qt_dram_window); // load next Q^T
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(dst_gemm,
sequence<i_k3 * kK3, 0>{},
sequence<(i_k3 + 1) * kK3, kN0>{}),
qt_lds_window);
block_sync_lds();
shuffle_tile(qt_shuffle_tmp, qt);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
move_tile_window(qt_dram_window, {0, kK3});
});
} }
// tail else
{ {
block_sync_lds(); update_tile(dq_dram_window, dq_acc);
gemm_3(dk_acc, move_tile_window(dq_dram_window, {kM0, 0});
get_slice_tile(
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
qt_lds_window);
block_sync_lds();
} }
// STAGE 7, SGrad@K^T Gemm4 i_total_loops += 1;
store_tile(ds_lds_window, dst_gemm); seqlen_q_step += kM0;
} while(i_total_loops < (num_total_loop - 1));
__builtin_amdgcn_sched_barrier(0);
auto dq_acc = QGradBlockTileType{}; // Tail
clear_tile(dq_acc); // Initialize QGrad auto st_acc = SPTBlockTileType{};
clear_tile(st_acc);
gemm_0(st_acc, q_reg_tensor, k_reg_tensor);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
const auto bias_tile = load_tile(bias_dram_window);
block_sync_lds(); block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
static_for<0, k4_loops, 1>{}([&](auto i_k4) { st_acc(i_j_idx) *= scale;
gemm_4(dq_acc, position_encoding.update(st_acc(i_j_idx), row, col);
get_slice_tile(ds_lds_window, });
sequence<0, i_k4 * kK4>{},
sequence<kM0, (i_k4 + 1) * kK4>{}),
get_slice_tile(kt_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
}); });
}
// QGrad Scale if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
if constexpr(kHasDropout) {
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
if(need_perpixel_check)
{ {
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
dq_acc); const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_lse;
} }
else else
{ {
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); return raw_lse;
} }
const auto dq = cast_tile<QGradDataType>(dq_acc); };
update_tile(dq_dram_block_window, dq);
// move tile windows auto pt = SPTBlockTileType{};
move_tile_window(q_dram_block_window, {kM0, 0}); constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
move_tile_window(dq_dram_block_window, {kM0, 0}); sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
move_tile_window(do_dram_block_window, {kM0, 0}); constexpr auto i_idx = make_tuple(idx0);
move_tile_window(lse_dram_window, {kM0}); auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop); sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
});
});
if constexpr(FmhaDropout::IsDropout)
{
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_step, k_origin.at(number<0>{}), pt, randval_dram_window);
}
// KGrad Scale // STAGE 3, P^T@OGrad^T Gemm1
if constexpr(kHasDropout) const auto pt_gemm = [&]() {
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
pt_reg_tensor.get_thread_buffer() = pt_gemm.get_thread_buffer();
auto dot_reg_tensor = load_tile(dot_lds_read_window);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>();
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
clear_tile(dpt_acc);
auto qt_reg_tensor = load_tile(qt_lds_read_window);
gemm_2(dpt_acc, do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>();
// STAGE 5, P^T(PGrad^T - D)
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) = pt[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
? (dpt_acc[i_j_idx] - d[i_idx])
: d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{ {
const auto dbiast = [&]() {
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
const auto dst_gemm = cast_tile<GemmDataType>(dst);
dst_reg_tensor.get_thread_buffer() = dst_gemm.get_thread_buffer();
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, dst_gemm);
block_sync_lds();
auto ds_reg_tensor = load_tile(ds_lds_read_window);
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
move_tile_window(ds_lds_read_window, {0, kK4});
HotLoopScheduler::template GemmStagedScheduler<3>();
// STAGE 7, SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {0, kK4});
}
auto kt_reg_tensor_slice = get_slice_tile(
kt_reg_tensor, sequence<0, i_k4 * kK4>{}, sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
}
});
HotLoopScheduler::template GemmStagedScheduler<4>();
// Results Scale
if constexpr(FmhaDropout::IsDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc); dk_acc);
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
} }
else else
{ {
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
} }
// VGrad Scale
if constexpr(kHasDropout) if constexpr(kIsDeterministic)
{ {
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); auto dq_write_reg_tensor = make_static_distributed_tensor<AccDataType>(
Policy::template MakeQGradWriteBlockDescriptor<Problem>());
dq_write_reg_tensor.get_thread_buffer() = dq_acc.get_thread_buffer();
store_tile(dq_dram_window_deterministic, dq_write_reg_tensor);
}
else
{
update_tile(dq_dram_window, dq_acc);
} }
return ck_tile::make_tuple(dk_acc, dv_acc); return make_tuple(dk_acc, dv_acc);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment