"vllm/vscode:/vscode.git/clone" did not exist on "e15a5ff07b89646090458ce6ae7908bcbae82b90"
Commit ca796e19 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.1' into v0.8.1-ori

parents e983c804 61c7a1b8
......@@ -47,6 +47,11 @@ class Sampler(nn.Module):
logits = self.apply_penalties(logits, sampling_metadata)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Convert sampled token ids to int64 (long) type to ensure compatibility
# with subsequent operations that may use these values as indices.
# This conversion is necessary because FlashInfer sampling operations
# return int32 (while PyTorch argmax and topk return int64).
sampled = sampled.long()
# Gather the logprobs of the topk and sampled token (if requested).
# Get logprobs and rank tensors (if requested)
......@@ -139,12 +144,14 @@ class Sampler(nn.Module):
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == torch.int64
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import numpy as np
import torch
@dataclass
class SpecDecodeMetadata:
# [num_tokens]
draft_token_ids: torch.Tensor
# [batch_size]
num_draft_tokens: list[int]
# [batch_size]
cu_num_draft_tokens: torch.Tensor
# [num_tokens]
target_logits_indices: torch.Tensor
# [batch_size]
bonus_logits_indices: torch.Tensor
# [num_tokens + batch_size]
logits_indices: torch.Tensor
def __post_init__(self):
self.max_spec_len = max(self.num_draft_tokens)
@classmethod
def make_dummy(
cls,
draft_token_ids: list[list[int]],
device: torch.device,
) -> "SpecDecodeMetadata":
batch_size = len(draft_token_ids)
num_draft_tokens = [len(ids) for ids in draft_token_ids]
flattened_draft_token_ids = sum(draft_token_ids, [])
num_tokens = len(flattened_draft_token_ids)
draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids,
dtype=torch.int32,
device=device)
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(
device)
target_logits_indices = torch.zeros(num_tokens,
dtype=torch.int32,
device=device)
bonus_logits_indices = torch.zeros(batch_size,
dtype=torch.int32,
device=device)
logits_indices = torch.zeros(num_tokens + batch_size,
dtype=torch.int32,
device=device)
return cls(
draft_token_ids=draft_token_ids_tensor,
num_draft_tokens=num_draft_tokens,
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.sample.ops.topk_topp_sampler import random_sample # noqa
from vllm.v1.worker.gpu_input_batch import InputBatch
......
......@@ -7,75 +7,27 @@ from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar)
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import xgrammar as xgr
import torch
from vllm.v1.request import Request
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
class StructuredOutputManager:
"""Engine-level manager for structured output requests."""
def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None
self.vllm_config = vllm_config
self.init_complete = False
def _delayed_init(self):
"""Initialization delayed until we know it is needed."""
tokenizer_group = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
parallel_config=self.vllm_config.parallel_config,
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = self.vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
encoded_vocab = [
token for token, _ in sorted(
tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if hasattr(
tokenizer,
"eos_token_id",
) and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
tokenizer_info = xgr.TokenizerInfo(
encoded_vocab=encoded_vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.BYTE_FALLBACK,
vocab_size=self.vocab_size,
stop_token_ids=stop_token_ids,
add_prefix_space=True,
)
else:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer,
vocab_size=self.vocab_size,
)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
self._grammar_bitmask: Optional[torch.Tensor] = None
# The default max_workers if not specified is the number of CPUs * 5,
# which is way too high since these tasks are CPU-bound, not I/O bound.
......@@ -83,28 +35,30 @@ class StructuredOutputManager:
# compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs,
self.vocab_size,
)
self.init_complete = True
def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None:
return
# The first time this is called, we need to finish initialization
# of xgrammar. We defer it to avoid the import of xgrammar and
# initialization cost if it is not going to be used.
if not self.init_complete:
self._delayed_init()
# Initialize the backend the first time it is needed.
#
# NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...).
if self.backend is None:
backend_name = request.sampling_params.guided_decoding.backend_name
if backend_name == "xgrammar":
self.backend = XgrammarBackend(self.vllm_config)
else:
raise ValueError(
f"Unsupported structured output backend: {backend_name}")
grammar: Future[Grammar] = self.executor.submit(
self._async_create_grammar, request)
grammar: Future[StructuredOutputGrammar] = self.executor.submit(
self._async_create_grammar, request, self.backend)
request.structured_output_request.grammar = grammar # type: ignore[assignment]
def _async_create_grammar(self, request: Request) -> Grammar:
def _async_create_grammar(
self, request: Request,
backend: StructuredOutputBackend) -> StructuredOutputGrammar:
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
# Note that the request was validated in the engine core client,
......@@ -114,28 +68,8 @@ class StructuredOutputManager:
# though it should be unlikely as we test that up front as well.
request_type, grammar_spec = key
if request_type == StructuredOutputOptions.JSON:
# TODO -- allow any_whitespace to be configurable
# pending merge of https://github.com/vllm-project/vllm/pull/12744
ctx = self.compiler.compile_json_schema(grammar_spec,
any_whitespace=False)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_builtin_json_grammar()
elif request_type == StructuredOutputOptions.GRAMMAR:
ctx = self.compiler.compile_grammar(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX:
ctx = self.compiler.compile_regex(grammar_spec)
else:
logger.error("Validation should have already occurred. "
"Please file an issue.")
raise ValueError(
f"grammar is not of valid supported types. ({request_type!s})")
return Grammar(
matcher=xgr.GrammarMatcher(ctx),
vocab_size=self.vocab_size,
ctx=ctx,
)
assert self.backend is not None
return self.backend.compile_grammar(request_type, grammar_spec)
def grammar_bitmask(
self,
......@@ -147,6 +81,11 @@ class StructuredOutputManager:
if not structured_output_request_ids:
return None
if self._grammar_bitmask is None:
assert self.backend is not None
self._grammar_bitmask = self.backend.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs)
# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# the batch.
......@@ -154,7 +93,7 @@ class StructuredOutputManager:
for req_id, batch_index in structured_output_request_ids.items():
request = requests[req_id].structured_output_request
assert request is not None and request.grammar is not None
if not request.grammar.matcher.is_terminated():
if not request.grammar.is_terminated():
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
if batch_len < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:batch_len]
......
# SPDX-License-Identifier: Apache-2.0
import enum
from abc import ABC, abstractmethod
import torch
class StructuredOutputOptions(enum.Enum):
JSON = enum.auto()
JSON_OBJECT = enum.auto()
REGEX = enum.auto()
GRAMMAR = enum.auto()
CHOICE = enum.auto()
StructuredOutputKey = tuple[StructuredOutputOptions, str]
class StructuredOutputGrammar(ABC):
"""Request-level backend for structured output requests."""
@abstractmethod
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""
Determines whether the provided tokens are accepted for the
given request.
Args:
request_id (str): The unique identifier for the request.
tokens (list[int]): A list of token IDs to evaluate.
Returns:
bool: True if the tokens are accepted, False otherwise.
"""
@abstractmethod
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
"""
Fills the bitmask for a specific batch index.
Args:
bitmask (torch.Tensor): The bitmask to fill
batch_index (int): The index in the bitmask to fill
"""
@abstractmethod
def is_terminated(self) -> bool:
"""
Checks whether the structured output process has terminated.
Returns:
bool: True if the process is terminated, False otherwise.
"""
@abstractmethod
def reset(self):
"""
Resets the state of the structured output grammar.
"""
class StructuredOutputBackend(ABC):
"""Engine-level backend for structured output requests."""
@abstractmethod
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
"""
Compiles a grammar specification into a structured output grammar.
Args:
request_type (StructuredOutputOptions): The type of structured
output request.
grammar_spec (str): The grammar specification to compile.
Returns:
StructuredOutputGrammar: The compiled structured output grammar.
"""
@abstractmethod
def allocate_token_bitmask(self, max_num_seqs: int):
"""
Allocates a token bitmask for the specified maximum number of sequences.
Args:
max_num_seqs (int): The maximum number of sequences for which
to allocate the bitmask.
"""
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions)
if TYPE_CHECKING:
import xgrammar as xgr
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
class XgrammarBackend(StructuredOutputBackend):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
encoded_vocab = [
token for token, _ in sorted(
tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if hasattr(
tokenizer,
"eos_token_id",
) and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
tokenizer_info = xgr.TokenizerInfo( # type: ignore
encoded_vocab=encoded_vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.BYTE_FALLBACK,
vocab_size=self.vocab_size,
stop_token_ids=stop_token_ids,
add_prefix_space=True,
)
else:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer,
vocab_size=self.vocab_size,
)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
if request_type == StructuredOutputOptions.JSON:
ctx = self.compiler.compile_json_schema(grammar_spec,
any_whitespace=False)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_builtin_json_grammar()
elif request_type == StructuredOutputOptions.GRAMMAR:
ctx = self.compiler.compile_grammar(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX:
ctx = self.compiler.compile_regex(grammar_spec)
else:
logger.error(
"Validation should have already occurred. Please file an issue."
)
raise ValueError(
f"grammar is not of valid supported types. ({request_type!s})")
return XgrammarGrammar(
matcher=xgr.GrammarMatcher(ctx),
vocab_size=self.vocab_size,
ctx=ctx,
)
def allocate_token_bitmask(self, max_num_seqs: int):
return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)
@dataclass
class XgrammarGrammar(StructuredOutputGrammar):
# NOTE: This would be a generic-enough class for
# supporting different backends, in the future.
# For now, just xgrammar.
#
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding
vocab_size: int
matcher: xgr.GrammarMatcher = field(hash=False)
ctx: xgr.CompiledGrammar = field(hash=False)
num_processed_tokens: int = field(default_factory=lambda: 0,
repr=False,
hash=False,
init=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
for token in tokens:
if not self.matcher.accept_token(token):
logger.error(
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue.", request_id, token)
return False
self.num_processed_tokens += 1
return True
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(bitmask, idx)
def is_terminated(self) -> bool:
return self.matcher.is_terminated()
def reset(self):
self.num_processed_tokens = 0
self.matcher.reset()
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import enum
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from vllm.logger import init_logger
from vllm.utils import LazyLoader
if TYPE_CHECKING:
import xgrammar as xgr
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
class StructuredOutputOptions(enum.Enum):
JSON = enum.auto()
JSON_OBJECT = enum.auto()
REGEX = enum.auto()
GRAMMAR = enum.auto()
CHOICE = enum.auto()
StructuredOutputKey = tuple[StructuredOutputOptions, str]
@dataclass
class Grammar:
# NOTE: This would be a generic-enough class for
# supporting different backends, in the future.
# For now, just xgrammar.
#
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding
vocab_size: int
matcher: xgr.GrammarMatcher = field(hash=False)
ctx: xgr.CompiledGrammar = field(hash=False)
num_processed_tokens: int = field(default_factory=lambda: 0,
repr=False,
hash=False,
init=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
for token in tokens:
if not self.matcher.accept_token(token):
logger.error(
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue.", request_id, token)
return False
self.num_processed_tokens += 1
return True
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool:
return self.matcher.fill_next_token_bitmask(bitmask, idx)
def reset(self):
self.num_processed_tokens = 0
self.matcher.reset()
def __copy__(self):
return Grammar(
matcher=xgr.GrammarMatcher(self.ctx),
vocab_size=self.vocab_size,
ctx=self.ctx,
)
......@@ -9,15 +9,17 @@ from concurrent.futures._base import TimeoutError
from typing import Optional, Union, cast
from vllm.sampling_params import SamplingParams
from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
StructuredOutputOptions)
from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar,
StructuredOutputKey,
StructuredOutputOptions)
@dataclasses.dataclass
class StructuredOutputRequest:
sampling_params: SamplingParams
_grammar: Optional[Union[Future[Grammar], Grammar]] = None
_grammar: Optional[Union[Future[StructuredOutputGrammar],
StructuredOutputGrammar]] = None
def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports
......@@ -37,12 +39,16 @@ class StructuredOutputRequest:
return self._check_grammar_completion()
@property
def grammar(self) -> Optional[Grammar]:
def grammar(self) -> Optional[StructuredOutputGrammar]:
completed = self._check_grammar_completion()
return cast(Optional[Grammar], self._grammar) if completed else None
return cast(Optional[StructuredOutputGrammar],
self._grammar) if completed else None
@grammar.setter
def grammar(self, grammar: Union[Grammar, Future[Grammar]]) -> None:
def grammar(
self, grammar: Union[StructuredOutputGrammar,
Future[StructuredOutputGrammar]]
) -> None:
self._grammar = grammar
@functools.cached_property
......
......@@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
from vllm.v1.utils import bind_kv_cache
......@@ -149,7 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True
self.rejection_sampler = RejectionSampler()
# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1."
......@@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
self.rejection_sampler = RejectionSampler()
# Request states.
self.requests: dict[str, CachedRequestState] = {}
......@@ -452,7 +453,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[FlashAttentionMetadata, torch.Tensor]:
) -> tuple[FlashAttentionMetadata, torch.Tensor,
Optional[SpecDecodeMetadata]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
......@@ -577,22 +579,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if use_spec_decode:
logits_indices = self._calc_spec_decode_metadata(
scheduler_output, cu_num_tokens)
else:
if not use_spec_decode:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = attn_metadata.query_start_loc[1:] - 1
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return attn_metadata, logits_indices
return attn_metadata, logits_indices, spec_decode_metadata
def _compute_cascade_attn_prefix_len(
self,
......@@ -732,50 +745,79 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _calc_spec_decode_metadata(
self,
scheduler_output: "SchedulerOutput",
cu_num_tokens: np.ndarray,
) -> torch.Tensor:
# Get the number of spec decode tokens for each request.
num_reqs = self.input_batch.num_reqs
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
for i, req_id in enumerate(self.input_batch.req_ids):
num_spec_decode_tokens[i] = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
# Get spec decode logits indices.
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1]
# num_sampled_tokens: [4, 1, 3, 1, 2]
# spec_decode_logits_indices:
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
num_sampled_tokens = num_spec_decode_tokens + 1
# logits_start_loc: [0, 103, 104, 206, 207]
logits_start_loc = cu_num_tokens - num_sampled_tokens
# [0, 103, 104, 206, 207] ->
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
# The following three lines:
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_sampled_offsets = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
total_num_sampled_tokens = num_sampled_tokens.sum()
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
cumsums_sampled_offsets)
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
spec_decode_logits_indices = logits_start_loc + sampled_arange
return torch.from_numpy(spec_decode_logits_indices).to(
num_draft_tokens: np.ndarray,
cu_num_scheduled_tokens: np.ndarray,
) -> SpecDecodeMetadata:
# Inputs:
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
# num_draft_tokens: [ 3, 0, 2, 0, 1]
# Outputs:
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
# 206, 207, 208]
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
# Compute the logits indices.
# [4, 1, 3, 1, 2]
num_sampled_tokens = num_draft_tokens + 1
# Step 1. [4, 5, 8, 9, 11]
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
total_num_sampled_tokens = cu_num_sampled_tokens[-1]
# Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
num_sampled_tokens)
# Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
# Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_indices = np.repeat(
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
logits_indices += arange
# Compute the bonus logits indices.
bonus_logits_indices = cu_num_sampled_tokens - 1
# Compute the draft logits indices.
# [3, 3, 5, 5, 6]
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
total_num_draft_tokens = cu_num_draft_tokens[-1]
# [0, 0, 0, 3, 3, 5]
cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
num_draft_tokens)
# [0, 1, 2, 0, 1, 0]
arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
# [0, 0, 0, 5, 5, 9]
target_logits_indices = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
# [0, 1, 2, 5, 6, 9]
target_logits_indices += arange
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
metadata = SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=num_draft_tokens.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
return metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
......@@ -931,7 +973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs = []
# Prepare the decoder inputs.
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
attn_metadata, logits_indices, spec_decode_metadata = (
self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
......@@ -1006,31 +1049,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if not self.use_spec_decode:
if spec_decode_metadata is None:
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
draft_token_ids = [
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
for req_id in self.input_batch.req_ids
]
sample_lens = [len(tokens) + 1 for tokens in draft_token_ids]
recover_logits_idx = np.cumsum(sample_lens) - 1
target_probs = self.rejection_sampler.compute_probs(
logits, sampling_metadata, sample_lens)
# TODO(woosuk): Optimize the memory usage.
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.model.sample(
logits=logits[recover_logits_idx, :],
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
# TODO(woosuk): Optimize the memory usage.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
draft_token_ids,
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
target_probs,
sampling_metadata)
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
......@@ -1066,13 +1107,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens.
valid_mask = sampled_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist()
# TODO(woosuk): Optimize this.
valid_sampled_token_ids = [
seq.tolist()
for seq in sampled_token_ids[valid_mask].split(gen_lens)
]
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids, self.input_batch.vocab_size)
if not self.use_spec_decode:
spec_token_ids = None
......@@ -1316,6 +1352,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"initializing the engine.") from e
else:
raise e
if self.use_spec_decode:
draft_token_ids = [[0] for _ in range(num_reqs)]
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids, self.device)
num_tokens = sum(len(ids) for ids in draft_token_ids)
# draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
draft_probs = None
target_logits = torch.randn(num_tokens,
logits.shape[-1],
device=self.device,
dtype=logits.dtype)
# NOTE(woosuk): Here, we should use int32 because the sampler uses
# int32 for bonus_token_ids. If the dtype mismatches, re-compilation
# will occur at runtime.
bonus_token_ids = torch.zeros(num_reqs,
device=self.device,
dtype=torch.int32)
self.rejection_sampler(
dummy_spec_decode_metadata,
draft_probs,
target_logits,
bonus_token_ids,
dummy_metadata,
)
return sampler_output
def profile_run(self) -> None:
......
......@@ -410,6 +410,9 @@ class TPUModelRunner:
# Do the padding and copy the tensors to the TPU.
padded_total_num_scheduled_tokens = _get_padded_token_len(
total_num_scheduled_tokens)
# Zero out to avoid spurious values from prev iteration (last cp chunk)
self.input_ids_cpu[
total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
self.input_ids = self.input_ids_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment