Commit 4633f073 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use vlayout=col for chunked prefill

parent bb093470
......@@ -733,8 +733,11 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if receipt == 2:
is_chunked_prefill = (mode == 'group' and pipeline.F_pagedkv == 't')
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
# use vlayout=row for chunked prefill
cond = cond and ((pipeline.F_vlayout == 'row' and not is_chunked_prefill) or (pipeline.F_vlayout == 'col' and is_chunked_prefill))
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
if not cond:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment