Commit f1793462 authored by rocking's avatar rocking
Browse files

Add rowwise dynamic quant kernel

parent b5d201dd
...@@ -6,7 +6,6 @@ FWD_DTYPE_MAP = { ...@@ -6,7 +6,6 @@ FWD_DTYPE_MAP = {
"fp16" : "FmhaFwdFp16", "fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16", "bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8", "fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16" "fp8bf16": "FmhaFwdFp8Bf16"
} }
......
...@@ -56,6 +56,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, ...@@ -56,6 +56,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_lse}, {F_lse},
{F_dropout}, {F_dropout},
{F_squant}, {F_squant},
{F_rdquant},
{F_occupancy}>; {F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask}; using fmha_mask_{F_idx} = {F_mask};
...@@ -88,7 +89,7 @@ using fmha_kernel_{F_idx} = ...@@ -88,7 +89,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>; ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_rdquant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream> #include <iostream>
...@@ -123,9 +124,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < ...@@ -123,9 +124,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}} }}
""" """
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.do_fp8_rowwise_dynamic_quant == {F_rdquant})&&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_rdquant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a); return fmha_fwd_<trait_>(s, a);
}} }}
""" """
...@@ -149,6 +150,7 @@ class FmhaFwdApiTrait: ...@@ -149,6 +150,7 @@ class FmhaFwdApiTrait:
lse : str # lse : str #
dropout : str dropout : str
squant : str # squant : str #
rdquant : str #
spad : str spad : str
skpad : str skpad : str
dpad : str dpad : str
...@@ -157,7 +159,7 @@ class FmhaFwdApiTrait: ...@@ -157,7 +159,7 @@ class FmhaFwdApiTrait:
@property @property
def name(self) -> str: def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.dquant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
@property @property
def scheck(self) -> str: def scheck(self) -> str:
...@@ -218,6 +220,7 @@ class FmhaFwdPipeline: ...@@ -218,6 +220,7 @@ class FmhaFwdPipeline:
F_lse : str # F_lse : str #
F_dropout : str # F_dropout : str #
F_squant : str # F_squant : str #
F_rdquant : str #
F_mask : str # value from MASK_MAP F_mask : str # value from MASK_MAP
@property @property
...@@ -241,6 +244,7 @@ class FmhaFwdPipeline: ...@@ -241,6 +244,7 @@ class FmhaFwdPipeline:
if self.F_lse == 't' : n += '_lse' if self.F_lse == 't' : n += '_lse'
if self.F_dropout == 't' : n += '_dropout' if self.F_dropout == 't' : n += '_dropout'
if self.F_squant == 't' : n += '_squant' if self.F_squant == 't' : n += '_squant'
if self.F_rdquant == 't' : n += '_rdquant'
return n return n
class FmhaFwdApiPool: class FmhaFwdApiPool:
...@@ -271,7 +275,7 @@ class FmhaFwdApiPool: ...@@ -271,7 +275,7 @@ class FmhaFwdApiPool:
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_squant=BOOL_MAP[trait.squant], F_rdquant=BOOL_MAP[trait.rdquant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
...@@ -357,6 +361,7 @@ class FmhaFwdKernel: ...@@ -357,6 +361,7 @@ class FmhaFwdKernel:
F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_rdquant = BOOL_MAP[self.F_pipeline.F_rdquant],
F_occupancy = self.F_tile.F_occupancy, F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
...@@ -391,6 +396,7 @@ class FmhaFwdKernel: ...@@ -391,6 +396,7 @@ class FmhaFwdKernel:
lse=self.F_pipeline.F_lse, lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout, dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant, squant=self.F_pipeline.F_squant,
rdquant=self.F_pipeline.F_rdquant,
spad=self.F_pipeline.F_spad, spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad, skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad, dpad=self.F_pipeline.F_dpad,
...@@ -407,7 +413,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -407,7 +413,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8' or dtype == 'fp8bf16':
return { return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
...@@ -424,39 +430,39 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -424,39 +430,39 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
# TODO: the order of List matters! the later in this list will be also be checked later # TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic? # TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = [] pipelines = []
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
if hdim == 256: if hdim == 256:
# if True: # if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
else: else:
if bias == "bias": if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case # TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
else: else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
if receipt == 1 and bias != "bias": if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, 'f', 'f', mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8']:
# no need lse/dropout kernels # no need lse/dropout kernels
# squant
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', 't', 'f', mask))
elif dtype in ['fp8fp16', 'fp8bf16']: elif dtype in ['fp8bf16']:
# TODO # rdquant
None pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', 'no', 'f', 'f', 'f', 't', 's_no'))
else: else:
assert False assert False
return pipelines return pipelines
...@@ -492,6 +498,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -492,6 +498,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f' cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_rdquant == 'f'
if not cond: if not cond:
continue continue
api_pool.register_traits(k.api_trait()) api_pool.register_traits(k.api_trait())
......
...@@ -85,6 +85,10 @@ auto create_args(int argc, char* argv[]) ...@@ -85,6 +85,10 @@ auto create_args(int argc, char* argv[])
"P and O.\n" "P and O.\n"
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
"range_p, range_o") "range_p, range_o")
.insert("rdquant",
"0",
"if using rowwise dynamic quantization fusion or not.\n"
"0: no dynamic quant. 1: apply rowwise dynamic quantization with respect Q, K and V.\n")
.insert("iperm", .insert("iperm",
"1", "1",
"permute input\n" "permute input\n"
...@@ -95,7 +99,7 @@ auto create_args(int argc, char* argv[]) ...@@ -95,7 +99,7 @@ auto create_args(int argc, char* argv[])
"n or 0, no bias\n" "n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h") "a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp8bf16")
.insert("mask", .insert("mask",
"0", "0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
...@@ -176,6 +180,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method) ...@@ -176,6 +180,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
} }
} }
template <>
auto get_elimit<FmhaFwdFp8Bf16>(std::string init_method)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
{ {
// If we have enough to almost fill the SMs, then just use 1 split // If we have enough to almost fill the SMs, then just use 1 split
...@@ -1580,6 +1592,10 @@ int main(int argc, char* argv[]) ...@@ -1580,6 +1592,10 @@ int main(int argc, char* argv[])
{ {
return run<FmhaFwdFp8>(arg_parser) ? 0 : -2; return run<FmhaFwdFp8>(arg_parser) ? 0 : -2;
} }
else if(data_type == "fp8bf16")
{
return run<FmhaFwdFp8Bf16>(arg_parser) ? 0 : -2;
}
return -3; return -3;
} }
...@@ -107,6 +107,22 @@ struct FmhaFwdTypeConfig<FmhaFwdBf8> ...@@ -107,6 +107,22 @@ struct FmhaFwdTypeConfig<FmhaFwdBf8>
using ODataType = ck_tile::bf8_t; using ODataType = ck_tile::bf8_t;
}; };
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp8Bf16>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
struct FmhaMasks struct FmhaMasks
{ {
using NoMask = ck_tile::GenericAttentionMask<false>; using NoMask = ck_tile::GenericAttentionMask<false>;
...@@ -635,6 +651,7 @@ template <ck_tile::index_t HDim_, ...@@ -635,6 +651,7 @@ template <ck_tile::index_t HDim_,
bool kStoreLse_, bool kStoreLse_,
bool kHasDropout_, bool kHasDropout_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
bool kDoFp8RowwiseDynamicQuant_,
bool kPadS_, bool kPadS_,
bool kPadSK_, bool kPadSK_,
bool kPadD_, bool kPadD_,
...@@ -657,6 +674,7 @@ struct fmha_fwd_traits_ ...@@ -657,6 +674,7 @@ struct fmha_fwd_traits_
static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kDoFp8RowwiseDynamicQuant = kDoFp8RowwiseDynamicQuant_;
static constexpr bool kPadS = kPadS_; static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_; static constexpr bool kPadD = kPadD_;
...@@ -789,6 +807,7 @@ struct fmha_fwd_traits ...@@ -789,6 +807,7 @@ struct fmha_fwd_traits
bool has_lse; bool has_lse;
bool has_dropout; bool has_dropout;
bool do_fp8_static_quant; bool do_fp8_static_quant;
bool do_fp8_rowwise_dynamic_quant;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
...@@ -804,6 +823,7 @@ struct fmha_fwd_splitkv_traits ...@@ -804,6 +823,7 @@ struct fmha_fwd_splitkv_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse; bool has_lse;
bool do_fp8_static_quant; bool do_fp8_static_quant;
bool do_fp8_rowwise_dynamic_quant;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
......
...@@ -51,6 +51,8 @@ struct FmhaFwdKernel ...@@ -51,6 +51,8 @@ struct FmhaFwdKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kDoFp8RowwiseDynamicQuant =
FmhaPipeline::Problem::kDoFp8RowwiseDynamicQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kHasMask = FmhaMask::IsMasking;
...@@ -93,7 +95,8 @@ struct FmhaFwdKernel ...@@ -93,7 +95,8 @@ struct FmhaFwdKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) +
(kDoFp8StaticQuant ? "_squant" : "" ) + (kDoFp8RowwiseDynamicQuant ? "_rdquant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
// clang-format on // clang-format on
......
...@@ -46,15 +46,16 @@ struct BlockFmhaPipelineProblem ...@@ -46,15 +46,16 @@ struct BlockFmhaPipelineProblem
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr bool kDoFp8RowwiseDynamicQuant = Traits::kDoFp8RowwiseDynamicQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename QDataType_, template <typename QDataType_,
......
...@@ -18,19 +18,21 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */, ...@@ -18,19 +18,21 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kStoreLSE_, bool kStoreLSE_,
bool kHasDropout_, bool kHasDropout_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
bool kDoFp8RowwiseDynamicQuant_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaTraits struct TileFmhaTraits
{ {
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_; static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_; static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr auto BiasEnum = BiasEnum_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_; static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kDoFp8RowwiseDynamicQuant = kDoFp8RowwiseDynamicQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment