"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "325a06c2deef25067f8b37e73358b4569e13def7"
Commit 658350b3 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use async splitkv pipeline for hdim<256 problems

parent 76b31460
...@@ -669,13 +669,10 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -669,13 +669,10 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: if hdim == 256: ### [32, 64, 96, 128]:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
else: else:
...@@ -719,7 +716,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -719,7 +716,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue continue
is_prefill = (mode == "group" and pipeline.F_pagedkv == 't') is_prefill = (mode == "group")
tile = prefill_tile if is_prefill else decode_tile tile = prefill_tile if is_prefill else decode_tile
k = Kernel(F_idx=0, k = Kernel(F_idx=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment