Unverified Commit 43e3f8e4 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Misc] Various code simplifications (#31666)


Signed-off-by: default avatarnjhill <nickhill123@gmail.com>
parent bb4337b3
...@@ -10,10 +10,7 @@ logger = init_logger(__name__) ...@@ -10,10 +10,7 @@ logger = init_logger(__name__)
class AsyncScheduler(Scheduler): class AsyncScheduler(Scheduler):
def _update_after_schedule( def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
self,
scheduler_output: SchedulerOutput,
) -> None:
super()._update_after_schedule(scheduler_output) super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
...@@ -41,9 +38,7 @@ class AsyncScheduler(Scheduler): ...@@ -41,9 +38,7 @@ class AsyncScheduler(Scheduler):
) )
def _update_request_with_output( def _update_request_with_output(
self, self, request: Request, new_token_ids: list[int]
request: Request,
new_token_ids: list[int],
) -> tuple[list[int], bool]: ) -> tuple[list[int], bool]:
if request.discard_latest_async_tokens: if request.discard_latest_async_tokens:
# If the request is force preempted in reset_prefix_cache, we # If the request is force preempted in reset_prefix_cache, we
......
...@@ -85,10 +85,7 @@ class SchedulerInterface(ABC): ...@@ -85,10 +85,7 @@ class SchedulerInterface(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def update_draft_token_ids( def update_draft_token_ids(self, draft_token_ids: "DraftTokenIds") -> None:
self,
draft_token_ids: "DraftTokenIds",
) -> None:
"""Update the draft token ids for the scheduled requests.""" """Update the draft token ids for the scheduled requests."""
raise NotImplementedError raise NotImplementedError
......
...@@ -762,11 +762,7 @@ class Scheduler(SchedulerInterface): ...@@ -762,11 +762,7 @@ class Scheduler(SchedulerInterface):
self._update_after_schedule(scheduler_output) self._update_after_schedule(scheduler_output)
return scheduler_output return scheduler_output
def _preempt_request( def _preempt_request(self, request: Request, timestamp: float) -> None:
self,
request: Request,
timestamp: float,
) -> None:
"""Preempt a request and put it back to the waiting queue. """Preempt a request and put it back to the waiting queue.
NOTE: The request should be popped from the running queue outside of this NOTE: The request should be popped from the running queue outside of this
...@@ -786,10 +782,7 @@ class Scheduler(SchedulerInterface): ...@@ -786,10 +782,7 @@ class Scheduler(SchedulerInterface):
# Put the request back to the waiting queue. # Put the request back to the waiting queue.
self.waiting.prepend_request(request) self.waiting.prepend_request(request)
def _update_after_schedule( def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
self,
scheduler_output: SchedulerOutput,
) -> None:
# Advance the number of computed tokens for the request AFTER # Advance the number of computed tokens for the request AFTER
# the request is scheduled. # the request is scheduled.
# 1. The scheduler_output of the current step has to include the # 1. The scheduler_output of the current step has to include the
...@@ -1006,8 +999,7 @@ class Scheduler(SchedulerInterface): ...@@ -1006,8 +999,7 @@ class Scheduler(SchedulerInterface):
) )
curr_embeds_start, curr_embeds_end = ( curr_embeds_start, curr_embeds_end = (
mm_feature.mm_position.get_embeds_indices_in_range( mm_feature.mm_position.get_embeds_indices_in_range(
start_idx_rel, start_idx_rel, end_idx_rel
end_idx_rel,
) )
) )
# There's no embeddings in the current range of encoder placeholder tokens # There's no embeddings in the current range of encoder placeholder tokens
...@@ -1034,8 +1026,7 @@ class Scheduler(SchedulerInterface): ...@@ -1034,8 +1026,7 @@ class Scheduler(SchedulerInterface):
) )
def get_grammar_bitmask( def get_grammar_bitmask(
self, self, scheduler_output: SchedulerOutput
scheduler_output: SchedulerOutput,
) -> GrammarOutput | None: ) -> GrammarOutput | None:
# Collect list of scheduled request ids that use structured output. # Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order. # The corresponding rows of the bitmask will be in this order.
...@@ -1285,9 +1276,7 @@ class Scheduler(SchedulerInterface): ...@@ -1285,9 +1276,7 @@ class Scheduler(SchedulerInterface):
return engine_core_outputs return engine_core_outputs
def _update_request_with_output( def _update_request_with_output(
self, self, request: Request, new_token_ids: list[int]
request: Request,
new_token_ids: list[int],
) -> tuple[list[int], bool]: ) -> tuple[list[int], bool]:
# Append generated tokens and check for stop. Note that if # Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner # a request is still being prefilled, we expect the model runner
...@@ -1328,10 +1317,7 @@ class Scheduler(SchedulerInterface): ...@@ -1328,10 +1317,7 @@ class Scheduler(SchedulerInterface):
# in the decoder's KV cache. # in the decoder's KV cache.
self.encoder_cache_manager.free_encoder_input(request, input_id) self.encoder_cache_manager.free_encoder_input(request, input_id)
def update_draft_token_ids( def update_draft_token_ids(self, draft_token_ids: DraftTokenIds) -> None:
self,
draft_token_ids: DraftTokenIds,
) -> None:
for req_id, spec_token_ids in zip( for req_id, spec_token_ids in zip(
draft_token_ids.req_ids, draft_token_ids.req_ids,
draft_token_ids.draft_token_ids, draft_token_ids.draft_token_ids,
...@@ -1361,9 +1347,7 @@ class Scheduler(SchedulerInterface): ...@@ -1361,9 +1347,7 @@ class Scheduler(SchedulerInterface):
request.record_event(EngineCoreEventType.QUEUED) request.record_event(EngineCoreEventType.QUEUED)
def finish_requests( def finish_requests(
self, self, request_ids: str | Iterable[str], finished_status: RequestStatus
request_ids: str | Iterable[str],
finished_status: RequestStatus,
) -> None: ) -> None:
"""Handles the finish signal from outside the scheduler. """Handles the finish signal from outside the scheduler.
......
...@@ -204,10 +204,7 @@ class EagleProposer: ...@@ -204,10 +204,7 @@ class EagleProposer:
) )
# Precompute draft position offsets in flattened tree. # Precompute draft position offsets in flattened tree.
self.tree_draft_pos_offsets = torch.arange( self.tree_draft_pos_offsets = torch.arange(
1, 1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
len(self.tree_choices) + 1,
device=device,
dtype=torch.int32,
).repeat(max_batch_size, 1) ).repeat(max_batch_size, 1)
def _get_positions(self, num_tokens: int): def _get_positions(self, num_tokens: int):
...@@ -287,8 +284,7 @@ class EagleProposer: ...@@ -287,8 +284,7 @@ class EagleProposer:
per_layer_attn_metadata[layer_name] = draft_indexer_metadata per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp( num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
num_tokens_padded=num_tokens,
) )
cudagraph_runtime_mode = CUDAGraphMode.NONE cudagraph_runtime_mode = CUDAGraphMode.NONE
...@@ -391,8 +387,7 @@ class EagleProposer: ...@@ -391,8 +387,7 @@ class EagleProposer:
draft_token_ids_list = [draft_token_ids] draft_token_ids_list = [draft_token_ids]
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp( batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=batch_size, num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
num_tokens_padded=batch_size,
) )
if ( if (
...@@ -610,10 +605,8 @@ class EagleProposer: ...@@ -610,10 +605,8 @@ class EagleProposer:
assert discard_request_mask.dtype == torch.bool assert discard_request_mask.dtype == torch.bool
assert backup_tokens_gpu.dtype == torch.int32 assert backup_tokens_gpu.dtype == torch.int32
next_token_ids = torch.empty((batch_size,), dtype=torch.int32, device=device) next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
valid_sampled_tokens_count = torch.empty( valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
(batch_size,), dtype=torch.int32, device=device
)
# Kernel grid: one program per request (row) # Kernel grid: one program per request (row)
grid = (batch_size,) grid = (batch_size,)
...@@ -782,8 +775,7 @@ class EagleProposer: ...@@ -782,8 +775,7 @@ class EagleProposer:
max_query_len=query_len, max_query_len=query_len,
) )
attn_metadata = tree_attn_metadata_builder.build_for_drafting( attn_metadata = tree_attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata, draft_index=level + 1
draft_index=level + 1,
) )
# Apply new attention metadata to all layers. # Apply new attention metadata to all layers.
...@@ -1161,8 +1153,8 @@ class EagleProposer: ...@@ -1161,8 +1153,8 @@ class EagleProposer:
def dummy_run( def dummy_run(
self, self,
num_tokens: int, num_tokens: int,
use_cudagraphs=True, use_cudagraphs: bool = True,
is_graph_capturing=False, is_graph_capturing: bool = False,
) -> None: ) -> None:
# Determine if CUDA graphs should be used for this run. # Determine if CUDA graphs should be used for this run.
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
...@@ -1174,8 +1166,7 @@ class EagleProposer: ...@@ -1174,8 +1166,7 @@ class EagleProposer:
): ):
if fwd_idx <= 1: if fwd_idx <= 1:
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp( num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
num_tokens_padded=num_tokens,
) )
if ( if (
cudagraphs_enabled cudagraphs_enabled
...@@ -1342,9 +1333,5 @@ def compute_probs_and_sample_next_token( ...@@ -1342,9 +1333,5 @@ def compute_probs_and_sample_next_token(
next_token_ids = probs.div(q).argmax(dim=-1).view(-1) next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
if not sampling_metadata.all_random: if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1) greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where( next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
is_greedy,
greedy_token_ids,
next_token_ids,
)
return next_token_ids, probs return next_token_ids, probs
...@@ -28,8 +28,6 @@ if TYPE_CHECKING: ...@@ -28,8 +28,6 @@ if TYPE_CHECKING:
else: else:
torch = LazyLoader("torch", globals(), "torch") torch = LazyLoader("torch", globals(), "torch")
ReasoningParser = object
Request = object
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -98,7 +96,7 @@ class StructuredOutputManager: ...@@ -98,7 +96,7 @@ class StructuredOutputManager:
self.vllm_config.structured_outputs_config.enable_in_reasoning self.vllm_config.structured_outputs_config.enable_in_reasoning
) )
def grammar_init(self, request: Request) -> None: def grammar_init(self, request: "Request") -> None:
if request.structured_output_request is None: if request.structured_output_request is None:
return return
...@@ -156,10 +154,7 @@ class StructuredOutputManager: ...@@ -156,10 +154,7 @@ class StructuredOutputManager:
grammar = self._create_grammar(request) # type: ignore[assignment] grammar = self._create_grammar(request) # type: ignore[assignment]
request.structured_output_request.grammar = grammar # type: ignore[assignment] request.structured_output_request.grammar = grammar # type: ignore[assignment]
def _create_grammar( def _create_grammar(self, request: "Request") -> StructuredOutputGrammar:
self,
request: Request,
) -> StructuredOutputGrammar:
key = request.structured_output_request.structured_output_key # type: ignore[union-attr] key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
# Note that the request was validated in the engine core client, # Note that the request was validated in the engine core client,
...@@ -173,8 +168,7 @@ class StructuredOutputManager: ...@@ -173,8 +168,7 @@ class StructuredOutputManager:
return self.backend.compile_grammar(request_type, grammar_spec) return self.backend.compile_grammar(request_type, grammar_spec)
def _fill_bitmasks( def _fill_bitmasks(
self, self, batch: Iterable[tuple[StructuredOutputGrammar, int, bool]]
batch: Iterable[tuple[StructuredOutputGrammar, int, bool]],
) -> None: ) -> None:
assert self._grammar_bitmask is not None assert self._grammar_bitmask is not None
for grammar, index, apply_bitmask in batch: for grammar, index, apply_bitmask in batch:
...@@ -187,14 +181,13 @@ class StructuredOutputManager: ...@@ -187,14 +181,13 @@ class StructuredOutputManager:
self._grammar_bitmask[index].fill_(self._full_mask) self._grammar_bitmask[index].fill_(self._full_mask)
def _async_submit_fill_bitmask( def _async_submit_fill_bitmask(
self, self, batch: list[tuple[StructuredOutputGrammar, int, bool]]
batch: list[tuple[StructuredOutputGrammar, int, bool]],
) -> Future: ) -> Future:
return self.executor_for_fillmask.submit(self._fill_bitmasks, batch) return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)
def grammar_bitmask( def grammar_bitmask(
self, self,
requests: dict[str, Request], requests: dict[str, "Request"],
structured_output_request_ids: list[str], structured_output_request_ids: list[str],
scheduled_spec_decode_tokens: dict[str, list[int]], scheduled_spec_decode_tokens: dict[str, list[int]],
) -> "npt.NDArray[np.int32] | None": ) -> "npt.NDArray[np.int32] | None":
...@@ -239,11 +232,10 @@ class StructuredOutputManager: ...@@ -239,11 +232,10 @@ class StructuredOutputManager:
if TYPE_CHECKING: if TYPE_CHECKING:
assert structured_output_request is not None assert structured_output_request is not None
assert structured_output_request.grammar is not None assert structured_output_request.grammar is not None
grammar = structured_output_request.grammar
apply_bitmask = self.should_fill_bitmask(request) apply_bitmask = self.should_fill_bitmask(request)
batch.append( batch.append((grammar, cumulative_index, apply_bitmask))
(structured_output_request.grammar, cumulative_index, apply_bitmask)
)
if len(batch) == self.fill_bitmask_parallel_batch_size: if len(batch) == self.fill_bitmask_parallel_batch_size:
promises.append(self._async_submit_fill_bitmask(batch)) promises.append(self._async_submit_fill_bitmask(batch))
batch = [] batch = []
...@@ -264,34 +256,23 @@ class StructuredOutputManager: ...@@ -264,34 +256,23 @@ class StructuredOutputManager:
if TYPE_CHECKING: if TYPE_CHECKING:
assert structured_output_request is not None assert structured_output_request is not None
assert structured_output_request.grammar is not None assert structured_output_request.grammar is not None
grammar = structured_output_request.grammar
apply_bitmask = self.should_fill_bitmask(request) apply_bitmask = self.should_fill_bitmask(request)
state_advancements = 0 state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, ()) req_tokens = scheduled_spec_decode_tokens.get(req_id, ())
for token in itertools.chain(req_tokens, (None,)): for token in itertools.chain(req_tokens, (-1,)):
self._fill_bitmasks( self._fill_bitmasks(((grammar, cumulative_index, apply_bitmask),))
( if token == -1:
( # Stop advancing the grammar once we hit a padding token.
structured_output_request.grammar, apply_bitmask = False
cumulative_index, if apply_bitmask and not grammar.is_terminated():
apply_bitmask, accepted = grammar.accept_tokens(req_id, [token])
),
)
)
if (
apply_bitmask
and token is not None
and not structured_output_request.grammar.is_terminated()
):
accepted = structured_output_request.grammar.accept_tokens(
req_id, [token]
)
assert accepted, (token, req_id, scheduled_spec_decode_tokens) assert accepted, (token, req_id, scheduled_spec_decode_tokens)
state_advancements += 1 state_advancements += 1
cumulative_index += 1 cumulative_index += 1
if state_advancements > 0: if state_advancements > 0:
structured_output_request.grammar.rollback(state_advancements) grammar.rollback(state_advancements)
bitmask_tensor = self._grammar_bitmask bitmask_tensor = self._grammar_bitmask
if cumulative_index < bitmask_tensor.shape[0]: if cumulative_index < bitmask_tensor.shape[0]:
...@@ -302,7 +283,7 @@ class StructuredOutputManager: ...@@ -302,7 +283,7 @@ class StructuredOutputManager:
# and deserialization when sending this to the GPU workers. # and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy() return bitmask_tensor.numpy()
def should_fill_bitmask(self, request: Request) -> bool: def should_fill_bitmask(self, request: "Request") -> bool:
# NOTE (Hanchen) if enable_in_reasoning is True, it means that # NOTE (Hanchen) if enable_in_reasoning is True, it means that
# the model needs to be constrained in reasoning. So we should always # the model needs to be constrained in reasoning. So we should always
# enable the bitmask filling. # enable the bitmask filling.
...@@ -318,7 +299,7 @@ class StructuredOutputManager: ...@@ -318,7 +299,7 @@ class StructuredOutputManager:
return request.structured_output_request.reasoning_ended return request.structured_output_request.reasoning_ended
return True return True
def should_advance(self, request: Request) -> bool: def should_advance(self, request: "Request") -> bool:
if not request.use_structured_output: if not request.use_structured_output:
return False return False
......
...@@ -5,6 +5,7 @@ from __future__ import annotations ...@@ -5,6 +5,7 @@ from __future__ import annotations
import hashlib import hashlib
import importlib.metadata import importlib.metadata
import os import os
import tempfile
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import numpy as np import numpy as np
...@@ -34,9 +35,6 @@ else: ...@@ -34,9 +35,6 @@ else:
"convert_slow_tokenizer", globals(), "transformers.convert_slow_tokenizer" "convert_slow_tokenizer", globals(), "transformers.convert_slow_tokenizer"
) )
TokenizerLike = object
SchedulerOutput = object
InputBatch = object
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -72,13 +70,12 @@ def apply_grammar_bitmask( ...@@ -72,13 +70,12 @@ def apply_grammar_bitmask(
# request in the batch, as the logit indices are offset by this amount. # request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {} struct_out_req_batch_indices: dict[str, int] = {}
cumulative_offset = 0 cumulative_offset = 0
seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1]) spec_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id, batch_index in seq: struct_out_req_ids = set(grammar_output.structured_output_request_ids)
for batch_index, req_id in enumerate(input_batch.req_ids):
logit_index = batch_index + cumulative_offset logit_index = batch_index + cumulative_offset
cumulative_offset += len( cumulative_offset += len(spec_tokens.get(req_id, ()))
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) if req_id in struct_out_req_ids:
)
if req_id in grammar_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index struct_out_req_batch_indices[req_id] = logit_index
out_indices = [] out_indices = []
...@@ -91,14 +88,12 @@ def apply_grammar_bitmask( ...@@ -91,14 +88,12 @@ def apply_grammar_bitmask(
) )
cumulative_index = 0 cumulative_index = 0
for req_id in grammar_output.structured_output_request_ids: for req_id in grammar_output.structured_output_request_ids:
num_spec_tokens = len( num_spec_tokens = len(spec_tokens.get(req_id, ()))
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) if (logit_idx := struct_out_req_batch_indices.get(req_id)) is not None:
)
if req_id in struct_out_req_batch_indices:
logit_index = struct_out_req_batch_indices[req_id]
for i in range(1 + num_spec_tokens): for i in range(1 + num_spec_tokens):
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i] bitmask_index = logit_idx + i
out_indices.append(logit_index + i) sorted_bitmask[bitmask_index] = grammar_bitmask[cumulative_index + i]
out_indices.append(bitmask_index)
cumulative_index += 1 + num_spec_tokens cumulative_index += 1 + num_spec_tokens
# Copy async to device as tensor. # Copy async to device as tensor.
...@@ -149,21 +144,19 @@ def get_outlines_cache_path() -> str: ...@@ -149,21 +144,19 @@ def get_outlines_cache_path() -> str:
if outlines_cache_dir: if outlines_cache_dir:
# OUTLINES_CACHE_DIR takes precedence # OUTLINES_CACHE_DIR takes precedence
return outlines_cache_dir return outlines_cache_dir
elif xdg_cache_home: if xdg_cache_home:
return os.path.join(xdg_cache_home, ".cache", "outlines") return os.path.join(xdg_cache_home, ".cache", "outlines")
# If homedir is "/", we may be inside a container, and thus writing to # If homedir is "/", we may be inside a container, and thus writing to
# root would be problematic, so we fall back to using a tempfile. # root would be problematic, so we fall back to using a tempfile.
# Also validate the path exists, since os.path.expanduser does # Also validate the path exists, since os.path.expanduser does
# not guarantee existence. # not guarantee existence.
elif os.path.isdir(home_dir) and home_dir != "/": if os.path.isdir(home_dir) and home_dir != "/":
# Default Unix fallback: ~/.cache/outlines # Default Unix fallback: ~/.cache/outlines
return os.path.join(home_dir, ".cache", "outlines") return os.path.join(home_dir, ".cache", "outlines")
else:
import tempfile
# home_dir may be / inside a docker container without existing user # home_dir may be / inside a docker container without existing user
tempdir = tempfile.gettempdir() tempdir = tempfile.gettempdir()
return os.path.join(tempdir, ".cache", "outlines") return os.path.join(tempdir, ".cache", "outlines")
def get_outlines_cache(): def get_outlines_cache():
...@@ -184,8 +177,8 @@ def get_outlines_cache(): ...@@ -184,8 +177,8 @@ def get_outlines_cache():
cache.clear() cache.clear()
cache.set("__version__", outlines_version) cache.set("__version__", outlines_version)
return cache return cache
else:
return LRUCache(maxsize=128) return LRUCache(maxsize=128)
re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
...@@ -193,8 +186,7 @@ re_replacement_seq = re.compile(r"^.{0,6}�+.{0,6}$") ...@@ -193,8 +186,7 @@ re_replacement_seq = re.compile(r"^.{0,6}�+.{0,6}$")
def _reduced_vocabulary( def _reduced_vocabulary(
tokenizer: TokenizerLike, tokenizer: TokenizerLike, eos_token_id: int
eos_token_id: int,
) -> dict[bytes, list[int]]: ) -> dict[bytes, list[int]]:
"""Create a map from vocabulary tokens to lists of equivalent token ids. """Create a map from vocabulary tokens to lists of equivalent token ids.
...@@ -267,17 +259,13 @@ def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary: ...@@ -267,17 +259,13 @@ def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary:
return tokenizer._outlines_vocabulary # type: ignore return tokenizer._outlines_vocabulary # type: ignore
try: try:
if ( if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
hasattr(
tokenizer,
"eos_token_id",
)
and tokenizer.eos_token_id is not None
):
eos_token_id = tokenizer.eos_token_id eos_token_id = tokenizer.eos_token_id
else: else:
raise ValueError( raise ValueError(
f"Error during structured outputs setup for outlines: Tokenizer ({type(tokenizer)}) has no `eos_token_id` property, but `eos_token_id` is required for structured outputs to work properly." # noqa: E501 "Error during structured outputs setup for outlines: Tokenizer "
f"({type(tokenizer)}) has no `eos_token_id` property, but "
"`eos_token_id` is required for structured outputs to work properly."
) )
reduced_vocab = _reduced_vocabulary( reduced_vocab = _reduced_vocabulary(
...@@ -290,7 +278,7 @@ def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary: ...@@ -290,7 +278,7 @@ def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary:
return vocabulary return vocabulary
except AttributeError as e: except AttributeError as e:
raise ValueError( raise ValueError(
f"Cannot get the vocabulary of the tokenizer " "Cannot get the vocabulary of the tokenizer "
f"({type(tokenizer)}). The tokenizer should have a " f"({type(tokenizer)}). The tokenizer should have a "
"get_vocab method." "get_vocab method."
) from e ) from e
......
...@@ -3564,14 +3564,13 @@ class GPUModelRunner( ...@@ -3564,14 +3564,13 @@ class GPUModelRunner(
def _get_valid_sampled_token_count(self) -> list[int]: def _get_valid_sampled_token_count(self) -> list[int]:
# Wait until valid_sampled_tokens_count is copied to cpu, # Wait until valid_sampled_tokens_count is copied to cpu,
prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
if ( sampled_count_event = self.valid_sampled_token_count_event
self.valid_sampled_token_count_event is None if sampled_count_event is None or prev_sampled_token_ids is None:
or prev_sampled_token_ids is None
):
return [] return []
counts_cpu = self.valid_sampled_token_count_cpu counts_cpu = self.valid_sampled_token_count_cpu
self.valid_sampled_token_count_event.synchronize() assert counts_cpu is not None
sampled_count_event.synchronize()
return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist() return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
def propose_draft_token_ids( def propose_draft_token_ids(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment