# SPDX-License-Identifier: MIT # Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import argparse import itertools from pathlib import Path from typing import List, Optional, Tuple from dataclasses import dataclass import copy import fnmatch DTYPE_MAP = { "fp16": "ck_tile::fp16_t", "bf16": "ck_tile::bf16_t", "fp8" : "ck_tile::fp8_t" } DTYPE_BITS = { "fp32": 32, "fp16": 16, "bf16": 16, "fp8" : 8, "bf8" : 8 } MASK_IMPL = { "generic" : "ck_tile::GenericAttentionMask", "simplified" : "ck_tile::SimplifiedGenericAttentionMask" } MASK_SIMPLIFIED_MAP = { "s_no" : "ck_tile::SimplifiedGenericAttentionMask", "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", } MASK_MAP = { "no" : "FmhaMasks::NoMask", "causal" : "FmhaMasks::CausalMask", "generic" : "FmhaMasks::GenericMask" } BIAS_MAP = { "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" } # TODO: this is ugly BIAS_CHECK_MAP = { "no" : "bias_enum::no_bias", "bias" : "bias_enum::elementwise_bias", "alibi" : "bias_enum::alibi" } MODE_MAP = { "batch" : "false", "group" : "true" } LAYOUT_MAP = { "row" : "true", "col" : "false" } PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", } PIPELINE_ENUM_MAP = { "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", } BOOL_MAP = { "t" : "true", "f" : "false" } TILE_PARTITIONER_MAP = { "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", } GEN_DIR = "" # in Cmake, have to generate files in same folder FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py #include "fmha_fwd.hpp" """ FMHA_FWD_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_bias}, false, {F_lse}, {F_dropout}, {F_squant}, {F_occupancy}>; using fmha_mask_{F_idx} = {F_mask}; using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::QDataType, typename FmhaFwdTypeConfig::KDataType, typename FmhaFwdTypeConfig::VDataType, typename FmhaFwdTypeConfig::SaccDataType, typename FmhaFwdTypeConfig::SMPLComputeDataType, typename FmhaFwdTypeConfig::BiasDataType, typename FmhaFwdTypeConfig::RandValOutputDataType, typename FmhaFwdTypeConfig::LSEDataType, typename FmhaFwdTypeConfig::PDataType, typename FmhaFwdTypeConfig::OaccDataType, typename FmhaFwdTypeConfig::ODataType, fmha_shape_{F_idx}, {F_mode}, fmha_mask_{F_idx}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< fmha_pipeline_problem_{F_idx}>; using fmha_epilogue_{F_idx} = ck_tile::Default2DEpilogue::OaccDataType, typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel<{F_tile_partitioner}, fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include template<> float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}; if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" FMHA_FWD_API=""" float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} return r; }} """ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} """ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ {F_inner_dispatch} }} """ MASK_CHECK_MAP = { "no" : "t.mask_type == mask_enum::no_mask", "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", "generic" : "t.mask_type == mask_enum::window_generic", } MASK_SIMPLIFIED_CHECK_MAP = { "s_no" : "t.mask_type == mask_enum::no_mask", "s_mask" : "t.mask_type != mask_enum::no_mask", } FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); }} """ def get_mask_map(mask : str): if mask == "generic": return MASK_MAP elif mask == "simplified": return MASK_SIMPLIFIED_MAP else: assert False return None def get_mask_check_map(mask : str): if mask == "generic": return MASK_CHECK_MAP elif mask == "simplified": return MASK_SIMPLIFIED_CHECK_MAP else: assert False return None @dataclass class FmhaFwdApiTrait: pipeline_tag : str # sync with fmha_fwd_traits<>, to generate fallback calls hdim : str dtype : str # data type mode : str # value from MODE_MAP bm0 : int # tile size along q seqlen (block size) bn0 : int # tile size along qk seqlen bk0 : int # tile size along qk gemm unroll bn1 : int # tile size along v head_dim bk1 : int # tile size along kv gemm unroll bk0blen : int vlayout : str mask : str bias : str # lse : str # dropout : str squant : str # spad : str skpad : str dpad : str dvpad : str @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' @property def scheck(self) -> str: if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': if self.spad == 't' : return 'true' # always support else : return 'true' elif self.pipeline_tag in ['qr']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False @property def skcheck(self) -> str: if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' elif self.pipeline_tag in ['qr', 'qr_fp8']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' else: assert False @property def dcheck(self) -> str: if self.pipeline_tag == 'qr_async': vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False elif self.pipeline_tag in ['qr']: if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {self.bk0blen} == 0' else: assert False @property def dvcheck(self) -> str: if self.pipeline_tag == 'qr_async': vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False elif self.pipeline_tag in ['qr']: if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {self.bk0blen} == 0' else: assert False @dataclass class FmhaFwdPipeline: tag : str F_vlayout : str # row/col F_spad : str # true/false F_skpad : str # F_dpad : str # F_dvpad : str # F_bias : str # true/false F_lse : str # F_dropout : str # F_squant : str # F_mask : str # value from MASK_MAP @property def name(self) -> str: def pad_name() -> str: n = '' if self.F_spad == 't': n += 's' if self.F_skpad == 't' : n += 'sk' if self.F_dpad == 't' : n += 'd' if self.F_dvpad == 't' : n += 'dv' if n != '' : n = 'p' + n return n pn = pad_name() n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' if self.F_bias != 'no' : n += f'_{self.F_bias}' if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_lse == 't' : n += '_lse' if self.F_dropout == 't' : n += '_dropout' if self.F_squant == 't' : n += '_squant' return n class FmhaFwdApiPool: def __init__(self, mask_impl): self.pool = dict() self.mask_impl = mask_impl def register_traits(self, trait : FmhaFwdApiTrait) -> None: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() if trait.hdim not in self.pool[trait.dtype].keys(): self.pool[trait.dtype][trait.hdim] = list() self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) @property def api(self) -> str: per_dtypes=str() for i, dtype in enumerate(self.pool.keys()): per_hdim_case=str() for j, hdim in enumerate(self.pool[dtype].keys()): traits=self.pool[dtype][hdim] inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) @dataclass class FmhaFwdTileSize: F_bm0 : int # tile size along q seqlen (block size) F_bn0 : int # tile size along k seqlen F_bk0 : int # tile size along qk gemm unroll F_bn1 : int # tile size along v head_dim F_bk1 : int # tile size along kv gemm unroll F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) F_rm : int # number of warps along q seqlen (block warps) F_rn : int # number of warps along k seqlen(not used) F_rk : int # number of warps along gemm-k(not used) F_wm : int # warp size along m (warp size) F_wn : int # warp size along n F_wk : int # warp size along k F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy @property def name(self) -> str: return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") @dataclass class FmhaFwdKernel: direction : str F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim F_dtype : str # data type F_mode : str # value from MODE_MAP F_tile : FmhaFwdTileSize F_pipeline : FmhaFwdPipeline mask_impl : str def get_tp(self) -> str: if self.F_mode == 'group': return 'hbs' else: return 'shb' @property def template(self) -> str: kernel_body = str() return FMHA_FWD_KERNEL_HEADER + \ FMHA_FWD_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, F_dtype = DTYPE_MAP[self.F_dtype], F_bm0 = self.F_tile.F_bm0, F_bn0 = self.F_tile.F_bn0, F_bk0 = self.F_tile.F_bk0, F_bn1 = self.F_tile.F_bn1, F_bk1 = self.F_tile.F_bk1, F_bk0blen = self.F_tile.F_bk0blen, F_rm = self.F_tile.F_rm, F_rn = self.F_tile.F_rn, F_rk = self.F_tile.F_rk, F_wm = self.F_tile.F_wm, F_wn = self.F_tile.F_wn, F_wk = self.F_tile.F_wk, F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_occupancy = self.F_tile.F_occupancy, F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) @property def name(self) -> str: # TODO: we don't encode idx here return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ self.F_tile.name + '_' + self.F_pipeline.name @property def filename(self) -> str: return self.name + ".cpp" def api_trait(self) -> FmhaFwdApiTrait: return FmhaFwdApiTrait( pipeline_tag=self.F_pipeline.tag, hdim=str(self.F_hdim), dtype=self.F_dtype, mode=self.F_mode, bm0=self.F_tile.F_bm0, bn0=self.F_tile.F_bn0, bk0=self.F_tile.F_bk0, bn1=self.F_tile.F_bn1, bk1=self.F_tile.F_bk1, bk0blen=self.F_tile.F_bk0blen, vlayout=self.F_pipeline.F_vlayout, mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, dropout=self.F_pipeline.F_dropout, squant=self.F_pipeline.F_squant, spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad) # TODO: design a more practical way to do it # this is current supported tile size per hdim def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: if direction == 'fwd': if dtype == 'fp16' or dtype == 'bf16': return { '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) } else: return None else: return None def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): if hdim == 256: # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) else: pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) if receipt == 1: pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) else: assert False return pipelines gen = list() api_pool = FmhaFwdApiPool(mask_impl) for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()): d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) if d == None: continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: continue api_pool.register_traits(k.api_trait()) gen.append(k) return (api_pool, gen) BWD_DQDKDV_PIPELINE_MAP = { "ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS", "ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR", } BWD_DQDKDV_PIPELINE_ENUM_MAP = { "ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR", "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS", "ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR", } FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py #include "fmha_bwd.hpp" """ FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; // TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape // G0&G2 -> GSdP // G1&G3 -> GdKV // G4 -> GdQ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_bias}, {F_dbias}, false, {F_dropout}, false, {F_occupancy}>; using fmha_mask_{F_idx} = {F_mask}; using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< typename FmhaBwdTypeConfig::QDataType, typename FmhaBwdTypeConfig::KDataType, typename FmhaBwdTypeConfig::VDataType, typename FmhaBwdTypeConfig::GemmDataType, typename FmhaBwdTypeConfig::LSEDataType, typename FmhaBwdTypeConfig::AccDataType, typename FmhaBwdTypeConfig::DDataType, typename FmhaBwdTypeConfig::BiasDataType, typename FmhaBwdTypeConfig::RandValOutputDataType, typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::QGradDataType, typename FmhaBwdTypeConfig::KGradDataType, typename FmhaBwdTypeConfig::VGradDataType, typename FmhaBwdTypeConfig::BiasGradDataType, fmha_bwd_shape_{F_idx}, {F_mode}, fmha_mask_{F_idx}, fmha_bwd_trait_{F_idx}>; using fmha_bwd_pipeline_{F_idx} = {F_pipeline}< fmha_bwd_pipeline_problem_{F_idx}>; using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, false, false>>; using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, false, false>>; using fmha_bwd_dq_dk_dv_kernel_{F_idx} = ck_tile::FmhaBwdDQDKDVKernel, fmha_bwd_pipeline_{F_idx}, fmha_bwd_dk_epilogue_{F_idx}, fmha_bwd_dv_epilogue_{F_idx}>; using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include template<> float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template<> void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} template<> std::string fmha_bwd_dq_dk_dv_get_name_() {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; return k_::GetName(); }} """ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" FMHA_BWD_API=""" #include template float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << std::flush; return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} ); }} float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} return r; }} """ FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} """ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ {F_inner_dispatch} }} """ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; r = fmha_bwd_(s, a); return r; }} """ @dataclass class FmhaBwdDQDKDVApiTrait: pipeline : str # sync with fmha_bwd_traits<>, to generate fallback calls hdim : str dtype : str # data type mode : str # value from MODE_MAP bm0 : int # tile size along q seqlen (block size) bn0 : int # tile size along k seqlen bhdq : int # q head_dim bhdv : int # v head_dim mask : str bias : str dbias : str dropout : str spad : str skpad : str dpad : str dvpad : str @property def name(self) -> str: return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' def scheck(self, spad1 : str) -> str: if self.mode == 'group': return 'true' # always support elif self.spad == 't' and spad1 == 't': return f'a.seqlen_q % {self.bm0} != 0' elif self.spad == 'f' and spad1 == 't': return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize else: # self.skpad == 'f' and skpad1 == 'f' return f'a.seqlen_q % 256 == 0' # BlockSize @property def skcheck(self) -> str: if self.mode == 'group': return 'true' # always support elif self.skpad == 't': return f'a.seqlen_k % {self.bn0} != 0' else: return f'a.seqlen_k % {self.bn0} == 0' @property def dcheck(self) -> str: if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' else : return f'a.hdim_q % {self.bhdq} == 0' @property def dvcheck(self) -> str: if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' else : return f'a.hdim_v % {self.bhdv} == 0' class FmhaBwdApiPool: def __init__(self, mask_impl): self.dq_dk_dv_pool = dict() self.mask_impl = mask_impl def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: # TODO: do we need to check duplication? if trait.dtype not in self.dq_dk_dv_pool.keys(): self.dq_dk_dv_pool[trait.dtype] = dict() if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) @property def api(self) -> str: per_dtypes=str() for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): per_hdim_case=str() for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): traits=self.dq_dk_dv_pool[dtype][hdim] inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' for spad1 in ["t", "f"]: if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): continue inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) # GEMM0: Q@K=S^T # GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) # GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) # GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) # GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) # Is it necessary to distinguish between K0~K4? @dataclass class FmhaBwdDQDKDVTileSize: F_bm0 : int # tile size along q seqlen (block size) F_bn0 : int # tile size along k seqlen F_bk0 : int # tile size along gemm0 unroll(F_bhdq) F_bk1 : int # tile size along gemm1 unroll(F_bm0) F_bk2 : int # tile size along gemm2 unroll(F_bhdv) F_bk3 : int # tile size along gemm3 unroll(F_bm0) F_bk4 : int # tile size along gemm4 unroll(F_bn0) F_bhdq : int # q head_dim F_bhdv : int # v head_dim F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2 F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3 F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3 F_rm2 : int # number of warps along k seqlen (block warps) in gemm4 F_rn2 : int # number of warps along q seqlen (block warps) in gemm4 F_rk2 : int # number of warps along gemm-k (not used) in gemm4 F_wm : int # warp size along m (warp size) F_wn : int # warp size along n F_wk : int # warp size along k F_occupancy : int # occupancy @property def name(self) -> str: return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" @dataclass class FmhaBwdDQDKDVKernel: direction : str F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim F_dtype : str # data type F_tile : FmhaBwdDQDKDVTileSize F_spad : str # true/false F_skpad : str # F_dpad : str # F_dvpad : str # F_bias : str # F_dbias : str # F_dropout : str # F_mask : str # value from MASK_MAP F_mode : str # value from MODE_MAP F_pipeline : str mask_impl : str @property def template(self) -> str: return FMHA_BWD_KERNEL_HEADER + \ FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, F_dtype = DTYPE_MAP[self.F_dtype], F_bm0 = self.F_tile.F_bm0, F_bn0 = self.F_tile.F_bn0, F_bk0 = self.F_tile.F_bk0, F_bk1 = self.F_tile.F_bk1, F_bk2 = self.F_tile.F_bk2, F_bk3 = self.F_tile.F_bk3, F_bk4 = self.F_tile.F_bk4, F_bhdq = self.F_tile.F_bhdq, F_bhdv = self.F_tile.F_bhdv, F_rm0 = self.F_tile.F_rm0, F_rn0 = self.F_tile.F_rn0, F_rk0 = self.F_tile.F_rk0, F_rm1 = self.F_tile.F_rm1, F_rn1 = self.F_tile.F_rn1, F_rk1 = self.F_tile.F_rk1, F_rm2 = self.F_tile.F_rm2, F_rn2 = self.F_tile.F_rn2, F_rk2 = self.F_tile.F_rk2, F_wm = self.F_tile.F_wm, F_wn = self.F_tile.F_wn, F_wk = self.F_tile.F_wk, F_spad = BOOL_MAP[self.F_spad], F_skpad = BOOL_MAP[self.F_skpad], F_dpad = BOOL_MAP[self.F_dpad], F_dvpad = BOOL_MAP[self.F_dvpad], F_bias = BIAS_MAP[self.F_bias], F_dbias = BOOL_MAP[self.F_dbias], F_dropout = BOOL_MAP[self.F_dropout], F_occupancy = self.F_tile.F_occupancy, F_mask = get_mask_map(self.mask_impl)[self.F_mask], F_mode = MODE_MAP[self.F_mode], F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) @property def name(self) -> str: def pad_name() -> str: n = '' if self.F_spad == 't': n += 's' if self.F_skpad == 't' : n += 'sk' if self.F_dpad == 't' : n += 'd' if self.F_dvpad == 't' : n += 'dv' if n != '' : n = 'p' + n return n pn = pad_name() n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name if pn != '' : n += f'_{pn}' if self.F_bias != 'no' : n += f'_{self.F_bias}' if self.F_dbias == 't' : n += '_dbias' if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_dropout == 't' : n += '_dropout' return n @property def filename(self) -> str: return self.name + ".cpp" def api_trait(self) -> FmhaBwdDQDKDVApiTrait: return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, hdim=str(self.F_hdim), dtype=self.F_dtype, mode=self.F_mode, bm0=self.F_tile.F_bm0, bn0=self.F_tile.F_bn0, bhdq=self.F_tile.F_bhdq, bhdv=self.F_tile.F_bhdv, mask=self.F_mask, bias=self.F_bias, dbias=self.F_dbias, dropout=self.F_dropout, spad=self.F_spad, skpad=self.F_skpad, dpad=self.F_dpad, dvpad=self.F_dvpad) # TODO: design a more practical way to do it # this is current supported tile size & pipeline. def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: if direction == 'bwd': if dtype == 'fp16' or dtype == 'bf16': return { '32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1), "qs_ks_vr_dos"], '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), "qs_ks_vr_dos"], '128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), "ks_vr"] } else: return None else: return None def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: # TODO: we don't support tuning yet, so pick up one value for pad # support this in future gen = list() api_pool = FmhaBwdApiPool(mask_impl) for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) if d == None: continue for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): tile = d[hdim_str][0] ppl = d[hdim_str][1] hdim = int(hdim_str) if (mode == "group") and (spad == "f" or skpad == "f"): continue if ((bias == "no" or bias == "alibi") and dbias == "t"): continue k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, F_pipeline=ppl, mask_impl=mask_impl) if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'alibi'] if not cond: continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) return (api_pool, gen) FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>; using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, /* BlockSize = */ 256, {F_hdim}, {F_mode}, fmha_bwd_dot_do_o_trait_{F_idx}>; using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO< fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; using fmha_bwd_dot_do_o_kernel_{F_idx} = ck_tile::FmhaBwdOGradDotOKernel, fmha_bwd_dot_do_o_{F_idx}>; using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; #include template<> float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template<> void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} template<> std::string fmha_bwd_dot_do_o_get_name_() {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; return k_::GetName(); }} """ @dataclass class FmhaBwdOGradDotOKernel: direction : str F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim F_dtype : str # data type F_spad : str # true/false F_dvpad : str # F_mode : str # value from MODE_MAP F_occupancy : int @property def template(self) -> str: return FMHA_BWD_KERNEL_HEADER + \ FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, F_dtype = DTYPE_MAP[self.F_dtype], F_spad = BOOL_MAP[self.F_spad], F_dvpad = BOOL_MAP[self.F_dvpad], F_mode = MODE_MAP[self.F_mode], F_occupancy = self.F_occupancy) @property def name(self) -> str: def pad_name() -> str: n = '' if self.F_spad == 't': n += 's' if self.F_dvpad == 't' : n += 'dv' if n != '' : n = 'p' + n return n pn = pad_name() n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" if pn != '' : n += f'_{pn}' return n @property def filename(self) -> str: return self.name + ".cpp" def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: # TODO: we don't support tuning yet, so pick up one value for pad/occupancy # support this in future def get_occupancy(dtype, hdim): return 2 gen = list() for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) if d == None: continue for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): hdim = int(hdim_str) if (mode == "group" and spad == "f"): continue k = FmhaBwdOGradDotOKernel(direction=direction+"_dot_do_o", F_idx=0, F_hdim=hdim, F_dtype=dtype, F_spad=spad, F_dvpad=dvpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim)) gen.append(k) return gen def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) / GEN_DIR output_dir.mkdir(parents=True, exist_ok=True) if direction == 'fwd': api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) else: kernels = get_bwd_dot_do_o_blobs() for kernel in kernels: write_single_bwd_dot_do_o_kernel(kernel, output_dir) api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) write_bwd_api(api_pool, output_dir) # list all the files that will be generated def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) with file_path.open('a') as f: if direction == 'fwd': _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") else: kernels = get_bwd_dot_do_o_blobs() for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", description="gen api for CK fmha kernel", ) parser.add_argument( "-d", "--direction", default='fwd', choices=['fwd', 'bwd'], required=False, help="choose the direction of kernels(default: fwd)" ) parser.add_argument( "-o", "--output_dir", required=False, help="write all the blobs into a directory" ) parser.add_argument( "-l", "--list_blobs", required=False, help="list all the kernels to a file" ) # TODO: if using filter, must apply same value to output_dir and list_blobs parser.add_argument( "-f", "--filter", required=False, help="filter out kernels that need to generate, using fnmatch module" ) parser.add_argument( "-m", "--mask", default="simplified", required=False, help="mask implementation, simplified/generic" ) parser.add_argument( "-r", "--receipt", default=0, required=False, help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ " 1: generate more instance to cover all hdim\n" + \ " 2: Only generate instance for Flash attention integration" ) args = parser.parse_args() if args.list_blobs is not None: list_blobs(args.list_blobs, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) else: write_blobs(args.output_dir, args.direction, args.filter, int(args.receipt), mask_impl=args.mask)