Commit 4af3f889 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Simplify flashinfer indices update for prefill (#2074)


Co-authored-by: default avatarkavioyu <kavioyu@tencent.com>
Co-authored-by: default avatarkavioyu <kavioyu@gmail.com>
parent df7fe452
......@@ -8,7 +8,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
"""
from enum import Enum, auto
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List
import torch
import triton
......@@ -136,15 +136,17 @@ class FlashInferAttnBackend(AttentionBackend):
prefix_lens = forward_batch.extend_prefix_lens
# Some heuristics to check whether to use ragged forward
use_ragged = False
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
use_ragged = True
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
else:
use_ragged = False
extend_no_prefix = False
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens,
......@@ -334,7 +336,12 @@ class FlashInferIndicesUpdaterDecode:
self.update = self.update_single_wrapper
def update(
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers: List,
encoder_lens: torch.Tensor,
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
......@@ -344,8 +351,8 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers=None,
encoder_lens=None,
decode_wrappers: List,
encoder_lens: torch.Tensor,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward(
......@@ -362,8 +369,8 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers=None,
encoder_lens=None,
decode_wrappers: List,
encoder_lens: torch.Tensor,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
......@@ -393,11 +400,11 @@ class FlashInferIndicesUpdaterDecode:
def update_cross_attention(
self,
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=None,
encoder_lens=None,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers: List,
encoder_lens: torch.Tensor,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
......@@ -424,11 +431,11 @@ class FlashInferIndicesUpdaterDecode:
def call_begin_forward(
self,
wrapper,
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
kv_indptr,
kv_start_idx,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
):
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
......@@ -494,23 +501,40 @@ class FlashInferIndicesUpdaterPrefill:
assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
def update(
self,
req_pool_indices: torch.Tnesor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
use_ragged: bool,
encoder_lens: torch.Tensor,
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
def update_single_wrapper(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
self,
req_pool_indices: torch.Tnesor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
use_ragged: bool,
encoder_lens: torch.Tensor,
):
if use_ragged:
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else:
paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
self.call_begin_forward(
self.wrapper_ragged,
self.wrappers_paged[0],
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens,
prefix_lens,
None,
......@@ -520,7 +544,13 @@ class FlashInferIndicesUpdaterPrefill:
)
def update_sliding_window(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
use_ragged: bool,
encoder_lens: torch.Tensor,
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -529,9 +559,12 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens,
torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
)
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else:
# full attention
paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
kv_start_idx = seq_lens - paged_kernel_lens
self.call_begin_forward(
......@@ -539,6 +572,7 @@ class FlashInferIndicesUpdaterPrefill:
self.wrappers_paged[wrapper_id],
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens,
prefix_lens,
kv_start_idx,
......@@ -548,23 +582,32 @@ class FlashInferIndicesUpdaterPrefill:
)
def update_cross_attention(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
use_ragged: bool,
encoder_lens: torch.Tensor,
):
for wrapper_id in range(2):
if wrapper_id == 0:
# normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
paged_kernel_lens_sum = seq_lens_sum
else:
# cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
self.call_begin_forward(
self.wrapper_ragged,
self.wrappers_paged[wrapper_id],
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens,
prefix_lens,
kv_start_idx,
......@@ -577,19 +620,22 @@ class FlashInferIndicesUpdaterPrefill:
self,
wrapper_ragged,
wrapper_paged,
req_pool_indices,
paged_kernel_lens,
seq_lens,
prefix_lens,
kv_start_idx,
kv_indptr,
qo_indptr,
use_ragged,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
seq_lens: torch.Tensor,
prefix_lens: torch.Tensor,
kv_start_idx: torch.Tensor,
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
):
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
......
......@@ -64,8 +64,7 @@ class TritonAttnBackend(AttentionBackend):
max_extend_len = None
else:
start_loc = attn_logits = max_seq_len = None
prefix_lens = forward_batch.extend_prefix_lens
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
......
......@@ -109,6 +109,7 @@ class ForwardBatch:
extend_seq_lens: Optional[torch.Tensor] = None
extend_prefix_lens: Optional[torch.Tensor] = None
extend_start_loc: Optional[torch.Tensor] = None
extend_prefix_lens_cpu: Optional[List[int]] = None
extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None
......@@ -250,6 +251,7 @@ class ForwardBatch:
ret.positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
)
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
......
......@@ -345,7 +345,7 @@ class LlavaBaseForCausalLM(nn.Module):
# Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0
for i in range(bs):
if not need_vision[i]:
......
......@@ -169,7 +169,7 @@ class LlavaVidForCausalLM(nn.Module):
# Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0
for i in range(bs):
if not need_vision[i]:
......
......@@ -616,7 +616,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
inputs_embeds = self.model.embed_tokens(input_ids)
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None:
continue
......
......@@ -66,7 +66,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.84)
self.assertGreater(metrics["score"], 0.835)
if __name__ == "__main__":
......
......@@ -37,7 +37,7 @@ class TestMLA(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.5
self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
......@@ -49,7 +49,7 @@ class TestMLA(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.8
self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment