Commit 2da4b185 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use simpler type hint for backward compatibility

parent 59138975
......@@ -574,7 +574,7 @@ class FmhaFwdSplitKVCombineKernel:
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict[str, FmhaFwdTileSize]]:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1),
......@@ -592,7 +592,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict[str, FmhaFwd
else:
return None
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict[str, List[FmhaFwdSplitKVCombineTileSize]]]:
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
# tile size for decode tile size for prefill
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment