Unverified Commit 09473ee4 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[mypy] Add mypy type annotation part 1 (#4006)

parent d4ec9ffb
name: mypy
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
......@@ -93,9 +93,23 @@ fi
echo 'vLLM yapf: Done'
# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy
echo 'vLLM mypy:'
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**'
......@@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then
exit 1
fi
......@@ -46,10 +46,13 @@ ignore = [
python_version = "3.8"
ignore_missing_imports = true
check_untyped_defs = true
files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
]
[tool.codespell]
......
......@@ -12,3 +12,4 @@ pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions
\ No newline at end of file
......@@ -7,7 +7,7 @@ codespell==2.2.6
isort==5.13.2
# type checking
mypy==0.991
mypy==1.9.0
types-PyYAML
types-requests
types-setuptools
......
......@@ -2,7 +2,7 @@ import enum
import json
import os
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, Optional, Union
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
import torch
from packaging.version import Version
......@@ -141,7 +141,7 @@ class ModelConfig:
supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy"
]
rocm_not_supported_load_format = []
rocm_not_supported_load_format: List[str] = []
if load_format not in supported_load_format:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
......@@ -679,6 +679,9 @@ class SpeculativeConfig:
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")
assert (speculative_model is not None
and num_speculative_tokens is not None)
# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
draft_revision = None
......@@ -993,7 +996,7 @@ def _get_and_verify_max_len(
derived_max_model_len *= scaling_factor
if max_model_len is None:
max_model_len = derived_max_model_len
max_model_len = int(derived_max_model_len)
elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input
......
"""A block manager that manages token blocks."""
from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set
......@@ -231,10 +232,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if self.enable_caching:
logger.info("Automatic prefix caching is enabled.")
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
num_gpu_blocks)
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
else:
self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
......@@ -588,7 +589,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for b in takewhile(lambda b: b.computed, block_table[:-1])
]
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
......
"""A block manager that manages token blocks."""
from collections.abc import Sequence as GenericSequence
from typing import Dict, List, Optional
from vllm.core.block.block_table import BlockTable
......@@ -205,7 +206,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
# as computed.
self.block_allocator.mark_blocks_as_computed()
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks.
......
import enum
from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from typing import Dict, List
from vllm.sequence import Sequence, SequenceGroup
......@@ -103,7 +104,8 @@ class BlockSpaceManager(ABC):
pass
@abstractmethod
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
pass
@abstractmethod
......
......@@ -42,8 +42,8 @@ class SchedulingBudget:
"""
token_budget: int
max_num_seqs: int
_requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set)
_requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0
_num_curr_seqs: int = 0
......@@ -133,7 +133,7 @@ class SchedulerOutputs:
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self) -> bool:
def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
......@@ -337,7 +337,8 @@ class Scheduler:
self.free_seq(seq)
def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
......@@ -404,7 +405,7 @@ class Scheduler:
budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.pop(seq_group.lora_int_id)
curr_loras.remove(seq_group.lora_int_id)
if running_queue:
# Preempt the lowest-priority sequence groups.
......@@ -496,7 +497,7 @@ class Scheduler:
now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue)
leftover_swapped = deque()
leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue:
seq_group = swapped_queue[0]
......@@ -507,7 +508,9 @@ class Scheduler:
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if (lora_int_id > 0 and lora_int_id not in curr_loras
assert curr_loras is not None
assert self.lora_config is not None
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
......@@ -593,7 +596,7 @@ class Scheduler:
# Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue])
leftover_waiting_sequences = deque()
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
seq_group = waiting_queue[0]
......@@ -635,6 +638,8 @@ class Scheduler:
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
assert curr_loras is not None
assert self.lora_config is not None
if (self.lora_enabled and lora_int_id > 0
and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
......@@ -780,7 +785,7 @@ class Scheduler:
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)
curr_loras = set()
curr_loras: Set[int] = set()
remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
......@@ -1087,7 +1092,7 @@ class Scheduler:
def _get_num_new_tokens(self, seq_group: SequenceGroup,
status: SequenceStatus, enable_chunking: bool,
budget: SchedulingBudget) -> Tuple[int, bool]:
budget: SchedulingBudget) -> int:
"""Get the next new tokens to compute for a given sequence group
that's in a given `status`.
......
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch.distributed import ProcessGroup
......@@ -144,7 +144,7 @@ def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]:
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
......@@ -157,10 +157,10 @@ def broadcast_tensor_dict(
rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
assert value.is_cuda, (
......@@ -190,10 +190,10 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=group)
metadata_list = recv_metadata_list[0]
assert recv_metadata_list[0] is not None
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
......
import pickle
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
from vllm.worker.worker import Worker
logger = init_logger(__name__)
......@@ -18,15 +19,20 @@ try:
if init_cached_hf_modules:
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self.worker = None
self._worker: Optional[Worker] = None
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False
def init_worker(self, worker_init_fn):
self.worker = worker_init_fn()
def init_worker(self, worker_init_fn: Callable[[], Worker]):
self._worker = worker_init_fn()
@property
def worker(self) -> Worker:
assert self._worker is not None
return self._worker
def __getattr__(self, name):
return getattr(self.worker, name)
......@@ -70,8 +76,8 @@ except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with "
"`pip install ray`.")
ray = None
RayWorkerVllm = None
ray = None # type: ignore
RayWorkerVllm = None # type: ignore
def initialize_ray_cluster(
......
......@@ -47,6 +47,7 @@ async def generate(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
......
......@@ -170,8 +170,12 @@ class LLM:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len(
prompt_token_ids)
if prompts is not None:
num_requests = len(prompts)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
......
import os
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
import torch
......@@ -61,7 +61,7 @@ class CPUExecutor(ExecutorBase):
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]:
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
......
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
......@@ -66,7 +66,7 @@ class GPUExecutor(ExecutorBase):
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]:
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
......
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
......@@ -47,7 +47,7 @@ class NeuronExecutor(ExecutorBase):
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]:
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
......
......@@ -3,7 +3,7 @@ import copy
import os
import pickle
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
......@@ -197,7 +197,7 @@ class RayGPUExecutor(ExecutorBase):
max_parallel_loading_workers,
)
def determine_num_available_blocks(self) -> tuple[int, int]:
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
......@@ -205,7 +205,7 @@ class RayGPUExecutor(ExecutorBase):
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
- Tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", )
......@@ -276,7 +276,7 @@ class RayGPUExecutor(ExecutorBase):
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
......@@ -291,6 +291,7 @@ class RayGPUExecutor(ExecutorBase):
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
......@@ -369,7 +370,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
......
......@@ -5,7 +5,8 @@ from functools import cached_property
from typing import Callable, List, Optional, Union
import torch
from pydantic import conint
from pydantic import Field
from typing_extensions import Annotated
_SAMPLING_EPS = 1e-5
......@@ -127,7 +128,7 @@ class SamplingParams:
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
......
......@@ -171,10 +171,10 @@ class SequenceData:
return self.prompt_token_ids[-1]
return self.output_token_ids[-1]
def get_prompt_token_ids(self) -> int:
def get_prompt_token_ids(self) -> List[int]:
return self.prompt_token_ids
def get_output_token_ids(self) -> int:
def get_output_token_ids(self) -> List[int]:
return self.output_token_ids
@property
......@@ -370,7 +370,7 @@ class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator: Optional = None
generator: Optional = None # type: ignore
class MultiModalData:
......@@ -599,7 +599,7 @@ class SequenceGroupMetadata:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def token_chunk_size(self) -> int:
def token_chunk_size(self) -> Optional[int]:
"""Return the number of tokens to be processed (chunk size)."""
return self._token_chunk_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment