Unverified Commit fafe76b4 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Async][Spec Decoding] Zero-bubble async scheduling + spec decoding (#32951)


Signed-off-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Co-authored-by: default avatarzhrrr <43847754+izhuhaoran@users.noreply.github.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
parent ffb5b32b
...@@ -177,7 +177,7 @@ def test_prepare_next_token_ids(): ...@@ -177,7 +177,7 @@ def test_prepare_next_token_ids():
next_token_ids_from_padded, valid_sampled_tokens_count = ( next_token_ids_from_padded, valid_sampled_tokens_count = (
proposer.prepare_next_token_ids_padded( proposer.prepare_next_token_ids_padded(
common_attn_metadata, common_attn_metadata.seq_lens_cpu,
sampled_token_ids_tensor, sampled_token_ids_tensor,
mock_requests, mock_requests,
mock_input_batch, mock_input_batch,
......
...@@ -187,7 +187,7 @@ def test_prepare_next_token_ids_padded(): ...@@ -187,7 +187,7 @@ def test_prepare_next_token_ids_padded():
) )
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded( next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata, common_attn_metadata.seq_lens_cpu,
sampled_token_ids, sampled_token_ids,
mock_requests, mock_requests,
mock_input_batch, mock_input_batch,
......
...@@ -766,6 +766,19 @@ class VllmConfig: # type: ignore[misc] ...@@ -766,6 +766,19 @@ class VllmConfig: # type: ignore[misc]
else: else:
self.parallel_config.disable_nccl_for_dp_synchronization = False self.parallel_config.disable_nccl_for_dp_synchronization = False
if (
self.speculative_config is not None
and self.scheduler_config.async_scheduling
and self.model_config is not None
and not self.model_config.disable_cascade_attn
):
logger.warning_once(
"Disabling cascade attention (not yet compatible with "
"async speculative decoding).",
scope="local",
)
self.model_config.disable_cascade_attn = True
if ( if (
self.model_config is not None self.model_config is not None
and self.model_config.multimodal_config is not None and self.model_config.multimodal_config is not None
......
...@@ -71,7 +71,6 @@ class SpecDecodeBaseProposer: ...@@ -71,7 +71,6 @@ class SpecDecodeBaseProposer:
self.method = self.speculative_config.method self.method = self.speculative_config.method
self.pass_hidden_states_to_model = pass_hidden_states_to_model self.pass_hidden_states_to_model = pass_hidden_states_to_model
self.runner = runner
self.device = device self.device = device
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
...@@ -424,8 +423,6 @@ class SpecDecodeBaseProposer: ...@@ -424,8 +423,6 @@ class SpecDecodeBaseProposer:
) )
) )
assert self.runner is not None
per_layer_attn_metadata: dict[str, object] = {} per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups: for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting( attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
...@@ -821,7 +818,7 @@ class SpecDecodeBaseProposer: ...@@ -821,7 +818,7 @@ class SpecDecodeBaseProposer:
def prepare_next_token_ids_padded( def prepare_next_token_ids_padded(
self, self,
common_attn_metadata: CommonAttentionMetadata, seq_lens_cpu: torch.Tensor,
sampled_token_ids: torch.Tensor, sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState], requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch, gpu_input_batch: InputBatch,
...@@ -836,11 +833,10 @@ class SpecDecodeBaseProposer: ...@@ -836,11 +833,10 @@ class SpecDecodeBaseProposer:
""" """
# Precompute get_token_id for when there is no valid next token # Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs num_reqs = gpu_input_batch.num_reqs
seq_lens_list = seq_lens_cpu[:num_reqs].tolist()
self.backup_next_token_ids.np[:num_reqs] = np.array( self.backup_next_token_ids.np[:num_reqs] = np.array(
[ [
requests[gpu_input_batch.req_ids[i]].get_token_id( requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs) for i in range(num_reqs)
], ],
dtype=np.int32, dtype=np.int32,
...@@ -925,7 +921,7 @@ class SpecDecodeBaseProposer: ...@@ -925,7 +921,7 @@ class SpecDecodeBaseProposer:
num_reqs=common_attn_metadata.num_reqs, num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens, num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(), max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), max_seq_len=common_attn_metadata.max_seq_len,
block_table_tensor=common_attn_metadata.block_table_tensor, block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens], slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
causal=True, causal=True,
......
...@@ -286,7 +286,7 @@ class ExtractHiddenStatesProposer: ...@@ -286,7 +286,7 @@ class ExtractHiddenStatesProposer:
def prepare_next_token_ids_padded( def prepare_next_token_ids_padded(
self, self,
common_attn_metadata: CommonAttentionMetadata, seq_lens: torch.Tensor,
sampled_token_ids: torch.Tensor, sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState], requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch, gpu_input_batch: InputBatch,
...@@ -303,11 +303,10 @@ class ExtractHiddenStatesProposer: ...@@ -303,11 +303,10 @@ class ExtractHiddenStatesProposer:
device = sampled_token_ids.device device = sampled_token_ids.device
# Compute backup tokens for discarded / invalid requests # Compute backup tokens for discarded / invalid requests
seq_lens_list = seq_lens[:num_reqs].tolist()
backup_tokens_gpu = torch.tensor( backup_tokens_gpu = torch.tensor(
[ [
requests[gpu_input_batch.req_ids[i]].get_token_id( requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs) for i in range(num_reqs)
], ],
dtype=torch.int32, dtype=torch.int32,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import torch import torch
from vllm.config import VllmConfig, replace from vllm.config import VllmConfig, replace
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, CommonAttentionMetadata,
...@@ -463,3 +464,36 @@ def copy_and_expand_eagle_inputs_kernel( ...@@ -463,3 +464,36 @@ def copy_and_expand_eagle_inputs_kernel(
out_idx, out_idx,
mask=is_new_token_region & in_bounds, mask=is_new_token_region & in_bounds,
) )
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def update_num_computed_tokens_for_batch_change(
num_computed_tokens: torch.Tensor,
num_accepted_tokens: torch.Tensor,
prev_positions: torch.Tensor,
valid_sampled_token_count: torch.Tensor,
prev_num_draft_tokens: torch.Tensor,
cpu_num_computed_tokens: torch.Tensor,
) -> None:
"""Correct num_computed_tokens for async spec decode drift.
Requests that had drafts: corrected = prev_gpu + valid_count.
New requests or non-draft (e.g. prefills): use CPU value directly.
"""
# Clamp because prev_positions can be -1 for new requests
gather_indices = prev_positions.clamp(min=0)
valid_counts = valid_sampled_token_count[gather_indices]
prev_computed = num_computed_tokens[gather_indices]
prev_drafts = prev_num_draft_tokens[gather_indices]
participating = (prev_positions >= 0) & (prev_drafts > 0)
corrected = prev_computed + valid_counts.int()
n = prev_positions.shape[0]
num_computed_tokens[:n].copy_(
torch.where(participating, corrected, cpu_num_computed_tokens)
)
num_accepted_tokens.copy_(
torch.where(participating, valid_counts, num_accepted_tokens)
)
...@@ -6,7 +6,9 @@ import torch ...@@ -6,7 +6,9 @@ import torch
from vllm.distributed import get_dcp_group, get_pcp_group from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.cp_utils import get_total_cp_world_size from vllm.v1.worker.cp_utils import get_total_cp_world_size
...@@ -131,71 +133,33 @@ class BlockTable: ...@@ -131,71 +133,33 @@ class BlockTable:
self.block_table.np[src_tgt] = self.block_table.np[tgt_src] self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
def compute_slot_mapping( def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray self,
num_reqs: int,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> None: ) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] num_tokens = positions.shape[0]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
total_cp_world_size = self.pcp_world_size * self.dcp_world_size total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
if total_cp_world_size > 1: _compute_slot_mapping_kernel[(num_reqs + 1,)](
# Note(hc): The DCP implement store kvcache with an interleave num_tokens,
# style, the kvcache for the token whose token_idx is i is self.max_num_batched_tokens,
# always stored on the GPU whose dcp_rank equals i % cp_world_size: query_start_loc,
positions,
# Use a "virtual block" which equals to world_size * block_size self.block_table.gpu,
# for block_table_indices calculation. self.block_table.gpu.stride(0),
virtual_block_size = self.block_size * total_cp_world_size self.block_size,
block_table_indices = ( self.slot_mapping.gpu,
req_indices * self.max_num_blocks_per_req TOTAL_CP_WORLD_SIZE=total_cp_world_size,
+ positions // virtual_block_size TOTAL_CP_RANK=total_cp_rank,
) CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
PAD_ID=PAD_SLOT_ID,
block_numbers = self.block_table.np.ravel()[block_table_indices] BLOCK_SIZE=1024,
# Use virtual_block_size for mask calculation, which marks local )
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = (
virtual_block_offsets
// self.cp_kv_cache_interleave_size
% total_cp_world_size
== total_cp_rank
)
# Calculate local block_offsets
block_offsets = (
virtual_block_offsets
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
* self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping.np[: req_indices.shape[0]] = np.where(
mask, slot_mapping, -1
)
else:
block_table_indices = (
req_indices * self.max_num_blocks_per_req + positions // self.block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(
block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping.np[: req_indices.shape[0]],
)
def commit_block_table(self, num_reqs: int) -> None: def commit_block_table(self, num_reqs: int) -> None:
self.block_table.copy_to_gpu(num_reqs) self.block_table.copy_to_gpu(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping.copy_to_gpu(num_tokens)
def clear(self) -> None: def clear(self) -> None:
self.block_table.gpu.fill_(0) self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0) self.block_table.cpu.fill_(0)
...@@ -320,19 +284,18 @@ class MultiGroupBlockTable: ...@@ -320,19 +284,18 @@ class MultiGroupBlockTable:
block_table.swap_row(src, tgt) block_table.swap_row(src, tgt)
def compute_slot_mapping( def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray self,
num_reqs: int,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> None: ) -> None:
for block_table in self.block_tables: for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions) block_table.compute_slot_mapping(num_reqs, query_start_loc, positions)
def commit_block_table(self, num_reqs: int) -> None: def commit_block_table(self, num_reqs: int) -> None:
for block_table in self.block_tables: for block_table in self.block_tables:
block_table.commit_block_table(num_reqs) block_table.commit_block_table(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
for block_table in self.block_tables:
block_table.commit_slot_mapping(num_tokens)
def clear(self) -> None: def clear(self) -> None:
for block_table in self.block_tables: for block_table in self.block_tables:
block_table.clear() block_table.clear()
...@@ -340,3 +303,61 @@ class MultiGroupBlockTable: ...@@ -340,3 +303,61 @@ class MultiGroupBlockTable:
def __getitem__(self, idx: int) -> "BlockTable": def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group.""" """Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx] return self.block_tables[idx]
@triton.jit
def _compute_slot_mapping_kernel(
num_tokens,
max_num_tokens,
query_start_loc_ptr, # [num_reqs + 1], int32
positions_ptr, # [num_tokens], int64
block_table_ptr, # [max_num_reqs, max_num_blocks_per_req], int32 (flat)
block_table_stride, # max_num_blocks_per_req
block_size,
slot_mapping_ptr, # [max_num_tokens], int64
TOTAL_CP_WORLD_SIZE: tl.constexpr,
TOTAL_CP_RANK: tl.constexpr,
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == tl.num_programs(0) - 1:
# Pad remaining slots for CUDA graph compatibility.
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(
slot_mapping_ptr + offsets,
PAD_ID,
mask=offsets < max_num_tokens,
)
return
start_idx = tl.load(query_start_loc_ptr + req_idx).to(tl.int64)
end_idx = tl.load(query_start_loc_ptr + req_idx + 1).to(tl.int64)
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
row_offset = req_idx * block_table_stride
for i in range(start_idx, end_idx, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
mask = offsets < end_idx
pos = tl.load(positions_ptr + offsets, mask=mask, other=0)
block_indices = pos // virtual_block_size
block_numbers = tl.load(block_table_ptr + row_offset + block_indices).to(
tl.int64
)
virtual_block_offsets = pos - block_indices * virtual_block_size
is_local = (
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
local_block_offsets = (
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
) * CP_KV_CACHE_INTERLEAVE_SIZE + (
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
)
slot_ids = block_numbers * block_size + local_block_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + offsets, slot_ids, mask=mask)
...@@ -219,7 +219,7 @@ class InputBatch: ...@@ -219,7 +219,7 @@ class InputBatch:
# Speculative decoding # Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones( self.num_accepted_tokens_cpu_tensor = torch.ones(
(max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
) )
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy() self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
...@@ -989,13 +989,15 @@ class InputBatch: ...@@ -989,13 +989,15 @@ class InputBatch:
continue continue
num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1) num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
# Also account for case where there may be a smaller number of # Also account for case where there may be a smaller number of
# output placeholders (tokens can be discarded after a kv-load failure). # output placeholders (tokens can be discarded after kv-load
# failure) or a larger number (async spec decode adds optimistic
# placeholders that may exceed the actual acceptance count).
first_placeholder = req_output_token_ids.index(-1) first_placeholder = req_output_token_ids.index(-1)
num_placeholders = len(req_output_token_ids) - first_placeholder num_placeholders = len(req_output_token_ids) - first_placeholder
num_to_replace = min(num_sampled_ids, num_placeholders) num_to_replace = min(num_sampled_ids, num_placeholders)
del new_ids[num_to_replace:] del new_ids[num_to_replace:]
end_index = first_placeholder + num_to_replace req_output_token_ids[first_placeholder:] = new_ids
req_output_token_ids[first_placeholder:end_index] = new_ids # ^ Implicitly resizes to (first_placeholder + num_to_replace)
def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None: def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
""" """
......
...@@ -7,7 +7,7 @@ import itertools ...@@ -7,7 +7,7 @@ import itertools
import threading import threading
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
...@@ -172,6 +172,7 @@ from vllm.v1.spec_decode.ngram_proposer_gpu import ( ...@@ -172,6 +172,7 @@ from vllm.v1.spec_decode.ngram_proposer_gpu import (
update_scheduler_for_invalid_drafts, update_scheduler_for_invalid_drafts,
) )
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.spec_decode.utils import update_num_computed_tokens_for_batch_change
from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker import mamba_utils from vllm.v1.worker import mamba_utils
...@@ -570,6 +571,7 @@ class GPUModelRunner( ...@@ -570,6 +571,7 @@ class GPUModelRunner(
self.rejection_sampler = RejectionSampler(self.sampler) self.rejection_sampler = RejectionSampler(self.sampler)
self.num_spec_tokens = 0 self.num_spec_tokens = 0
self.valid_sampled_token_count_gpu: torch.Tensor | None = None
if self.speculative_config: if self.speculative_config:
self.num_spec_tokens = self.speculative_config.num_speculative_tokens self.num_spec_tokens = self.speculative_config.num_speculative_tokens
draft_config = self.speculative_config.draft_model_config draft_config = self.speculative_config.draft_model_config
...@@ -577,6 +579,9 @@ class GPUModelRunner( ...@@ -577,6 +579,9 @@ class GPUModelRunner(
self.effective_drafter_max_model_len = draft_config.max_model_len self.effective_drafter_max_model_len = draft_config.max_model_len
else: else:
self.effective_drafter_max_model_len = self.max_model_len self.effective_drafter_max_model_len = self.max_model_len
self.use_async_spec_decode = (
self.use_async_scheduling and self.num_spec_tokens > 0
)
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
...@@ -659,11 +664,31 @@ class GPUModelRunner( ...@@ -659,11 +664,31 @@ class GPUModelRunner(
# Persistent buffers for CUDA graphs. # Persistent buffers for CUDA graphs.
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) self.positions = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=self.device
)
self.query_start_loc = self._make_buffer( self.query_start_loc = self._make_buffer(
self.max_num_reqs + 1, dtype=torch.int32 self.max_num_reqs + 1, dtype=torch.int32
) )
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) self.seq_lens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=self.device
)
self.optimistic_seq_lens_cpu = torch.zeros(
self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory
)
self.num_computed_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=self.device
)
self.prev_num_draft_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
self.req_indices = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
# Maps current batch position -> previous batch position (-1 for new reqs)
self.prev_positions = self._make_buffer(self.max_num_reqs, dtype=torch.int64)
self.num_scheduled_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
self.dcp_local_seq_lens = self._make_buffer( self.dcp_local_seq_lens = self._make_buffer(
...@@ -683,7 +708,7 @@ class GPUModelRunner( ...@@ -683,7 +708,7 @@ class GPUModelRunner(
self.max_num_reqs, dtype=torch.int32 self.max_num_reqs, dtype=torch.int32
) )
self.num_accepted_tokens = self._make_buffer( self.num_accepted_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int64 self.max_num_reqs, dtype=torch.int32
) )
# Only relevant for multimodal models # Only relevant for multimodal models
...@@ -722,12 +747,14 @@ class GPUModelRunner( ...@@ -722,12 +747,14 @@ class GPUModelRunner(
# None in the first PP rank. The rest are set after load_model. # None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: IntermediateTensors | None = None self.intermediate_tensors: IntermediateTensors | None = None
# OPTIMIZATION: Cache the tensors rather than creating them every step. # OPTIMIZATION: Cache the arange tensors rather than creating them
# Keep in int64 to avoid overflow with long context # every step. Keep in int64 to avoid overflow with long context.
self.arange_np = np.arange( # - arange_np: immutable [0, 1, 2, ...] used as source for batched computation
max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), # - query_pos: CpuGpuBuffer for the computed batched arange result
dtype=np.int64, arange_size = max(self.max_num_reqs + 1, self.max_num_tokens)
) self.arange_np = np.arange(arange_size, dtype=np.int64)
self.query_pos = self._make_buffer(arange_size, dtype=torch.int64)
self._arange_scratch = np.empty(arange_size, dtype=np.int64)
# Layer pairings for cross-layer KV sharing. # Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it # If an Attention layer `layer_name` is in the keys of this dict, it
...@@ -812,7 +839,7 @@ class GPUModelRunner( ...@@ -812,7 +839,7 @@ class GPUModelRunner(
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
self.valid_sampled_token_count_cpu = torch.empty( self.valid_sampled_token_count_cpu = torch.empty(
self.max_num_reqs, self.max_num_reqs,
dtype=torch.int64, dtype=torch.int32,
device="cpu", device="cpu",
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
...@@ -903,13 +930,13 @@ class GPUModelRunner( ...@@ -903,13 +930,13 @@ class GPUModelRunner(
return self.mrope_positions.gpu[:, :num_tokens] return self.mrope_positions.gpu[:, :num_tokens]
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, :num_tokens] return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens] return self.positions[:num_tokens]
else: else:
if self.uses_mrope: if self.uses_mrope:
return self.mrope_positions.gpu[:, num_tokens] return self.mrope_positions.gpu[:, num_tokens]
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, num_tokens] return self.xdrope_positions.gpu[:, num_tokens]
return self.positions.gpu[num_tokens] return self.positions[num_tokens]
def _make_buffer( def _make_buffer(
self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True
...@@ -953,7 +980,7 @@ class GPUModelRunner( ...@@ -953,7 +980,7 @@ class GPUModelRunner(
if len(token_type_id_requests) == 0: if len(token_type_id_requests) == 0:
return model_kwargs return model_kwargs
seq_lens = self.seq_lens.gpu[:num_reqs] seq_lens = self.seq_lens[:num_reqs]
token_type_ids = [] token_type_ids = []
for i in range(num_reqs): for i in range(num_reqs):
...@@ -1021,7 +1048,7 @@ class GPUModelRunner( ...@@ -1021,7 +1048,7 @@ class GPUModelRunner(
def _sync_device(self) -> None: def _sync_device(self) -> None:
torch.accelerator.synchronize() torch.accelerator.synchronize()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None:
"""Update the cached states and the persistent batch with the scheduler """Update the cached states and the persistent batch with the scheduler
output. output.
...@@ -1086,6 +1113,8 @@ class GPUModelRunner( ...@@ -1086,6 +1113,8 @@ class GPUModelRunner(
ngram_gpu_new_reqs: list[CachedRequestState] = [] ngram_gpu_new_reqs: list[CachedRequestState] = []
reqs_to_add: list[CachedRequestState] = [] reqs_to_add: list[CachedRequestState] = []
deferred_spec_decode_corrections = []
# Add new requests to the cached states. # Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id req_id = new_req_data.req_id
...@@ -1172,10 +1201,8 @@ class GPUModelRunner( ...@@ -1172,10 +1201,8 @@ class GPUModelRunner(
scheduler_output, scheduler_output,
self.input_batch.req_id_to_index, self.input_batch.req_id_to_index,
) )
if self.use_async_spec_decode:
# Wait until valid_sampled_tokens_count is copied to cpu, self.prev_num_draft_tokens.np.fill(0)
# then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count = self._get_valid_sampled_token_count()
for i, req_id in enumerate(req_data.req_ids): for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
...@@ -1202,15 +1229,30 @@ class GPUModelRunner( ...@@ -1202,15 +1229,30 @@ class GPUModelRunner(
if req_index is None: if req_index is None:
req_state.prev_num_draft_len = 0 req_state.prev_num_draft_len = 0
else: else:
assert self.input_batch.prev_req_id_to_index is not None # Optimistically assume all accepted; queue up a correction
prev_req_index = self.input_batch.prev_req_id_to_index[req_id] # to be called after the model forward to preserve async
num_accepted = valid_sampled_token_count[prev_req_index] - 1 # scheduling. Corrected on GPU in _prepare_inputs.
num_rejected = req_state.prev_num_draft_len - num_accepted optimistic_num_accepted = req_state.prev_num_draft_len
num_computed_tokens -= num_rejected req_state.output_token_ids.extend([-1] * optimistic_num_accepted)
req_state.output_token_ids.extend([-1] * num_accepted)
deferred_spec_decode_corrections.append(
(req_id, optimistic_num_accepted, req_state)
)
prev_req_index = (
self.input_batch.prev_req_id_to_index.get(req_id)
if self.input_batch.prev_req_id_to_index
else None
)
if prev_req_index is not None:
self.prev_num_draft_tokens.np[prev_req_index] = (
optimistic_num_accepted
)
if is_ngram_gpu and num_accepted > 0 and req_index is not None: if is_ngram_gpu and optimistic_num_accepted > 0:
self.input_batch.num_tokens_no_spec[req_index] += num_accepted self.input_batch.num_tokens_no_spec[req_index] += (
optimistic_num_accepted
)
# Update the cached states. # Update the cached states.
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
...@@ -1238,7 +1280,8 @@ class GPUModelRunner( ...@@ -1238,7 +1280,8 @@ class GPUModelRunner(
) )
elif num_output_tokens < len(req_state.output_token_ids): elif num_output_tokens < len(req_state.output_token_ids):
# Some output tokens were discarded due to a sync-KV-load # Some output tokens were discarded due to a sync-KV-load
# failure. Align the cached state. # failure, or output_token_ids was inflated by the optimistic
# extend above (async spec decode). Align the cached state.
del req_state.output_token_ids[num_output_tokens:] del req_state.output_token_ids[num_output_tokens:]
if req_index is not None: if req_index is not None:
end_idx = ( end_idx = (
...@@ -1326,6 +1369,40 @@ class GPUModelRunner( ...@@ -1326,6 +1369,40 @@ class GPUModelRunner(
_pinned_val_buf=self._ngram_pinned_val_buf, _pinned_val_buf=self._ngram_pinned_val_buf,
) )
if deferred_spec_decode_corrections:
def correct_spec_decode_token_counts():
valid_sampled_token_count = self._get_valid_sampled_token_count()
if not valid_sampled_token_count:
return
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
if not prev_req_id_to_index:
return
for (
req_id,
optimistic_num_accepted,
req_state,
) in deferred_spec_decode_corrections:
prev_req_index = prev_req_id_to_index.get(req_id)
if prev_req_index is None:
continue
num_accepted = valid_sampled_token_count[prev_req_index] - 1
correction = optimistic_num_accepted - num_accepted
req_state.num_computed_tokens -= correction
cur_req_index = self.input_batch.req_id_to_index.get(req_id)
if cur_req_index is None:
continue
self.input_batch.num_computed_tokens_cpu[cur_req_index] -= (
correction
)
if is_ngram_gpu and correction > 0:
self.input_batch.num_tokens_no_spec[cur_req_index] -= correction
self.num_tokens_no_spec_gpu[cur_req_index] -= correction
return correct_spec_decode_token_counts
else:
return None
def _update_states_after_model_execute( def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput" self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
) -> None: ) -> None:
...@@ -1340,6 +1417,9 @@ class GPUModelRunner( ...@@ -1340,6 +1417,9 @@ class GPUModelRunner(
if not self.speculative_config or not self.model_config.is_hybrid: if not self.speculative_config or not self.model_config.is_hybrid:
return return
# TODO: Remove .cpu() sync to enable fully async for hybrid model;
# Use num_computed_tokens.gpu instead of req.num_computed_tokens to
# support aligned mamba cache mode.
# Find the number of accepted tokens for each sequence. # Find the number of accepted tokens for each sequence.
num_reqs = output_token_ids.size(0) num_reqs = output_token_ids.size(0)
self.num_accepted_tokens.gpu[:num_reqs] = ( self.num_accepted_tokens.gpu[:num_reqs] = (
...@@ -1486,12 +1566,14 @@ class GPUModelRunner( ...@@ -1486,12 +1566,14 @@ class GPUModelRunner(
def _get_cumsum_and_arange( def _get_cumsum_and_arange(
self, self,
num_tokens: np.ndarray, num_tokens: np.ndarray,
arange_out: np.ndarray,
cumsum_dtype: np.dtype | None = None, cumsum_dtype: np.dtype | None = None,
) -> tuple[np.ndarray, np.ndarray]: ) -> np.ndarray:
"""Get the cumulative sum and batched arange of the given array. """Get the cumulative sum and batched arange of the given array.
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) E.g., [2, 5, 3] -> [2, 7, 10], arange written to
# Equivalent to but faster than: arange_out[:10] as [0, 1, 0, 1, 2, 3, 4, 0, 1, 2].
# np.concatenate([np.arange(n) for n in num_tokens]) Equivalent to but faster than:
np.concatenate([np.arange(n) for n in num_tokens])
""" """
# Step 1. [2, 5, 3] -> [2, 7, 10] # Step 1. [2, 5, 3] -> [2, 7, 10]
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
...@@ -1499,13 +1581,33 @@ class GPUModelRunner( ...@@ -1499,13 +1581,33 @@ class GPUModelRunner(
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange = self.arange_np[:total_num_tokens] - cumsums_offsets np.subtract(
self.arange_np[:total_num_tokens],
cumsums_offsets,
out=arange_out[:total_num_tokens],
)
return cu_num_tokens
def _compute_prev_positions(self, num_reqs: int) -> None:
"""Build prev_positions mapping: current pos -> previous pos (-1 if new).
Populates self.prev_positions.np[:num_reqs] with the mapping.
"""
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
prev_positions = self.prev_positions.np[:num_reqs]
if not prev_req_id_to_index:
prev_positions.fill(-1)
return
return cu_num_tokens, arange for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
prev_positions[i] = prev_req_id_to_index.get(req_id, -1)
def _prepare_input_ids( def _prepare_input_ids(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
num_reqs: int,
total_num_scheduled_tokens: int, total_num_scheduled_tokens: int,
cu_num_tokens: np.ndarray, cu_num_tokens: np.ndarray,
) -> None: ) -> None:
...@@ -1513,7 +1615,11 @@ class GPUModelRunner( ...@@ -1513,7 +1615,11 @@ class GPUModelRunner(
Carefully handles the `prev_sampled_token_ids` which can be cached Carefully handles the `prev_sampled_token_ids` which can be cached
from the previous engine iteration, in which case those tokens on the from the previous engine iteration, in which case those tokens on the
GPU need to be copied into the corresponding slots into input_ids.""" GPU need to be copied into the corresponding slots into input_ids.
Uses self.prev_positions[:num_reqs] which maps current pos -> prev pos
(-1 for new requests).
"""
if self.input_batch.prev_sampled_token_ids is None: if self.input_batch.prev_sampled_token_ids is None:
# Normal scheduling case # Normal scheduling case
...@@ -1526,47 +1632,50 @@ class GPUModelRunner( ...@@ -1526,47 +1632,50 @@ class GPUModelRunner(
# Async scheduling case, where some decode requests from the previous # Async scheduling case, where some decode requests from the previous
# iteration won't have entries in input_ids_cpu and need to be copied # iteration won't have entries in input_ids_cpu and need to be copied
# on the GPU from prev_sampled_token_ids. # on the GPU from prev_sampled_token_ids.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index prev_positions = self.prev_positions.np[:num_reqs]
assert prev_req_id_to_index is not None scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
sample_flattened_indices: list[int] = [] sample_flattened_indices: list[int] = []
spec_flattened_indices: list[int] = [] spec_flattened_indices: list[int] = []
prev_common_req_indices: list[int] = []
prev_draft_token_indices: list[int] = [] prev_draft_token_indices: list[int] = []
indices_match = True prev_indices: list[int] = []
common_indices_match = True
max_flattened_index = -1 max_flattened_index = -1
total_num_spec_tokens = 0 total_num_spec_tokens = 0
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id, cur_index in self.input_batch.req_id_to_index.items(): for cur_index in range(num_reqs):
if (prev_index := prev_req_id_to_index.get(req_id)) is not None: prev_index = prev_positions[cur_index]
prev_common_req_indices.append(prev_index) if prev_index < 0:
# We need to compute the flattened input_ids index of the continue
# last token in each common request. prev_indices.append(prev_index)
draft_len = len(scheduled_spec_tokens.get(req_id, ())) req_id = self.input_batch.req_ids[cur_index]
total_num_spec_tokens += draft_len # We need to compute the flattened input_ids index of the
flattened_index = cu_num_tokens[cur_index].item() - 1 # last token in each common request.
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] draft_len = len(scheduled_spec_tokens.get(req_id, ()))
# sample_flattened_indices = [0, 2, 5] total_num_spec_tokens += draft_len
# spec_flattened_indices = [1, 3, 4, 6, 7] flattened_index = cu_num_tokens[cur_index].item() - 1
sample_flattened_indices.append(flattened_index - draft_len) # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
spec_flattened_indices.extend( # sample_flattened_indices = [0, 2, 5]
range(flattened_index - draft_len + 1, flattened_index + 1) # spec_flattened_indices = [1, 3, 4, 6, 7]
) sample_flattened_indices.append(flattened_index - draft_len)
start = prev_index * self.num_spec_tokens spec_flattened_indices.extend(
# prev_draft_token_indices is used to find which draft_tokens_id range(flattened_index - draft_len + 1, flattened_index + 1)
# should be copied to input_ids )
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] start = prev_index * self.num_spec_tokens
# flatten draft_tokens_id [1,2,3,4,5,6] # prev_draft_token_indices is used to find which draft_tokens_id
# draft_len of each request [1, 2, 1] # should be copied to input_ids
# then prev_draft_token_indices is [0, 2, 3, 4] # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
prev_draft_token_indices.extend(range(start, start + draft_len)) # flatten draft_tokens_id [1,2,3,4,5,6]
indices_match &= prev_index == flattened_index # draft_len of each request [1, 2, 1]
max_flattened_index = max(max_flattened_index, flattened_index) # then prev_draft_token_indices is [0, 2, 3, 4]
prev_draft_token_indices.extend(range(start, start + draft_len))
common_indices_match &= prev_index == flattened_index
max_flattened_index = max(max_flattened_index, flattened_index)
num_common_tokens = len(sample_flattened_indices) num_common_tokens = len(sample_flattened_indices)
total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
if num_common_tokens < total_without_spec: if num_common_tokens < total_without_spec:
# If not all requests are decodes from the last iteration, # If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the GPU first. # we need to copy the input_ids_cpu to the GPU first.
self.input_ids.copy_to_gpu(total_num_scheduled_tokens) self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.enable_prompt_embeds: if self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
...@@ -1575,7 +1684,7 @@ class GPUModelRunner( ...@@ -1575,7 +1684,7 @@ class GPUModelRunner(
# No requests in common with the previous iteration # No requests in common with the previous iteration
# So input_ids.cpu will have all the input ids. # So input_ids.cpu will have all the input ids.
return return
if indices_match and max_flattened_index == (num_common_tokens - 1): if common_indices_match and max_flattened_index == (num_common_tokens - 1):
# Common-case optimization: the batch is unchanged # Common-case optimization: the batch is unchanged
# and no reordering happened. # and no reordering happened.
# The indices are both the same permutation of 0..N-1 so # The indices are both the same permutation of 0..N-1 so
...@@ -1592,7 +1701,7 @@ class GPUModelRunner( ...@@ -1592,7 +1701,7 @@ class GPUModelRunner(
sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True) ).to(self.device, non_blocking=True)
prev_common_req_indices_tensor = torch.tensor( prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory prev_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True) ).to(self.device, non_blocking=True)
self.input_ids.gpu.scatter_( self.input_ids.gpu.scatter_(
dim=0, dim=0,
...@@ -1696,15 +1805,15 @@ class GPUModelRunner( ...@@ -1696,15 +1805,15 @@ class GPUModelRunner(
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # self.query_pos.np[:10]: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) cu_num_tokens = self._get_cumsum_and_arange(
num_scheduled_tokens, self.query_pos.np
)
# Get positions. # Get positions.
positions_np = self.positions.np[:total_num_scheduled_tokens] positions_np = (
np.add( self.input_batch.num_computed_tokens_cpu[req_indices]
self.input_batch.num_computed_tokens_cpu[req_indices], + self.query_pos.np[: cu_num_tokens[-1]]
arange,
out=positions_np,
) )
# Calculate M-RoPE positions. # Calculate M-RoPE positions.
...@@ -1782,9 +1891,6 @@ class GPUModelRunner( ...@@ -1782,9 +1891,6 @@ class GPUModelRunner(
output_idx += num_sched output_idx += num_sched
self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
# Prepare the attention metadata. # Prepare the attention metadata.
self.query_start_loc.np[0] = 0 self.query_start_loc.np[0] = 0
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
...@@ -1794,12 +1900,21 @@ class GPUModelRunner( ...@@ -1794,12 +1900,21 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu() self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
self.seq_lens.np[:num_reqs] = ( # Compute optimistic seq_lens (assumes all draft tokens from previous
self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens # iteration accepted). Store in optimistic_seq_lens_cpu for use by
# _build_attention_metadata (max_seq_len) and discard_request_mask.
# seq_lens (GPU) will be computed later using the same optimistic values.
torch.add(
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs],
torch.from_numpy(num_scheduled_tokens),
out=self.optimistic_seq_lens_cpu[:num_reqs],
) )
# Fill unused with 0 for full cuda graph mode. self.optimistic_seq_lens_cpu[num_reqs:].fill_(0)
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu() # Build prev_positions mapping: current pos -> prev pos (-1 if new).
# Used for gathering from previous iteration's GPU tensors.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
self._compute_prev_positions(num_reqs)
num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
num_tokens_np = np.array(num_tokens, dtype=np.int32) num_tokens_np = np.array(num_tokens, dtype=np.int32)
...@@ -1807,13 +1922,78 @@ class GPUModelRunner( ...@@ -1807,13 +1922,78 @@ class GPUModelRunner(
# Record which requests should not be sampled, # Record which requests should not be sampled,
# so that we could clear the sampled tokens before returning # so that we could clear the sampled tokens before returning
self.discard_request_mask.np[:num_reqs] = ( self.discard_request_mask.np[:num_reqs] = (
self.seq_lens.np[:num_reqs] < num_tokens_np self.optimistic_seq_lens_cpu[:num_reqs].numpy() < num_tokens_np
) )
self.discard_request_mask.copy_to_gpu(num_reqs) self.discard_request_mask.copy_to_gpu(num_reqs)
# Sync num_accepted_tokens from CPU (set by
# _update_states_after_model_execute for hybrid models).
if self.num_accepted_tokens_event is not None:
self.num_accepted_tokens_event.synchronize()
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
else:
self.num_accepted_tokens.np.fill(1)
self.num_accepted_tokens.gpu.fill_(1)
# Update num_computed_tokens on GPU. In async spec decode,
# CPU values are optimistic (all drafts accepted). The kernel
# corrects on GPU using the previous step's
# valid_sampled_token_count_gpu. Otherwise, just copy from CPU.
if (
self.use_async_spec_decode
and self.valid_sampled_token_count_gpu is not None
and prev_req_id_to_index
):
self.prev_positions.copy_to_gpu(num_reqs)
self.prev_num_draft_tokens.copy_to_gpu()
cpu_values = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs].to(
device=self.device, non_blocking=True
)
update_num_computed_tokens_for_batch_change(
self.num_computed_tokens,
self.num_accepted_tokens.gpu[:num_reqs],
self.prev_positions.gpu[:num_reqs],
self.valid_sampled_token_count_gpu,
self.prev_num_draft_tokens.gpu,
cpu_values,
)
else:
self.num_computed_tokens[:num_reqs].copy_(
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs],
non_blocking=True,
)
self.req_indices.np[:total_num_scheduled_tokens] = req_indices
self.req_indices.copy_to_gpu(total_num_scheduled_tokens)
req_indices_gpu = self.req_indices.gpu[:total_num_scheduled_tokens]
self.query_pos.copy_to_gpu(total_num_scheduled_tokens)
self.num_scheduled_tokens.np[:num_reqs] = num_scheduled_tokens
self.num_scheduled_tokens.copy_to_gpu(num_reqs)
num_scheduled_tokens_gpu = self.num_scheduled_tokens.gpu[:num_reqs]
self.positions[:total_num_scheduled_tokens] = (
self.num_computed_tokens[req_indices_gpu].to(torch.int64)
+ self.query_pos.gpu[:total_num_scheduled_tokens]
)
self.seq_lens[:num_reqs] = (
self.num_computed_tokens[:num_reqs] + num_scheduled_tokens_gpu
)
self.seq_lens[num_reqs:].fill_(0)
self.input_batch.block_table.compute_slot_mapping(
num_reqs,
self.query_start_loc.gpu[: num_reqs + 1],
self.positions[:total_num_scheduled_tokens],
)
# Copy the tensors to the GPU. # Copy the tensors to the GPU.
self._prepare_input_ids( self._prepare_input_ids(
scheduler_output, scheduler_output,
num_reqs,
total_num_scheduled_tokens, total_num_scheduled_tokens,
cu_num_tokens, cu_num_tokens,
) )
...@@ -1830,9 +2010,14 @@ class GPUModelRunner( ...@@ -1830,9 +2010,14 @@ class GPUModelRunner(
self.xdrope_positions.cpu[:, :total_num_scheduled_tokens], self.xdrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True, non_blocking=True,
) )
else: if self.use_async_spec_decode and (self.uses_mrope or self.uses_xdrope_dim > 0):
# Common case (1D positions) drift = self.num_computed_tokens[req_indices_gpu].to(
self.positions.copy_to_gpu(total_num_scheduled_tokens) torch.int64
) - self.input_batch.num_computed_tokens_cpu_tensor[req_indices].to(
device=self.device, dtype=torch.int64, non_blocking=True
)
target = self.mrope_positions if self.uses_mrope else self.xdrope_positions
target.gpu[:, :total_num_scheduled_tokens] += drift
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode: if not use_spec_decode:
...@@ -1857,12 +2042,13 @@ class GPUModelRunner( ...@@ -1857,12 +2042,13 @@ class GPUModelRunner(
draft_token_ids, draft_token_ids,
) in scheduler_output.scheduled_spec_decode_tokens.items(): ) in scheduler_output.scheduled_spec_decode_tokens.items():
req_idx = self.input_batch.req_id_to_index[req_id] req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids) draft_len = len(draft_token_ids)
num_draft_tokens[req_idx] = draft_len
if ( if (
self.input_batch.num_computed_tokens_cpu[req_idx] self.input_batch.num_computed_tokens_cpu[req_idx]
>= self.input_batch.num_prompt_tokens[req_idx] >= self.input_batch.num_prompt_tokens[req_idx]
): ):
num_decode_draft_tokens[req_idx] = len(draft_token_ids) num_decode_draft_tokens[req_idx] = draft_len
spec_decode_metadata = self._calc_spec_decode_metadata( spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens num_draft_tokens, cu_num_tokens
) )
...@@ -1924,16 +2110,7 @@ class GPUModelRunner( ...@@ -1924,16 +2110,7 @@ class GPUModelRunner(
# window size when capturing to make sure the correct kernel is selected. # window size when capturing to make sure the correct kernel is selected.
max_seq_len = self.max_model_len max_seq_len = self.max_model_len
else: else:
max_seq_len = self.seq_lens.np[:num_reqs].max().item() max_seq_len = self.optimistic_seq_lens_cpu.numpy()[:num_reqs].max().item()
if use_spec_decode:
if self.num_accepted_tokens_event is not None:
self.num_accepted_tokens_event.synchronize()
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
kv_cache_groups = self.kv_cache_config.kv_cache_groups kv_cache_groups = self.kv_cache_config.kv_cache_groups
...@@ -1963,22 +2140,29 @@ class GPUModelRunner( ...@@ -1963,22 +2140,29 @@ class GPUModelRunner(
attn_gid = self.routed_experts_attn_gid attn_gid = self.routed_experts_attn_gid
slot_mapping_attn = slot_mappings[attn_gid] slot_mapping_attn = slot_mappings[attn_gid]
self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy()
# Compute is_prefilling: True if request is still in prefill phase
# (num_computed_tokens < num_prompt_tokens). Used by mamba backends to
# distinguish actual decodes from short extends.
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded :num_reqs_padded
] ]
num_prompt_tokens_cpu = self.input_batch.num_prompt_tokens_cpu_tensor[ num_prompt_tokens_cpu = self.input_batch.num_prompt_tokens_cpu_tensor[
:num_reqs_padded :num_reqs_padded
] ]
seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs_padded]
# is_prefilling: True if request is still in prefill phase.
# Used by mamba backends to distinguish actual decodes from
# short extends.
is_prefilling = num_computed_tokens_cpu < num_prompt_tokens_cpu is_prefilling = num_computed_tokens_cpu < num_prompt_tokens_cpu
if self.use_async_spec_decode:
# GPU tensors are authoritative in async mode.
seq_lens_cpu = None
num_computed_tokens_cpu = None
cm_base = CommonAttentionMetadata( cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
seq_lens=self.seq_lens.gpu[:num_reqs_padded], seq_lens=self.seq_lens[:num_reqs_padded],
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], _seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs_padded, num_reqs=num_reqs_padded,
num_actual_tokens=num_tokens_padded, num_actual_tokens=num_tokens_padded,
...@@ -1992,7 +2176,7 @@ class GPUModelRunner( ...@@ -1992,7 +2176,7 @@ class GPUModelRunner(
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
self.seq_lens.cpu[:num_reqs], self.optimistic_seq_lens_cpu[:num_reqs],
self.dcp_world_size, self.dcp_world_size,
self.dcp_rank, self.dcp_rank,
self.parallel_config.cp_kv_cache_interleave_size, self.parallel_config.cp_kv_cache_interleave_size,
...@@ -2396,33 +2580,34 @@ class GPUModelRunner( ...@@ -2396,33 +2580,34 @@ class GPUModelRunner(
# [4, 1, 3, 1, 2] # [4, 1, 3, 1, 2]
num_sampled_tokens = num_draft_tokens + 1 num_sampled_tokens = num_draft_tokens + 1
# Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] # Step 1.
# arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] # cu_num_sampled_tokens: [4, 5, 8, 9, 11]
cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( # _arange_scratch[:11]: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
num_sampled_tokens, cumsum_dtype=np.int32 cu_num_sampled_tokens = self._get_cumsum_and_arange(
num_sampled_tokens, self._arange_scratch, cumsum_dtype=np.int32
) )
# Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_indices = np.repeat( logits_indices = np.repeat(
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens
) )
# Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
logits_indices += arange logits_indices += self._arange_scratch[: cu_num_sampled_tokens[-1]]
# Compute the bonus logits indices. # Compute the bonus logits indices.
bonus_logits_indices = cu_num_sampled_tokens - 1 bonus_logits_indices = cu_num_sampled_tokens - 1
# Compute the draft logits indices. # Compute the draft logits indices.
# cu_num_draft_tokens: [3, 3, 5, 5, 6] # cu_num_draft_tokens: [3, 3, 5, 5, 6]
# arange: [0, 1, 2, 0, 1, 0] # _arange_scratch[:6]: [0, 1, 2, 0, 1, 0]
cu_num_draft_tokens, arange = self._get_cumsum_and_arange( cu_num_draft_tokens = self._get_cumsum_and_arange(
num_draft_tokens, cumsum_dtype=np.int32 num_draft_tokens, self._arange_scratch, cumsum_dtype=np.int32
) )
# [0, 0, 0, 5, 5, 9] # [0, 0, 0, 5, 5, 9]
target_logits_indices = np.repeat( target_logits_indices = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens
) )
# [0, 1, 2, 5, 6, 9] # [0, 1, 2, 5, 6, 9]
target_logits_indices += arange target_logits_indices += self._arange_scratch[: cu_num_draft_tokens[-1]]
# TODO: Optimize the CPU -> GPU copy. # TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
...@@ -2924,7 +3109,7 @@ class GPUModelRunner( ...@@ -2924,7 +3109,7 @@ class GPUModelRunner(
) )
hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[:num_scheduled_tokens]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs] seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs]
pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor( pooling_metadata.build_pooling_cursor(
...@@ -3083,9 +3268,9 @@ class GPUModelRunner( ...@@ -3083,9 +3268,9 @@ class GPUModelRunner(
elif self.uses_xdrope_dim > 0: elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_input_tokens] positions = self.xdrope_positions.gpu[:, :num_input_tokens]
else: else:
positions = self.positions.gpu[:num_input_tokens] positions = self.positions[:num_input_tokens]
if num_input_tokens > num_scheduled_tokens: if num_input_tokens > num_scheduled_tokens:
self.positions.gpu[num_scheduled_tokens:num_input_tokens].zero_() self.positions[num_scheduled_tokens:num_input_tokens].zero_()
if is_first_rank: if is_first_rank:
intermediate_tensors = None intermediate_tensors = None
...@@ -3610,7 +3795,7 @@ class GPUModelRunner( ...@@ -3610,7 +3795,7 @@ class GPUModelRunner(
self.synchronize_input_prep(), self.synchronize_input_prep(),
): ):
# Update persistent batch states. # Update persistent batch states.
self._update_states(scheduler_output) deferred_state_corrections_fn = self._update_states(scheduler_output)
if has_ec_transfer() and not get_ec_transfer().is_consumer: if has_ec_transfer() and not get_ec_transfer().is_consumer:
with self.maybe_get_ec_connector_output( with self.maybe_get_ec_connector_output(
...@@ -3723,6 +3908,12 @@ class GPUModelRunner( ...@@ -3723,6 +3908,12 @@ class GPUModelRunner(
pad_attn = cudagraph_mode == CUDAGraphMode.FULL pad_attn = cudagraph_mode == CUDAGraphMode.FULL
if self.cache_config.mamba_cache_mode == "align": if self.cache_config.mamba_cache_mode == "align":
# preprocess_mamba reads req_state.num_computed_tokens (CPU)
# to decide copy operations, so we must apply deferred
# corrections before it runs.
if deferred_state_corrections_fn:
deferred_state_corrections_fn()
deferred_state_corrections_fn = None
mamba_utils.preprocess_mamba( mamba_utils.preprocess_mamba(
scheduler_output, scheduler_output,
self.kv_cache_config, self.kv_cache_config,
...@@ -3734,6 +3925,14 @@ class GPUModelRunner( ...@@ -3734,6 +3925,14 @@ class GPUModelRunner(
self.model.get_mamba_state_copy_func(), self.model.get_mamba_state_copy_func(),
self._get_mamba_copy_bufs(), self._get_mamba_copy_bufs(),
) )
# preprocess_mamba resets num_accepted_tokens_cpu to 1
# for requests whose state was copied to a new block.
# Re-sync to GPU so the mamba kernel reads from the
# correct initial state slot (init_token_idx = 0).
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
self.num_accepted_tokens.copy_to_gpu(num_reqs)
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
...@@ -3894,6 +4093,12 @@ class GPUModelRunner( ...@@ -3894,6 +4093,12 @@ class GPUModelRunner(
slot_mappings, slot_mappings,
) )
self.kv_connector_output = kv_connector_output self.kv_connector_output = kv_connector_output
# Now the batch has been launched we can wait for corrections from the
# previous model forward without breaking async scheduling.
if deferred_state_corrections_fn:
deferred_state_corrections_fn()
return None return None
@torch.inference_mode @torch.inference_mode
...@@ -3958,6 +4163,7 @@ class GPUModelRunner( ...@@ -3958,6 +4163,7 @@ class GPUModelRunner(
self._draft_token_ids = None self._draft_token_ids = None
self._draft_token_req_ids = None self._draft_token_req_ids = None
self.valid_sampled_token_count_gpu = None
self.input_batch.prev_sampled_token_ids = None self.input_batch.prev_sampled_token_ids = None
def propose_draft_token_ids(sampled_token_ids): def propose_draft_token_ids(sampled_token_ids):
...@@ -4002,7 +4208,7 @@ class GPUModelRunner( ...@@ -4002,7 +4208,7 @@ class GPUModelRunner(
assert spec_decode_common_attn_metadata is not None assert spec_decode_common_attn_metadata is not None
next_token_ids, valid_sampled_tokens_count = ( next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded( self.drafter.prepare_next_token_ids_padded(
spec_decode_common_attn_metadata, self.optimistic_seq_lens_cpu,
sampled_token_ids, sampled_token_ids,
self.requests, self.requests,
self.input_batch, self.input_batch,
...@@ -4237,6 +4443,9 @@ class GPUModelRunner( ...@@ -4237,6 +4443,9 @@ class GPUModelRunner(
counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True) counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True)
self.valid_sampled_token_count_event.record() self.valid_sampled_token_count_event.record()
if self.use_async_spec_decode:
# Stash for GPU-side correction in _prepare_inputs.
self.valid_sampled_token_count_gpu = valid_sampled_tokens_count
self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1) self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)
def _get_valid_sampled_token_count(self) -> list[int]: def _get_valid_sampled_token_count(self) -> list[int]:
...@@ -4366,7 +4575,7 @@ class GPUModelRunner( ...@@ -4366,7 +4575,7 @@ class GPUModelRunner(
) )
next_token_ids, valid_sampled_tokens_count = ( next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded( self.drafter.prepare_next_token_ids_padded(
common_attn_metadata, self.optimistic_seq_lens_cpu,
sampled_token_ids, sampled_token_ids,
self.requests, self.requests,
self.input_batch, self.input_batch,
...@@ -4405,7 +4614,7 @@ class GPUModelRunner( ...@@ -4405,7 +4614,7 @@ class GPUModelRunner(
) )
next_token_ids, valid_sampled_tokens_count = ( next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded( self.drafter.prepare_next_token_ids_padded(
common_attn_metadata, self.optimistic_seq_lens_cpu,
sampled_token_ids, sampled_token_ids,
self.requests, self.requests,
self.input_batch, self.input_batch,
...@@ -5148,14 +5357,19 @@ class GPUModelRunner( ...@@ -5148,14 +5357,19 @@ class GPUModelRunner(
# In the mixed batch mode (used for FI warmup), we use # In the mixed batch mode (used for FI warmup), we use
# shorter sequence lengths to run faster. # shorter sequence lengths to run faster.
# TODO(luka) better system for describing dummy batches # TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment] seq_lens = torch.tensor( # type: ignore[assignment]
[1] * num_decode_tokens + [num_prefill_tokens + 1],
dtype=torch.int,
)
else: else:
seq_lens = max_query_len # type: ignore[assignment] seq_lens = max_query_len # type: ignore[assignment]
self.seq_lens.np[:num_reqs] = seq_lens self.optimistic_seq_lens_cpu[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0 self.optimistic_seq_lens_cpu[num_reqs:].fill_(0)
self.seq_lens.copy_to_gpu() self.seq_lens.copy_(self.optimistic_seq_lens_cpu, non_blocking=True)
cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) cum_num_tokens = self._get_cumsum_and_arange(
num_scheduled_tokens, self.query_pos.np
)
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
self.query_start_loc.copy_to_gpu() self.query_start_loc.copy_to_gpu()
...@@ -5201,7 +5415,7 @@ class GPUModelRunner( ...@@ -5201,7 +5415,7 @@ class GPUModelRunner(
elif self.uses_xdrope_dim > 0: elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_tokens_padded] positions = self.xdrope_positions.gpu[:, :num_tokens_padded]
else: else:
positions = self.positions.gpu[:num_tokens_padded] positions = self.positions[:num_tokens_padded]
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
intermediate_tensors = None intermediate_tensors = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment