Commit 4af3f889 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Simplify flashinfer indices update for prefill (#2074)


Co-authored-by: default avatarkavioyu <kavioyu@tencent.com>
Co-authored-by: default avatarkavioyu <kavioyu@gmail.com>
parent df7fe452
...@@ -8,7 +8,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an ...@@ -8,7 +8,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
""" """
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, List
import torch import torch
import triton import triton
...@@ -136,15 +136,17 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -136,15 +136,17 @@ class FlashInferAttnBackend(AttentionBackend):
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
# Some heuristics to check whether to use ragged forward # Some heuristics to check whether to use ragged forward
use_ragged = False
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1: if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
use_ragged = True use_ragged = True
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item() else:
use_ragged = False
extend_no_prefix = False
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens, prefix_lens,
use_ragged=use_ragged, use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens, encoder_lens=forward_batch.encoder_lens,
...@@ -334,7 +336,12 @@ class FlashInferIndicesUpdaterDecode: ...@@ -334,7 +336,12 @@ class FlashInferIndicesUpdaterDecode:
self.update = self.update_single_wrapper self.update = self.update_single_wrapper
def update( def update(
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers: List,
encoder_lens: torch.Tensor,
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
...@@ -344,8 +351,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -344,8 +351,8 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers=None, decode_wrappers: List,
encoder_lens=None, encoder_lens: torch.Tensor,
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward( self.call_begin_forward(
...@@ -362,8 +369,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -362,8 +369,8 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers=None, decode_wrappers: List,
encoder_lens=None, encoder_lens: torch.Tensor,
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
...@@ -393,11 +400,11 @@ class FlashInferIndicesUpdaterDecode: ...@@ -393,11 +400,11 @@ class FlashInferIndicesUpdaterDecode:
def update_cross_attention( def update_cross_attention(
self, self,
req_pool_indices, req_pool_indices: torch.Tensor,
seq_lens, seq_lens: torch.Tensor,
seq_lens_sum, seq_lens_sum: int,
decode_wrappers=None, decode_wrappers: List,
encoder_lens=None, encoder_lens: torch.Tensor,
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
...@@ -424,11 +431,11 @@ class FlashInferIndicesUpdaterDecode: ...@@ -424,11 +431,11 @@ class FlashInferIndicesUpdaterDecode:
def call_begin_forward( def call_begin_forward(
self, self,
wrapper, wrapper,
req_pool_indices, req_pool_indices: torch.Tensor,
paged_kernel_lens, paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum, paged_kernel_lens_sum: int,
kv_indptr, kv_indptr: torch.Tensor,
kv_start_idx, kv_start_idx: torch.Tensor,
): ):
bs = len(req_pool_indices) bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
...@@ -494,23 +501,40 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -494,23 +501,40 @@ class FlashInferIndicesUpdaterPrefill:
assert self.attn_backend.num_wrappers == 1 assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper self.update = self.update_single_wrapper
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens): def update(
self,
req_pool_indices: torch.Tnesor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
use_ragged: bool,
encoder_lens: torch.Tensor,
):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
def update_single_wrapper( def update_single_wrapper(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens self,
req_pool_indices: torch.Tnesor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
use_ragged: bool,
encoder_lens: torch.Tensor,
): ):
if use_ragged: if use_ragged:
paged_kernel_lens = prefix_lens paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else: else:
paged_kernel_lens = seq_lens paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
self.call_begin_forward( self.call_begin_forward(
self.wrapper_ragged, self.wrapper_ragged,
self.wrappers_paged[0], self.wrappers_paged[0],
req_pool_indices, req_pool_indices,
paged_kernel_lens, paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens, seq_lens,
prefix_lens, prefix_lens,
None, None,
...@@ -520,7 +544,13 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -520,7 +544,13 @@ class FlashInferIndicesUpdaterPrefill:
) )
def update_sliding_window( def update_sliding_window(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
use_ragged: bool,
encoder_lens: torch.Tensor,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -529,9 +559,12 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -529,9 +559,12 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens, seq_lens,
torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens, torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
) )
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else: else:
# full attention # full attention
paged_kernel_lens = seq_lens paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
kv_start_idx = seq_lens - paged_kernel_lens kv_start_idx = seq_lens - paged_kernel_lens
self.call_begin_forward( self.call_begin_forward(
...@@ -539,6 +572,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -539,6 +572,7 @@ class FlashInferIndicesUpdaterPrefill:
self.wrappers_paged[wrapper_id], self.wrappers_paged[wrapper_id],
req_pool_indices, req_pool_indices,
paged_kernel_lens, paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens, seq_lens,
prefix_lens, prefix_lens,
kv_start_idx, kv_start_idx,
...@@ -548,23 +582,32 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -548,23 +582,32 @@ class FlashInferIndicesUpdaterPrefill:
) )
def update_cross_attention( def update_cross_attention(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
use_ragged: bool,
encoder_lens: torch.Tensor,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
# normal attention # normal attention
paged_kernel_lens = seq_lens paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens kv_start_idx = encoder_lens
paged_kernel_lens_sum = seq_lens_sum
else: else:
# cross attention # cross attention
paged_kernel_lens = encoder_lens paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens) kv_start_idx = torch.zeros_like(encoder_lens)
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
self.call_begin_forward( self.call_begin_forward(
self.wrapper_ragged, self.wrapper_ragged,
self.wrappers_paged[wrapper_id], self.wrappers_paged[wrapper_id],
req_pool_indices, req_pool_indices,
paged_kernel_lens, paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens, seq_lens,
prefix_lens, prefix_lens,
kv_start_idx, kv_start_idx,
...@@ -577,19 +620,22 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -577,19 +620,22 @@ class FlashInferIndicesUpdaterPrefill:
self, self,
wrapper_ragged, wrapper_ragged,
wrapper_paged, wrapper_paged,
req_pool_indices, req_pool_indices: torch.Tensor,
paged_kernel_lens, paged_kernel_lens: torch.Tensor,
seq_lens, paged_kernel_lens_sum: int,
prefix_lens, seq_lens: torch.Tensor,
kv_start_idx, prefix_lens: torch.Tensor,
kv_indptr, kv_start_idx: torch.Tensor,
qo_indptr, kv_indptr: torch.Tensor,
use_ragged, qo_indptr: torch.Tensor,
use_ragged: bool,
): ):
bs = len(req_pool_indices) bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
req_pool_indices, req_pool_indices,
......
...@@ -64,8 +64,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -64,8 +64,7 @@ class TritonAttnBackend(AttentionBackend):
max_extend_len = None max_extend_len = None
else: else:
start_loc = attn_logits = max_seq_len = None start_loc = attn_logits = max_seq_len = None
prefix_lens = forward_batch.extend_prefix_lens max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
......
...@@ -109,6 +109,7 @@ class ForwardBatch: ...@@ -109,6 +109,7 @@ class ForwardBatch:
extend_seq_lens: Optional[torch.Tensor] = None extend_seq_lens: Optional[torch.Tensor] = None
extend_prefix_lens: Optional[torch.Tensor] = None extend_prefix_lens: Optional[torch.Tensor] = None
extend_start_loc: Optional[torch.Tensor] = None extend_start_loc: Optional[torch.Tensor] = None
extend_prefix_lens_cpu: Optional[List[int]] = None
extend_seq_lens_cpu: Optional[List[int]] = None extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_start_lens_cpu: Optional[List[int]] = None
...@@ -250,6 +251,7 @@ class ForwardBatch: ...@@ -250,6 +251,7 @@ class ForwardBatch:
ret.positions, ret.extend_start_loc = compute_position_triton( ret.positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
) )
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
......
...@@ -345,7 +345,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -345,7 +345,7 @@ class LlavaBaseForCausalLM(nn.Module):
# Fill in the placeholder for the image # Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0 pt = 0
for i in range(bs): for i in range(bs):
if not need_vision[i]: if not need_vision[i]:
......
...@@ -169,7 +169,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -169,7 +169,7 @@ class LlavaVidForCausalLM(nn.Module):
# Fill in the placeholder for the image # Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0 pt = 0
for i in range(bs): for i in range(bs):
if not need_vision[i]: if not need_vision[i]:
......
...@@ -616,7 +616,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -616,7 +616,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
inputs_embeds = self.model.embed_tokens(input_ids) inputs_embeds = self.model.embed_tokens(input_ids)
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs): for i, image in enumerate(forward_batch.image_inputs):
if image is None: if image is None:
continue continue
......
...@@ -66,7 +66,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): ...@@ -66,7 +66,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.84) self.assertGreater(metrics["score"], 0.835)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -37,7 +37,7 @@ class TestMLA(unittest.TestCase): ...@@ -37,7 +37,7 @@ class TestMLA(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.5 self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self): def test_mgsm_en(self):
args = SimpleNamespace( args = SimpleNamespace(
...@@ -49,7 +49,7 @@ class TestMLA(unittest.TestCase): ...@@ -49,7 +49,7 @@ class TestMLA(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.8 self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment