Unverified Commit fafe76b4 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Async][Spec Decoding] Zero-bubble async scheduling + spec decoding (#32951)


Signed-off-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Co-authored-by: default avatarzhrrr <43847754+izhuhaoran@users.noreply.github.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
parent ffb5b32b
...@@ -177,7 +177,7 @@ def test_prepare_next_token_ids(): ...@@ -177,7 +177,7 @@ def test_prepare_next_token_ids():
next_token_ids_from_padded, valid_sampled_tokens_count = ( next_token_ids_from_padded, valid_sampled_tokens_count = (
proposer.prepare_next_token_ids_padded( proposer.prepare_next_token_ids_padded(
common_attn_metadata, common_attn_metadata.seq_lens_cpu,
sampled_token_ids_tensor, sampled_token_ids_tensor,
mock_requests, mock_requests,
mock_input_batch, mock_input_batch,
......
...@@ -187,7 +187,7 @@ def test_prepare_next_token_ids_padded(): ...@@ -187,7 +187,7 @@ def test_prepare_next_token_ids_padded():
) )
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded( next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata, common_attn_metadata.seq_lens_cpu,
sampled_token_ids, sampled_token_ids,
mock_requests, mock_requests,
mock_input_batch, mock_input_batch,
......
...@@ -766,6 +766,19 @@ class VllmConfig: # type: ignore[misc] ...@@ -766,6 +766,19 @@ class VllmConfig: # type: ignore[misc]
else: else:
self.parallel_config.disable_nccl_for_dp_synchronization = False self.parallel_config.disable_nccl_for_dp_synchronization = False
if (
self.speculative_config is not None
and self.scheduler_config.async_scheduling
and self.model_config is not None
and not self.model_config.disable_cascade_attn
):
logger.warning_once(
"Disabling cascade attention (not yet compatible with "
"async speculative decoding).",
scope="local",
)
self.model_config.disable_cascade_attn = True
if ( if (
self.model_config is not None self.model_config is not None
and self.model_config.multimodal_config is not None and self.model_config.multimodal_config is not None
......
...@@ -71,7 +71,6 @@ class SpecDecodeBaseProposer: ...@@ -71,7 +71,6 @@ class SpecDecodeBaseProposer:
self.method = self.speculative_config.method self.method = self.speculative_config.method
self.pass_hidden_states_to_model = pass_hidden_states_to_model self.pass_hidden_states_to_model = pass_hidden_states_to_model
self.runner = runner
self.device = device self.device = device
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
...@@ -424,8 +423,6 @@ class SpecDecodeBaseProposer: ...@@ -424,8 +423,6 @@ class SpecDecodeBaseProposer:
) )
) )
assert self.runner is not None
per_layer_attn_metadata: dict[str, object] = {} per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups: for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting( attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
...@@ -821,7 +818,7 @@ class SpecDecodeBaseProposer: ...@@ -821,7 +818,7 @@ class SpecDecodeBaseProposer:
def prepare_next_token_ids_padded( def prepare_next_token_ids_padded(
self, self,
common_attn_metadata: CommonAttentionMetadata, seq_lens_cpu: torch.Tensor,
sampled_token_ids: torch.Tensor, sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState], requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch, gpu_input_batch: InputBatch,
...@@ -836,11 +833,10 @@ class SpecDecodeBaseProposer: ...@@ -836,11 +833,10 @@ class SpecDecodeBaseProposer:
""" """
# Precompute get_token_id for when there is no valid next token # Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs num_reqs = gpu_input_batch.num_reqs
seq_lens_list = seq_lens_cpu[:num_reqs].tolist()
self.backup_next_token_ids.np[:num_reqs] = np.array( self.backup_next_token_ids.np[:num_reqs] = np.array(
[ [
requests[gpu_input_batch.req_ids[i]].get_token_id( requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs) for i in range(num_reqs)
], ],
dtype=np.int32, dtype=np.int32,
...@@ -925,7 +921,7 @@ class SpecDecodeBaseProposer: ...@@ -925,7 +921,7 @@ class SpecDecodeBaseProposer:
num_reqs=common_attn_metadata.num_reqs, num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens, num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(), max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), max_seq_len=common_attn_metadata.max_seq_len,
block_table_tensor=common_attn_metadata.block_table_tensor, block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens], slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
causal=True, causal=True,
......
...@@ -286,7 +286,7 @@ class ExtractHiddenStatesProposer: ...@@ -286,7 +286,7 @@ class ExtractHiddenStatesProposer:
def prepare_next_token_ids_padded( def prepare_next_token_ids_padded(
self, self,
common_attn_metadata: CommonAttentionMetadata, seq_lens: torch.Tensor,
sampled_token_ids: torch.Tensor, sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState], requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch, gpu_input_batch: InputBatch,
...@@ -303,11 +303,10 @@ class ExtractHiddenStatesProposer: ...@@ -303,11 +303,10 @@ class ExtractHiddenStatesProposer:
device = sampled_token_ids.device device = sampled_token_ids.device
# Compute backup tokens for discarded / invalid requests # Compute backup tokens for discarded / invalid requests
seq_lens_list = seq_lens[:num_reqs].tolist()
backup_tokens_gpu = torch.tensor( backup_tokens_gpu = torch.tensor(
[ [
requests[gpu_input_batch.req_ids[i]].get_token_id( requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs) for i in range(num_reqs)
], ],
dtype=torch.int32, dtype=torch.int32,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import torch import torch
from vllm.config import VllmConfig, replace from vllm.config import VllmConfig, replace
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, CommonAttentionMetadata,
...@@ -463,3 +464,36 @@ def copy_and_expand_eagle_inputs_kernel( ...@@ -463,3 +464,36 @@ def copy_and_expand_eagle_inputs_kernel(
out_idx, out_idx,
mask=is_new_token_region & in_bounds, mask=is_new_token_region & in_bounds,
) )
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def update_num_computed_tokens_for_batch_change(
num_computed_tokens: torch.Tensor,
num_accepted_tokens: torch.Tensor,
prev_positions: torch.Tensor,
valid_sampled_token_count: torch.Tensor,
prev_num_draft_tokens: torch.Tensor,
cpu_num_computed_tokens: torch.Tensor,
) -> None:
"""Correct num_computed_tokens for async spec decode drift.
Requests that had drafts: corrected = prev_gpu + valid_count.
New requests or non-draft (e.g. prefills): use CPU value directly.
"""
# Clamp because prev_positions can be -1 for new requests
gather_indices = prev_positions.clamp(min=0)
valid_counts = valid_sampled_token_count[gather_indices]
prev_computed = num_computed_tokens[gather_indices]
prev_drafts = prev_num_draft_tokens[gather_indices]
participating = (prev_positions >= 0) & (prev_drafts > 0)
corrected = prev_computed + valid_counts.int()
n = prev_positions.shape[0]
num_computed_tokens[:n].copy_(
torch.where(participating, corrected, cpu_num_computed_tokens)
)
num_accepted_tokens.copy_(
torch.where(participating, valid_counts, num_accepted_tokens)
)
...@@ -6,7 +6,9 @@ import torch ...@@ -6,7 +6,9 @@ import torch
from vllm.distributed import get_dcp_group, get_pcp_group from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.cp_utils import get_total_cp_world_size from vllm.v1.worker.cp_utils import get_total_cp_world_size
...@@ -131,71 +133,33 @@ class BlockTable: ...@@ -131,71 +133,33 @@ class BlockTable:
self.block_table.np[src_tgt] = self.block_table.np[tgt_src] self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
def compute_slot_mapping( def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray self,
num_reqs: int,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> None: ) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] num_tokens = positions.shape[0]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
total_cp_world_size = self.pcp_world_size * self.dcp_world_size total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
if total_cp_world_size > 1: _compute_slot_mapping_kernel[(num_reqs + 1,)](
# Note(hc): The DCP implement store kvcache with an interleave num_tokens,
# style, the kvcache for the token whose token_idx is i is self.max_num_batched_tokens,
# always stored on the GPU whose dcp_rank equals i % cp_world_size: query_start_loc,
positions,
# Use a "virtual block" which equals to world_size * block_size self.block_table.gpu,
# for block_table_indices calculation. self.block_table.gpu.stride(0),
virtual_block_size = self.block_size * total_cp_world_size self.block_size,
block_table_indices = ( self.slot_mapping.gpu,
req_indices * self.max_num_blocks_per_req TOTAL_CP_WORLD_SIZE=total_cp_world_size,
+ positions // virtual_block_size TOTAL_CP_RANK=total_cp_rank,
) CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
PAD_ID=PAD_SLOT_ID,
block_numbers = self.block_table.np.ravel()[block_table_indices] BLOCK_SIZE=1024,
# Use virtual_block_size for mask calculation, which marks local )
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = (
virtual_block_offsets
// self.cp_kv_cache_interleave_size
% total_cp_world_size
== total_cp_rank
)
# Calculate local block_offsets
block_offsets = (
virtual_block_offsets
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
* self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping.np[: req_indices.shape[0]] = np.where(
mask, slot_mapping, -1
)
else:
block_table_indices = (
req_indices * self.max_num_blocks_per_req + positions // self.block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(
block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping.np[: req_indices.shape[0]],
)
def commit_block_table(self, num_reqs: int) -> None: def commit_block_table(self, num_reqs: int) -> None:
self.block_table.copy_to_gpu(num_reqs) self.block_table.copy_to_gpu(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping.copy_to_gpu(num_tokens)
def clear(self) -> None: def clear(self) -> None:
self.block_table.gpu.fill_(0) self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0) self.block_table.cpu.fill_(0)
...@@ -320,19 +284,18 @@ class MultiGroupBlockTable: ...@@ -320,19 +284,18 @@ class MultiGroupBlockTable:
block_table.swap_row(src, tgt) block_table.swap_row(src, tgt)
def compute_slot_mapping( def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray self,
num_reqs: int,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> None: ) -> None:
for block_table in self.block_tables: for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions) block_table.compute_slot_mapping(num_reqs, query_start_loc, positions)
def commit_block_table(self, num_reqs: int) -> None: def commit_block_table(self, num_reqs: int) -> None:
for block_table in self.block_tables: for block_table in self.block_tables:
block_table.commit_block_table(num_reqs) block_table.commit_block_table(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
for block_table in self.block_tables:
block_table.commit_slot_mapping(num_tokens)
def clear(self) -> None: def clear(self) -> None:
for block_table in self.block_tables: for block_table in self.block_tables:
block_table.clear() block_table.clear()
...@@ -340,3 +303,61 @@ class MultiGroupBlockTable: ...@@ -340,3 +303,61 @@ class MultiGroupBlockTable:
def __getitem__(self, idx: int) -> "BlockTable": def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group.""" """Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx] return self.block_tables[idx]
@triton.jit
def _compute_slot_mapping_kernel(
num_tokens,
max_num_tokens,
query_start_loc_ptr, # [num_reqs + 1], int32
positions_ptr, # [num_tokens], int64
block_table_ptr, # [max_num_reqs, max_num_blocks_per_req], int32 (flat)
block_table_stride, # max_num_blocks_per_req
block_size,
slot_mapping_ptr, # [max_num_tokens], int64
TOTAL_CP_WORLD_SIZE: tl.constexpr,
TOTAL_CP_RANK: tl.constexpr,
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == tl.num_programs(0) - 1:
# Pad remaining slots for CUDA graph compatibility.
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(
slot_mapping_ptr + offsets,
PAD_ID,
mask=offsets < max_num_tokens,
)
return
start_idx = tl.load(query_start_loc_ptr + req_idx).to(tl.int64)
end_idx = tl.load(query_start_loc_ptr + req_idx + 1).to(tl.int64)
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
row_offset = req_idx * block_table_stride
for i in range(start_idx, end_idx, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
mask = offsets < end_idx
pos = tl.load(positions_ptr + offsets, mask=mask, other=0)
block_indices = pos // virtual_block_size
block_numbers = tl.load(block_table_ptr + row_offset + block_indices).to(
tl.int64
)
virtual_block_offsets = pos - block_indices * virtual_block_size
is_local = (
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
local_block_offsets = (
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
) * CP_KV_CACHE_INTERLEAVE_SIZE + (
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
)
slot_ids = block_numbers * block_size + local_block_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + offsets, slot_ids, mask=mask)
...@@ -219,7 +219,7 @@ class InputBatch: ...@@ -219,7 +219,7 @@ class InputBatch:
# Speculative decoding # Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones( self.num_accepted_tokens_cpu_tensor = torch.ones(
(max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
) )
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy() self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
...@@ -989,13 +989,15 @@ class InputBatch: ...@@ -989,13 +989,15 @@ class InputBatch:
continue continue
num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1) num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
# Also account for case where there may be a smaller number of # Also account for case where there may be a smaller number of
# output placeholders (tokens can be discarded after a kv-load failure). # output placeholders (tokens can be discarded after kv-load
# failure) or a larger number (async spec decode adds optimistic
# placeholders that may exceed the actual acceptance count).
first_placeholder = req_output_token_ids.index(-1) first_placeholder = req_output_token_ids.index(-1)
num_placeholders = len(req_output_token_ids) - first_placeholder num_placeholders = len(req_output_token_ids) - first_placeholder
num_to_replace = min(num_sampled_ids, num_placeholders) num_to_replace = min(num_sampled_ids, num_placeholders)
del new_ids[num_to_replace:] del new_ids[num_to_replace:]
end_index = first_placeholder + num_to_replace req_output_token_ids[first_placeholder:] = new_ids
req_output_token_ids[first_placeholder:end_index] = new_ids # ^ Implicitly resizes to (first_placeholder + num_to_replace)
def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None: def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
""" """
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment