Unverified Commit 0e9164b4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable type checking for test directory (#5017)

parent 1b8a0d71
......@@ -453,8 +453,8 @@ class ArcticForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"),
]
mlp_params_mapping = []
expert_params_mapping = []
mlp_params_mapping: List[Tuple[str, str, int]] = []
expert_params_mapping: List[Tuple[str, str, int]] = []
num_layers = self.config.num_hidden_layers
for layer in range(num_layers):
......
......@@ -20,7 +20,7 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Set, Tuple
import torch
import torch.utils.checkpoint
......@@ -352,7 +352,7 @@ class CohereForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
......
......@@ -15,7 +15,7 @@
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
......@@ -363,7 +363,7 @@ class GemmaForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
......
......@@ -123,7 +123,7 @@ class SequenceData:
output_token_ids = []
self.prompt_token_ids = prompt_token_ids
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
self._prompt_token_ids_tuple = tuple(prompt_token_ids)
self.output_token_ids = output_token_ids
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
......
import copy
import weakref
from typing import List, Tuple
from typing import Dict, List, Tuple
import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
......@@ -71,7 +71,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
sample_len)
# Run model sample_len times.
model_outputs = []
model_outputs: List[SamplerOutput] = []
for _ in range(sample_len):
model_output = super().execute_model(
execute_model_req=copied_execute_model_req)
......@@ -132,7 +132,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list = []
new_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for old_seq_group_metadata in seq_group_metadata_list:
# We must shallow-copy seq_group_metadata as is_prompt could change.
......@@ -140,7 +140,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
new_seq_group_metadata_list.append(seq_group_metadata)
# We must shallow-copy seq_data as we will append token ids
new_seq_data = {}
new_seq_data: Dict[int, SequenceData] = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[
......
......@@ -48,7 +48,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[Optional[List[SamplerOutput]], bool]:
) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
"""NGram match algo to pick proposal candidate. Returns the list of
sampler output, one per SequenceGroupMetadata.
......@@ -58,8 +58,8 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self._raise_if_unsupported(execute_model_req)
has_spec_out = False
token_id_list = []
token_prob_list = []
token_id_list: List[Optional[torch.Tensor]] = []
token_prob_list: List[Optional[torch.Tensor]] = []
for idx, seq_group_metadata in enumerate(
execute_model_req.seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
......
......@@ -7,8 +7,8 @@ from vllm.config import SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
SamplerOutput, SequenceGroupMetadata)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
......@@ -516,13 +516,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
topk_indices_by_step = topk_indices_by_step.tolist()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list = []
sampler_output_list: List[SamplerOutput] = []
for step_index in range(num_steps):
if all(token_id == -1
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
for sequence_index in range(batch_size):
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index]
......
......@@ -26,10 +26,10 @@ def get_all_num_logprobs(
sequence.
"""
all_num_logprobs = []
all_num_logprobs: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
num_logprobs = seq_group_metadata.sampling_params.logprobs
if seq_group_metadata.sampling_params.logprobs is None:
if num_logprobs is None:
num_logprobs = 0
all_num_logprobs.append(num_logprobs)
......
......@@ -44,7 +44,7 @@ class Detokenizer:
read_offset = 0
next_iter_prefix_offset = 0
next_iter_read_offset = 0
next_iter_tokens = []
next_iter_tokens: List[str] = []
prev_tokens = None
for token_position, prompt_logprobs_for_token in enumerate(
......
......@@ -20,12 +20,13 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
import numpy as np
import psutil
import torch
import torch.types
from typing_extensions import ParamSpec
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import enable_trace_function_call, init_logger
T = TypeVar("T")
logger = init_logger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = {
......@@ -37,6 +38,10 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_e5m2": torch.uint8,
}
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
class Device(enum.Enum):
GPU = enum.auto()
......@@ -176,7 +181,7 @@ def random_uuid() -> str:
@lru_cache(maxsize=None)
def get_vllm_instance_id():
def get_vllm_instance_id() -> str:
"""
If the environment variable VLLM_INSTANCE_ID is set, return it.
Otherwise, return a random UUID.
......@@ -192,7 +197,7 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower()
def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
......@@ -200,7 +205,7 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
The code in this function needs to be thread safe.
"""
def _async_wrapper(*args, **kwargs) -> asyncio.Future:
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=None, func=p_func)
......@@ -325,7 +330,7 @@ def update_environment_variables(envs: Dict[str, str]):
os.environ[k] = v
def chunk_list(lst, chunk_size):
def chunk_list(lst: List[T], chunk_size: int) -> List[List[T]]:
"""Yield successive chunk_size chunks from lst."""
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
......@@ -336,7 +341,7 @@ def cdiv(a: int, b: int) -> int:
def _generate_random_fp8(
tensor: torch.tensor,
tensor: torch.Tensor,
low: float,
high: float,
) -> None:
......@@ -398,7 +403,10 @@ def create_kv_caches_with_random_flash(
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
scale = head_size**-0.5
key_caches, value_caches = [], []
key_caches: List[torch.Tensor] = []
value_caches: List[torch.Tensor] = []
for _ in range(num_layers):
key_value_cache = torch.empty(size=key_value_cache_shape,
dtype=torch_dtype,
......@@ -429,7 +437,7 @@ def create_kv_caches_with_random(
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
key_caches: List[torch.Tensor] = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=torch_dtype,
......@@ -444,7 +452,7 @@ def create_kv_caches_with_random(
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
value_caches: List[torch.Tensor] = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=torch_dtype,
......@@ -484,7 +492,7 @@ def is_pin_memory_available() -> bool:
class CudaMemoryProfiler:
def __init__(self, device=None):
def __init__(self, device: Optional[torch.types.Device] = None):
self.device = device
def current_memory_usage(self) -> float:
......@@ -560,13 +568,13 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size()
def merge_dicts(dict1: Dict[Any, List[Any]],
dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
def merge_dicts(dict1: Dict[K, List[T]],
dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
"""Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized.
"""
merged_dict = defaultdict(list)
merged_dict: Dict[K, List[T]] = defaultdict(list)
for key, value in dict1.items():
merged_dict[key].extend(value)
......@@ -577,7 +585,7 @@ def merge_dicts(dict1: Dict[Any, List[Any]],
return dict(merged_dict)
def init_cached_hf_modules():
def init_cached_hf_modules() -> None:
"""
Lazy initialization of the Hugging Face modules.
"""
......@@ -613,7 +621,7 @@ def find_library(lib_name: str) -> str:
return locs[0]
def find_nccl_library():
def find_nccl_library() -> str:
"""
We either use the library file specified by the `VLLM_NCCL_SO_PATH`
environment variable, or we find the library file brought by PyTorch.
......
......@@ -779,8 +779,8 @@ class ModelRunner:
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests = []
dummy_lora_requests_per_seq = []
dummy_lora_requests: List[LoRARequest] = []
dummy_lora_requests_per_seq: List[LoRARequest] = []
if self.lora_config:
assert self.lora_manager is not None
with self.lora_manager.dummy_lora_cache():
......
......@@ -99,8 +99,8 @@ class WorkerWrapperBase:
"""
def __init__(self,
worker_module_name=None,
worker_class_name=None,
worker_module_name: str,
worker_class_name: str,
trust_remote_code: bool = False) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment