"vllm/vscode:/vscode.git/clone" did not exist on "4c69e228b32220ac9159dfdcf0df13ea776e630d"
Unverified Commit 09473ee4 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[mypy] Add mypy type annotation part 1 (#4006)

parent d4ec9ffb
name: mypy
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
...@@ -93,9 +93,23 @@ fi ...@@ -93,9 +93,23 @@ fi
echo 'vLLM yapf: Done' echo 'vLLM yapf: Done'
# Run mypy # Run mypy
# TODO(zhuohan): Enable mypy echo 'vLLM mypy:'
# echo 'vLLM mypy:' mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
# mypy mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
CODESPELL_EXCLUDES=( CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**' '--skip' '*docs/source/_build/**'
...@@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then ...@@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then
exit 1 exit 1
fi fi
...@@ -46,10 +46,13 @@ ignore = [ ...@@ -46,10 +46,13 @@ ignore = [
python_version = "3.8" python_version = "3.8"
ignore_missing_imports = true ignore_missing_imports = true
check_untyped_defs = true
files = "vllm" files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/" exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
]
[tool.codespell] [tool.codespell]
......
...@@ -12,3 +12,4 @@ pydantic >= 2.0 # Required for OpenAI server. ...@@ -12,3 +12,4 @@ pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0 prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer tiktoken == 0.6.0 # Required for DBRX tokenizer
outlines == 0.0.34 # Requires torch >= 2.1.0 outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions
\ No newline at end of file
...@@ -7,7 +7,7 @@ codespell==2.2.6 ...@@ -7,7 +7,7 @@ codespell==2.2.6
isort==5.13.2 isort==5.13.2
# type checking # type checking
mypy==0.991 mypy==1.9.0
types-PyYAML types-PyYAML
types-requests types-requests
types-setuptools types-setuptools
......
...@@ -2,7 +2,7 @@ import enum ...@@ -2,7 +2,7 @@ import enum
import json import json
import os import os
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, Optional, Union from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
import torch import torch
from packaging.version import Version from packaging.version import Version
...@@ -141,7 +141,7 @@ class ModelConfig: ...@@ -141,7 +141,7 @@ class ModelConfig:
supported_load_format = [ supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy" "auto", "pt", "safetensors", "npcache", "dummy"
] ]
rocm_not_supported_load_format = [] rocm_not_supported_load_format: List[str] = []
if load_format not in supported_load_format: if load_format not in supported_load_format:
raise ValueError( raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of " f"Unknown load format: {self.load_format}. Must be one of "
...@@ -679,6 +679,9 @@ class SpeculativeConfig: ...@@ -679,6 +679,9 @@ class SpeculativeConfig:
"num_speculative_tokens to be provided, but found " "num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.") f"{speculative_model=} and {num_speculative_tokens=}.")
assert (speculative_model is not None
and num_speculative_tokens is not None)
# TODO: The user should be able to specify revision/quantization/max # TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported. # model len for the draft model. It is not currently supported.
draft_revision = None draft_revision = None
...@@ -993,7 +996,7 @@ def _get_and_verify_max_len( ...@@ -993,7 +996,7 @@ def _get_and_verify_max_len(
derived_max_model_len *= scaling_factor derived_max_model_len *= scaling_factor
if max_model_len is None: if max_model_len is None:
max_model_len = derived_max_model_len max_model_len = int(derived_max_model_len)
elif max_model_len > derived_max_model_len: elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length # Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input # that will be bigger than derived_max_model_len. We compare user input
......
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from itertools import count, takewhile from itertools import count, takewhile
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional, Set
...@@ -231,10 +232,10 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -231,10 +232,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if self.enable_caching: if self.enable_caching:
logger.info("Automatic prefix caching is enabled.") logger.info("Automatic prefix caching is enabled.")
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
num_gpu_blocks) Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
num_cpu_blocks) Device.CPU, block_size, num_cpu_blocks)
else: else:
self.gpu_allocator = UncachedBlockAllocator( self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks) Device.GPU, block_size, num_gpu_blocks)
...@@ -588,7 +589,8 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -588,7 +589,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for b in takewhile(lambda b: b.computed, block_table[:-1]) for b in takewhile(lambda b: b.computed, block_table[:-1])
] ]
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Return the block ids that are common for a given sequence group. """Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks). Used in prefill (can skip prefill of some blocks).
......
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
from collections.abc import Sequence as GenericSequence
from typing import Dict, List, Optional from typing import Dict, List, Optional
from vllm.core.block.block_table import BlockTable from vllm.core.block.block_table import BlockTable
...@@ -205,7 +206,8 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -205,7 +206,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
# as computed. # as computed.
self.block_allocator.mark_blocks_as_computed() self.block_allocator.mark_blocks_as_computed()
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Determine which blocks for which we skip prefill. """Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks. With prefix caching we can skip prefill for previously-generated blocks.
......
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from typing import Dict, List from typing import Dict, List
from vllm.sequence import Sequence, SequenceGroup from vllm.sequence import Sequence, SequenceGroup
...@@ -103,7 +104,8 @@ class BlockSpaceManager(ABC): ...@@ -103,7 +104,8 @@ class BlockSpaceManager(ABC):
pass pass
@abstractmethod @abstractmethod
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
pass pass
@abstractmethod @abstractmethod
......
...@@ -42,8 +42,8 @@ class SchedulingBudget: ...@@ -42,8 +42,8 @@ class SchedulingBudget:
""" """
token_budget: int token_budget: int
max_num_seqs: int max_num_seqs: int
_requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set) _requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set) _requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0 _num_batched_tokens: int = 0
_num_curr_seqs: int = 0 _num_curr_seqs: int = 0
...@@ -133,7 +133,7 @@ class SchedulerOutputs: ...@@ -133,7 +133,7 @@ class SchedulerOutputs:
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy) and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self) -> bool: def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted( self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups, self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
...@@ -337,7 +337,8 @@ class Scheduler: ...@@ -337,7 +337,8 @@ class Scheduler:
self.free_seq(seq) self.free_seq(seq)
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
...@@ -404,7 +405,7 @@ class Scheduler: ...@@ -404,7 +405,7 @@ class Scheduler:
budget.subtract_num_seqs(seq_group.request_id, budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs) num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0: if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.pop(seq_group.lora_int_id) curr_loras.remove(seq_group.lora_int_id)
if running_queue: if running_queue:
# Preempt the lowest-priority sequence groups. # Preempt the lowest-priority sequence groups.
...@@ -496,7 +497,7 @@ class Scheduler: ...@@ -496,7 +497,7 @@ class Scheduler:
now = time.time() now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue) swapped_queue = policy.sort_by_priority(now, swapped_queue)
leftover_swapped = deque() leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue: while swapped_queue:
seq_group = swapped_queue[0] seq_group = swapped_queue[0]
...@@ -507,7 +508,9 @@ class Scheduler: ...@@ -507,7 +508,9 @@ class Scheduler:
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
lora_int_id = seq_group.lora_int_id lora_int_id = seq_group.lora_int_id
if (lora_int_id > 0 and lora_int_id not in curr_loras assert curr_loras is not None
assert self.lora_config is not None
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
and len(curr_loras) >= self.lora_config.max_loras): and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so # We don't have a space for another LoRA, so
# we ignore this request for now. # we ignore this request for now.
...@@ -593,7 +596,7 @@ class Scheduler: ...@@ -593,7 +596,7 @@ class Scheduler:
# Copy the queue so that the input queue is not modified. # Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue]) waiting_queue = deque([s for s in waiting_queue])
leftover_waiting_sequences = deque() leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue: while self._passed_delay(time.time()) and waiting_queue:
seq_group = waiting_queue[0] seq_group = waiting_queue[0]
...@@ -635,6 +638,8 @@ class Scheduler: ...@@ -635,6 +638,8 @@ class Scheduler:
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
lora_int_id = seq_group.lora_int_id lora_int_id = seq_group.lora_int_id
assert curr_loras is not None
assert self.lora_config is not None
if (self.lora_enabled and lora_int_id > 0 if (self.lora_enabled and lora_int_id > 0
and lora_int_id not in curr_loras and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras): and len(curr_loras) >= self.lora_config.max_loras):
...@@ -780,7 +785,7 @@ class Scheduler: ...@@ -780,7 +785,7 @@ class Scheduler:
token_budget=self.scheduler_config.max_num_batched_tokens, token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs, max_num_seqs=self.scheduler_config.max_num_seqs,
) )
curr_loras = set() curr_loras: Set[int] = set()
remaining_waiting, prefills = (self.waiting, remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty()) SchedulerPrefillOutputs.create_empty())
...@@ -1087,7 +1092,7 @@ class Scheduler: ...@@ -1087,7 +1092,7 @@ class Scheduler:
def _get_num_new_tokens(self, seq_group: SequenceGroup, def _get_num_new_tokens(self, seq_group: SequenceGroup,
status: SequenceStatus, enable_chunking: bool, status: SequenceStatus, enable_chunking: bool,
budget: SchedulingBudget) -> Tuple[int, bool]: budget: SchedulingBudget) -> int:
"""Get the next new tokens to compute for a given sequence group """Get the next new tokens to compute for a given sequence group
that's in a given `status`. that's in a given `status`.
......
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -144,7 +144,7 @@ def broadcast_tensor_dict( ...@@ -144,7 +144,7 @@ def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0, src: int = 0,
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]: ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.""" """Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group) ranks = torch.distributed.get_process_group_ranks(group)
...@@ -157,10 +157,10 @@ def broadcast_tensor_dict( ...@@ -157,10 +157,10 @@ def broadcast_tensor_dict(
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
if rank == src: if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}") dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list = []
for key, value in tensor_dict.items(): for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
assert value.is_cuda, ( assert value.is_cuda, (
...@@ -190,10 +190,10 @@ def broadcast_tensor_dict( ...@@ -190,10 +190,10 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list(recv_metadata_list, torch.distributed.broadcast_object_list(recv_metadata_list,
src=src, src=src,
group=group) group=group)
metadata_list = recv_metadata_list[0] assert recv_metadata_list[0] is not None
tensor_dict = {} tensor_dict = {}
async_handles = [] async_handles = []
for key, value in metadata_list: for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, tensor = torch.empty(value.size,
dtype=value.dtype, dtype=value.dtype,
......
import pickle import pickle
from typing import List, Optional, Tuple from typing import Callable, List, Optional, Tuple
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
from vllm.worker.worker import Worker
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -18,15 +19,20 @@ try: ...@@ -18,15 +19,20 @@ try:
if init_cached_hf_modules: if init_cached_hf_modules:
from transformers.dynamic_module_utils import init_hf_modules from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules() init_hf_modules()
self.worker = None self._worker: Optional[Worker] = None
# Since the compiled DAG runs a main execution # Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device. # in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on # The flag indicates is set_device is called on
# that thread. # that thread.
self.compiled_dag_cuda_device_set = False self.compiled_dag_cuda_device_set = False
def init_worker(self, worker_init_fn): def init_worker(self, worker_init_fn: Callable[[], Worker]):
self.worker = worker_init_fn() self._worker = worker_init_fn()
@property
def worker(self) -> Worker:
assert self._worker is not None
return self._worker
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.worker, name) return getattr(self.worker, name)
...@@ -70,8 +76,8 @@ except ImportError as e: ...@@ -70,8 +76,8 @@ except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. " logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with " "For distributed inference, please install Ray with "
"`pip install ray`.") "`pip install ray`.")
ray = None ray = None # type: ignore
RayWorkerVllm = None RayWorkerVllm = None # type: ignore
def initialize_ray_cluster( def initialize_ray_cluster(
......
...@@ -47,6 +47,7 @@ async def generate(request: Request) -> Response: ...@@ -47,6 +47,7 @@ async def generate(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case # Streaming case
......
...@@ -170,8 +170,12 @@ class LLM: ...@@ -170,8 +170,12 @@ class LLM:
multi_modal_data.data = multi_modal_data.data.to(torch.float16) multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine. # Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len( if prompts is not None:
prompt_token_ids) num_requests = len(prompts)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
for i in range(num_requests): for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[ token_ids = None if prompt_token_ids is None else prompt_token_ids[
......
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
import torch import torch
...@@ -61,7 +61,7 @@ class CPUExecutor(ExecutorBase): ...@@ -61,7 +61,7 @@ class CPUExecutor(ExecutorBase):
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
""" """
......
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
...@@ -66,7 +66,7 @@ class GPUExecutor(ExecutorBase): ...@@ -66,7 +66,7 @@ class GPUExecutor(ExecutorBase):
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
""" """
......
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
...@@ -47,7 +47,7 @@ class NeuronExecutor(ExecutorBase): ...@@ -47,7 +47,7 @@ class NeuronExecutor(ExecutorBase):
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
""" """
......
...@@ -3,7 +3,7 @@ import copy ...@@ -3,7 +3,7 @@ import copy
import os import os
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
...@@ -197,7 +197,7 @@ class RayGPUExecutor(ExecutorBase): ...@@ -197,7 +197,7 @@ class RayGPUExecutor(ExecutorBase):
max_parallel_loading_workers, max_parallel_loading_workers,
) )
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks. """Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes This invokes `determine_num_available_blocks` on each worker and takes
...@@ -205,7 +205,7 @@ class RayGPUExecutor(ExecutorBase): ...@@ -205,7 +205,7 @@ class RayGPUExecutor(ExecutorBase):
compatible with all workers. compatible with all workers.
Returns: Returns:
- tuple[num_gpu_blocks, num_cpu_blocks] - Tuple[num_gpu_blocks, num_cpu_blocks]
""" """
# Get the maximum number of blocks that can be allocated on GPU and CPU. # Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", ) num_blocks = self._run_workers("determine_num_available_blocks", )
...@@ -276,7 +276,7 @@ class RayGPUExecutor(ExecutorBase): ...@@ -276,7 +276,7 @@ class RayGPUExecutor(ExecutorBase):
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[List[Any]] = None, driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False, use_ray_compiled_dag: bool = False,
...@@ -291,6 +291,7 @@ class RayGPUExecutor(ExecutorBase): ...@@ -291,6 +291,7 @@ class RayGPUExecutor(ExecutorBase):
if use_ray_compiled_dag: if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single # Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it. # input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1) output_channels = self.forward_dag.execute(1)
else: else:
# Start the ray workers first. # Start the ray workers first.
...@@ -369,7 +370,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): ...@@ -369,7 +370,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[List[Any]] = None, driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
......
...@@ -5,7 +5,8 @@ from functools import cached_property ...@@ -5,7 +5,8 @@ from functools import cached_property
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from pydantic import conint from pydantic import Field
from typing_extensions import Annotated
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -127,7 +128,7 @@ class SamplingParams: ...@@ -127,7 +128,7 @@ class SamplingParams:
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None, logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
) -> None: ) -> None:
self.n = n self.n = n
self.best_of = best_of if best_of is not None else n self.best_of = best_of if best_of is not None else n
......
...@@ -171,10 +171,10 @@ class SequenceData: ...@@ -171,10 +171,10 @@ class SequenceData:
return self.prompt_token_ids[-1] return self.prompt_token_ids[-1]
return self.output_token_ids[-1] return self.output_token_ids[-1]
def get_prompt_token_ids(self) -> int: def get_prompt_token_ids(self) -> List[int]:
return self.prompt_token_ids return self.prompt_token_ids
def get_output_token_ids(self) -> int: def get_output_token_ids(self) -> List[int]:
return self.output_token_ids return self.output_token_ids
@property @property
...@@ -370,7 +370,7 @@ class SequenceGroupState: ...@@ -370,7 +370,7 @@ class SequenceGroupState:
"""Mutable state tied to a specific sequence group""" """Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling # torch.Generator used in seeded sampling
generator: Optional = None generator: Optional = None # type: ignore
class MultiModalData: class MultiModalData:
...@@ -599,7 +599,7 @@ class SequenceGroupMetadata: ...@@ -599,7 +599,7 @@ class SequenceGroupMetadata:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@property @property
def token_chunk_size(self) -> int: def token_chunk_size(self) -> Optional[int]:
"""Return the number of tokens to be processed (chunk size).""" """Return the number of tokens to be processed (chunk size)."""
return self._token_chunk_size return self._token_chunk_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment