Commit 9d8d4c61 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use larger tile size for chunk prefill

parent d46196f2
...@@ -409,17 +409,17 @@ class FmhaFwdKernel: ...@@ -409,17 +409,17 @@ class FmhaFwdKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1), '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1) '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1)
} }
else: else:
return None return None
......
...@@ -12,9 +12,9 @@ from typing import List, Optional, Tuple, Union ...@@ -12,9 +12,9 @@ from typing import List, Optional, Tuple, Union
from codegen.cmake_config import * from codegen.cmake_config import *
from codegen.cpp_symbol_map import * from codegen.cpp_symbol_map import *
import codegen.ops.fmha_fwd
from codegen.ops.fmha_fwd import ( from codegen.ops.fmha_fwd import (
FmhaFwdTileSize, FmhaFwdTileSize,
FmhaFwdApiTrait,
FMHA_FWD_KERNEL_HEADER, FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_DTYPE, FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE, FMHA_FWD_API_PER_HDIM_CASE,
...@@ -574,14 +574,14 @@ class FmhaFwdSplitKVCombineKernel: ...@@ -574,14 +574,14 @@ class FmhaFwdSplitKVCombineKernel:
# TODO: design a more practical way to do it # TODO: design a more practical way to do it
# this is current supported tile size per hdim # this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict[str, FmhaFwdTileSize]]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1), '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1),
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
...@@ -592,20 +592,21 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -592,20 +592,21 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
else: else:
return None return None
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict[str, List[FmhaFwdSplitKVCombineTileSize]]]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1), # tile size for decode tile size for prefill
'64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1), '32' : [FmhaFwdSplitKVCombineTileSize(16, 16, -1), FmhaFwdSplitKVCombineTileSize(64, 16, -1)],
## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), '64' : [FmhaFwdSplitKVCombineTileSize(32, 32, -1), FmhaFwdSplitKVCombineTileSize(64, 32, -1)],
'128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), ### '96' : [FmhaFwdSplitKVCombineTileSize(32, 64, -1), FmhaFwdSplitKVCombineTileSize(64, 64, -1)],
'256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1), '128' : [FmhaFwdSplitKVCombineTileSize(32, 64, -1), FmhaFwdSplitKVCombineTileSize(64, 64, -1)],
'256' : [FmhaFwdSplitKVCombineTileSize(32, 128, -1), FmhaFwdSplitKVCombineTileSize(64, 128, -1)],
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), '64' : [FmhaFwdSplitKVCombineTileSize(64, 32, -1)],
'128' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), '128' : [FmhaFwdSplitKVCombineTileSize(64, 64, -1)],
'256' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), '256' : [FmhaFwdSplitKVCombineTileSize(64, 128, -1)],
} }
else: else:
return None return None
...@@ -655,18 +656,28 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -655,18 +656,28 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
api_pool = FmhaFwdSplitKVApiPool(mask_impl) api_pool = FmhaFwdSplitKVApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys(): for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype) prefill_tiles = codegen.ops.fmha_fwd.get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None: decode_tiles = get_fmha_fwd_tile_dict_from_dtype(dtype)
if decode_tiles == None:
continue continue
# make sure if all the hdim str keys in decode_tiles are also available in prefill_tiles
assert all(tile in prefill_tiles.keys() for tile in decode_tiles.keys())
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): for hdim_str, mode in itertools.product(decode_tiles.keys(), MODE_MAP.keys()):
tile = d[hdim_str] prefill_tile = prefill_tiles[hdim_str]
decode_tile = decode_tiles[hdim_str]
hdim = int(hdim_str) hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim): for pipeline in get_pipelines(dtype, hdim):
if mode == "group": if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't': if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue continue
is_prefill = (mode == "group" and pipeline.F_pagedkv == 't')
tile = prefill_tile if is_prefill else decode_tile
k = Kernel(F_idx=0, k = Kernel(F_idx=0,
F_hdim=hdim, F_hdim=hdim,
F_dtype=dtype, F_dtype=dtype,
...@@ -720,9 +731,9 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis ...@@ -720,9 +731,9 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
continue continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str] tiles = d[hdim_str]
hdim = int(hdim_str) hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim): for tile, pipeline in itertools.product(tiles, get_pipelines(dtype, hdim)):
if mode == "group": if mode == "group":
if pipeline.F_spad != 't': if pipeline.F_spad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment