Unverified Commit 4ea48fb3 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V1][Minor] Move cascade attn logic outside _prepare_inputs (#12943)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent e31498bd
...@@ -476,12 +476,82 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -476,12 +476,82 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.device, non_blocking=True).long() self.device, non_blocking=True).long()
# Prepare for cascade attention if needed. # Prepare for cascade attention if needed.
common_prefix_len = (scheduler_output.num_common_prefix_blocks * common_prefix_len = self._compute_cascade_attn_prefix_len(
self.block_size) num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)
use_cascade = common_prefix_len > 0
if use_cascade:
# TODO: Optimize.
cu_prefix_query_lens = torch.tensor(
[0, total_num_scheduled_tokens],
dtype=torch.int32,
device=self.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.device)
suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
)
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
# requests. While we should not sample any token from these partial
# requests, we do so for simplicity. We will ignore the sampled
# tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
def _compute_cascade_attn_prefix_len(
self,
num_scheduled_tokens: np.ndarray,
num_common_prefix_blocks: int,
) -> int:
"""Compute the length of the common prefix for cascade attention.
NOTE(woosuk): The common prefix length returned by this function
represents the length used specifically for cascade attention, not the
actual number of tokens shared between requests. When cascade attention
is disabled (use_cascade=False), this function returns 0 even if
requests share common tokens. Additionally, the common prefix length is
truncated to a multiple of the block size and may be further truncated
due to implementation details explained below.
Args:
num_scheduled_tokens: Number of tokens scheduled per request.
num_common_prefix_blocks: Number of shared KV cache blocks.
Returns:
int: Length of common prefix in tokens.
"""
common_prefix_len = num_common_prefix_blocks * self.block_size
if common_prefix_len == 0: if common_prefix_len == 0:
# Common case. # Common case.
use_cascade = False return 0
else:
# NOTE(woosuk): Cascade attention uses two attention kernels: one # NOTE(woosuk): Cascade attention uses two attention kernels: one
# for the common prefix and the other for the rest. For the first # for the common prefix and the other for the rest. For the first
# kernel, we concatenate all the query tokens (possibly from # kernel, we concatenate all the query tokens (possibly from
...@@ -521,6 +591,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -521,6 +591,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# and the second kernel will get an empty input. While this is not # and the second kernel will get an empty input. While this is not
# a fundamental problem, our current implementation does not support # a fundamental problem, our current implementation does not support
# this case. # this case.
num_reqs = len(num_scheduled_tokens)
common_prefix_len = min( common_prefix_len = min(
common_prefix_len, common_prefix_len,
self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
...@@ -536,50 +607,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -536,50 +607,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_sliding_window=self.sliding_window is not None, use_sliding_window=self.sliding_window is not None,
num_sms=self.num_sms, num_sms=self.num_sms,
) )
return common_prefix_len if use_cascade else 0
if use_cascade:
# TODO: Optimize.
cu_prefix_query_lens = torch.tensor(
[0, total_num_scheduled_tokens],
dtype=torch.int32,
device=self.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.device)
suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
)
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
# requests. While we should not sample any token from these partial
# requests, we do so for simplicity. We will ignore the sampled
# tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0 mrope_pos_ptr = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment