Unverified Commit 0e9164b4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable type checking for test directory (#5017)

parent 1b8a0d71
...@@ -453,8 +453,8 @@ class ArcticForCausalLM(nn.Module): ...@@ -453,8 +453,8 @@ class ArcticForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
mlp_params_mapping = [] mlp_params_mapping: List[Tuple[str, str, int]] = []
expert_params_mapping = [] expert_params_mapping: List[Tuple[str, str, int]] = []
num_layers = self.config.num_hidden_layers num_layers = self.config.num_hidden_layers
for layer in range(num_layers): for layer in range(num_layers):
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# This file is based on the LLama model definition file in transformers # This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model.""" """PyTorch Cohere model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -352,7 +352,7 @@ class CohereForCausalLM(nn.Module): ...@@ -352,7 +352,7 @@ class CohereForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping: for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name: if shard_name not in name:
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights.""" """Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache from functools import lru_cache
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -363,7 +363,7 @@ class GemmaForCausalLM(nn.Module): ...@@ -363,7 +363,7 @@ class GemmaForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping: for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name: if shard_name not in name:
......
...@@ -123,7 +123,7 @@ class SequenceData: ...@@ -123,7 +123,7 @@ class SequenceData:
output_token_ids = [] output_token_ids = []
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) self._prompt_token_ids_tuple = tuple(prompt_token_ids)
self.output_token_ids = output_token_ids self.output_token_ids = output_token_ids
self.cumulative_logprob = 0.0 self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model). # The number of tokens that are computed (that run against the model).
......
import copy import copy
import weakref import weakref
from typing import List, Tuple from typing import Dict, List, Tuple
import torch import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
...@@ -71,7 +71,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase): ...@@ -71,7 +71,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
sample_len) sample_len)
# Run model sample_len times. # Run model sample_len times.
model_outputs = [] model_outputs: List[SamplerOutput] = []
for _ in range(sample_len): for _ in range(sample_len):
model_output = super().execute_model( model_output = super().execute_model(
execute_model_req=copied_execute_model_req) execute_model_req=copied_execute_model_req)
...@@ -132,7 +132,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase): ...@@ -132,7 +132,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Shallow-copy the list of SequenceGroupMetadata. This allows us to # Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects. # append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list = [] new_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for old_seq_group_metadata in seq_group_metadata_list: for old_seq_group_metadata in seq_group_metadata_list:
# We must shallow-copy seq_group_metadata as is_prompt could change. # We must shallow-copy seq_group_metadata as is_prompt could change.
...@@ -140,7 +140,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase): ...@@ -140,7 +140,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
new_seq_group_metadata_list.append(seq_group_metadata) new_seq_group_metadata_list.append(seq_group_metadata)
# We must shallow-copy seq_data as we will append token ids # We must shallow-copy seq_data as we will append token ids
new_seq_data = {} new_seq_data: Dict[int, SequenceData] = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_data[seq_id] = copy.copy(old_seq_data) new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[ new_seq_data[
......
...@@ -48,7 +48,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase): ...@@ -48,7 +48,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
sample_len: int, sample_len: int,
) -> Tuple[Optional[List[SamplerOutput]], bool]: ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
"""NGram match algo to pick proposal candidate. Returns the list of """NGram match algo to pick proposal candidate. Returns the list of
sampler output, one per SequenceGroupMetadata. sampler output, one per SequenceGroupMetadata.
...@@ -58,8 +58,8 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase): ...@@ -58,8 +58,8 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self._raise_if_unsupported(execute_model_req) self._raise_if_unsupported(execute_model_req)
has_spec_out = False has_spec_out = False
token_id_list = [] token_id_list: List[Optional[torch.Tensor]] = []
token_prob_list = [] token_prob_list: List[Optional[torch.Tensor]] = []
for idx, seq_group_metadata in enumerate( for idx, seq_group_metadata in enumerate(
execute_model_req.seq_group_metadata_list): execute_model_req.seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values())) seq_data = next(iter(seq_group_metadata.seq_data.values()))
......
...@@ -7,8 +7,8 @@ from vllm.config import SpeculativeConfig ...@@ -7,8 +7,8 @@ from vllm.config import SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
SequenceGroupMetadata) SamplerOutput, SequenceGroupMetadata)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
...@@ -516,13 +516,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -516,13 +516,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
topk_indices_by_step = topk_indices_by_step.tolist() topk_indices_by_step = topk_indices_by_step.tolist()
# Construct the output on a per-step, per-sequence basis. # Construct the output on a per-step, per-sequence basis.
sampler_output_list = [] sampler_output_list: List[SamplerOutput] = []
for step_index in range(num_steps): for step_index in range(num_steps):
if all(token_id == -1 if all(token_id == -1
for token_id in accepted_token_ids_by_step[step_index]): for token_id in accepted_token_ids_by_step[step_index]):
break break
step_output_token_ids = [] step_output_token_ids: List[CompletionSequenceGroupOutput] = []
for sequence_index in range(batch_size): for sequence_index in range(batch_size):
# Each sequence may have a different num_logprobs; retrieve it. # Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index] num_logprobs = num_logprobs_per_seq[sequence_index]
......
...@@ -26,10 +26,10 @@ def get_all_num_logprobs( ...@@ -26,10 +26,10 @@ def get_all_num_logprobs(
sequence. sequence.
""" """
all_num_logprobs = [] all_num_logprobs: List[int] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
num_logprobs = seq_group_metadata.sampling_params.logprobs num_logprobs = seq_group_metadata.sampling_params.logprobs
if seq_group_metadata.sampling_params.logprobs is None: if num_logprobs is None:
num_logprobs = 0 num_logprobs = 0
all_num_logprobs.append(num_logprobs) all_num_logprobs.append(num_logprobs)
......
...@@ -44,7 +44,7 @@ class Detokenizer: ...@@ -44,7 +44,7 @@ class Detokenizer:
read_offset = 0 read_offset = 0
next_iter_prefix_offset = 0 next_iter_prefix_offset = 0
next_iter_read_offset = 0 next_iter_read_offset = 0
next_iter_tokens = [] next_iter_tokens: List[str] = []
prev_tokens = None prev_tokens = None
for token_position, prompt_logprobs_for_token in enumerate( for token_position, prompt_logprobs_for_token in enumerate(
......
...@@ -20,12 +20,13 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, ...@@ -20,12 +20,13 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
import numpy as np import numpy as np
import psutil import psutil
import torch import torch
import torch.types
from typing_extensions import ParamSpec
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
T = TypeVar("T")
logger = init_logger(__name__) logger = init_logger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = { STR_DTYPE_TO_TORCH_DTYPE = {
...@@ -37,6 +38,10 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -37,6 +38,10 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_e5m2": torch.uint8, "fp8_e5m2": torch.uint8,
} }
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
class Device(enum.Enum): class Device(enum.Enum):
GPU = enum.auto() GPU = enum.auto()
...@@ -176,7 +181,7 @@ def random_uuid() -> str: ...@@ -176,7 +181,7 @@ def random_uuid() -> str:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_vllm_instance_id(): def get_vllm_instance_id() -> str:
""" """
If the environment variable VLLM_INSTANCE_ID is set, return it. If the environment variable VLLM_INSTANCE_ID is set, return it.
Otherwise, return a random UUID. Otherwise, return a random UUID.
...@@ -192,7 +197,7 @@ def in_wsl() -> bool: ...@@ -192,7 +197,7 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread. """Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the This function prevents the blocking function from blocking the
...@@ -200,7 +205,7 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: ...@@ -200,7 +205,7 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
The code in this function needs to be thread safe. The code in this function needs to be thread safe.
""" """
def _async_wrapper(*args, **kwargs) -> asyncio.Future: def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs) p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=None, func=p_func) return loop.run_in_executor(executor=None, func=p_func)
...@@ -325,7 +330,7 @@ def update_environment_variables(envs: Dict[str, str]): ...@@ -325,7 +330,7 @@ def update_environment_variables(envs: Dict[str, str]):
os.environ[k] = v os.environ[k] = v
def chunk_list(lst, chunk_size): def chunk_list(lst: List[T], chunk_size: int) -> List[List[T]]:
"""Yield successive chunk_size chunks from lst.""" """Yield successive chunk_size chunks from lst."""
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
...@@ -336,7 +341,7 @@ def cdiv(a: int, b: int) -> int: ...@@ -336,7 +341,7 @@ def cdiv(a: int, b: int) -> int:
def _generate_random_fp8( def _generate_random_fp8(
tensor: torch.tensor, tensor: torch.Tensor,
low: float, low: float,
high: float, high: float,
) -> None: ) -> None:
...@@ -398,7 +403,10 @@ def create_kv_caches_with_random_flash( ...@@ -398,7 +403,10 @@ def create_kv_caches_with_random_flash(
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
scale = head_size**-0.5 scale = head_size**-0.5
key_caches, value_caches = [], []
key_caches: List[torch.Tensor] = []
value_caches: List[torch.Tensor] = []
for _ in range(num_layers): for _ in range(num_layers):
key_value_cache = torch.empty(size=key_value_cache_shape, key_value_cache = torch.empty(size=key_value_cache_shape,
dtype=torch_dtype, dtype=torch_dtype,
...@@ -429,7 +437,7 @@ def create_kv_caches_with_random( ...@@ -429,7 +437,7 @@ def create_kv_caches_with_random(
scale = head_size**-0.5 scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size() x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = [] key_caches: List[torch.Tensor] = []
for _ in range(num_layers): for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape, key_cache = torch.empty(size=key_cache_shape,
dtype=torch_dtype, dtype=torch_dtype,
...@@ -444,7 +452,7 @@ def create_kv_caches_with_random( ...@@ -444,7 +452,7 @@ def create_kv_caches_with_random(
key_caches.append(key_cache) key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = [] value_caches: List[torch.Tensor] = []
for _ in range(num_layers): for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape, value_cache = torch.empty(size=value_cache_shape,
dtype=torch_dtype, dtype=torch_dtype,
...@@ -484,7 +492,7 @@ def is_pin_memory_available() -> bool: ...@@ -484,7 +492,7 @@ def is_pin_memory_available() -> bool:
class CudaMemoryProfiler: class CudaMemoryProfiler:
def __init__(self, device=None): def __init__(self, device: Optional[torch.types.Device] = None):
self.device = device self.device = device
def current_memory_usage(self) -> float: def current_memory_usage(self) -> float:
...@@ -560,13 +568,13 @@ def get_dtype_size(dtype: torch.dtype) -> int: ...@@ -560,13 +568,13 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size() return torch.tensor([], dtype=dtype).element_size()
def merge_dicts(dict1: Dict[Any, List[Any]], def merge_dicts(dict1: Dict[K, List[T]],
dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]: dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
"""Merge 2 dicts that have key -> List of items. """Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized. When a key conflicts, the values in dict1 is prioritized.
""" """
merged_dict = defaultdict(list) merged_dict: Dict[K, List[T]] = defaultdict(list)
for key, value in dict1.items(): for key, value in dict1.items():
merged_dict[key].extend(value) merged_dict[key].extend(value)
...@@ -577,7 +585,7 @@ def merge_dicts(dict1: Dict[Any, List[Any]], ...@@ -577,7 +585,7 @@ def merge_dicts(dict1: Dict[Any, List[Any]],
return dict(merged_dict) return dict(merged_dict)
def init_cached_hf_modules(): def init_cached_hf_modules() -> None:
""" """
Lazy initialization of the Hugging Face modules. Lazy initialization of the Hugging Face modules.
""" """
...@@ -613,7 +621,7 @@ def find_library(lib_name: str) -> str: ...@@ -613,7 +621,7 @@ def find_library(lib_name: str) -> str:
return locs[0] return locs[0]
def find_nccl_library(): def find_nccl_library() -> str:
""" """
We either use the library file specified by the `VLLM_NCCL_SO_PATH` We either use the library file specified by the `VLLM_NCCL_SO_PATH`
environment variable, or we find the library file brought by PyTorch. environment variable, or we find the library file brought by PyTorch.
......
...@@ -779,8 +779,8 @@ class ModelRunner: ...@@ -779,8 +779,8 @@ class ModelRunner:
# that will have unique loras, an therefore the max amount of memory # that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request # consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path. # passed in, which contains a lora from the lora warmup path.
dummy_lora_requests = [] dummy_lora_requests: List[LoRARequest] = []
dummy_lora_requests_per_seq = [] dummy_lora_requests_per_seq: List[LoRARequest] = []
if self.lora_config: if self.lora_config:
assert self.lora_manager is not None assert self.lora_manager is not None
with self.lora_manager.dummy_lora_cache(): with self.lora_manager.dummy_lora_cache():
......
...@@ -99,8 +99,8 @@ class WorkerWrapperBase: ...@@ -99,8 +99,8 @@ class WorkerWrapperBase:
""" """
def __init__(self, def __init__(self,
worker_module_name=None, worker_module_name: str,
worker_class_name=None, worker_class_name: str,
trust_remote_code: bool = False) -> None: trust_remote_code: bool = False) -> None:
self.worker_module_name = worker_module_name self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name self.worker_class_name = worker_class_name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment