Commit 4d7efba1 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use simpler type hint for backward compatibility

parent 626ab5b6
...@@ -574,7 +574,7 @@ class FmhaFwdSplitKVCombineKernel: ...@@ -574,7 +574,7 @@ class FmhaFwdSplitKVCombineKernel:
# TODO: design a more practical way to do it # TODO: design a more practical way to do it
# this is current supported tile size per hdim # this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict[str, FmhaFwdTileSize]]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1), '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1),
...@@ -592,7 +592,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict[str, FmhaFwd ...@@ -592,7 +592,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict[str, FmhaFwd
else: else:
return None return None
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict[str, List[FmhaFwdSplitKVCombineTileSize]]]: def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
# tile size for decode tile size for prefill # tile size for decode tile size for prefill
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment