"...composable_kernel_rocm.git" did not exist on "f2398f612d3bb4432abaf1bad3aa598387116205"
Commit f31fad7d authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Revert "Use async splitkv pipeline for hdim<256 problems"

This reverts commit 658350b3.
parent 658350b3
...@@ -669,10 +669,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -669,10 +669,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256: ### [32, 64, 96, 128]: if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
else: else:
...@@ -716,7 +719,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -716,7 +719,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue continue
is_prefill = (mode == "group") is_prefill = (mode == "group" and pipeline.F_pagedkv == 't')
tile = prefill_tile if is_prefill else decode_tile tile = prefill_tile if is_prefill else decode_tile
k = Kernel(F_idx=0, k = Kernel(F_idx=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment