Unverified Commit 0ae11f78 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Mypy] Part 3 fix typing for nested directories for most of directory (#4161)

parent 34128a69
......@@ -32,19 +32,20 @@ jobs:
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
# TODO(sang): Fix nested dir
# mypy vllm/lora/*.py --config-file pyproject.toml
......@@ -94,21 +94,19 @@ echo 'vLLM yapf: Done'
# Run mypy
echo 'vLLM mypy:'
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml
# mypy vllm/lora/*.py --config-file pyproject.toml
CODESPELL_EXCLUDES=(
......
......@@ -46,15 +46,17 @@ ignore = [
python_version = "3.8"
ignore_missing_imports = true
check_untyped_defs = true
check_untyped_defs = true
follow_imports = "skip"
files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
# Ignore triton kernels in ops.
'vllm/attention/ops/.*\.py$'
]
[tool.codespell]
ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt"
......
......@@ -116,7 +116,7 @@ class AttentionImpl(ABC):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
attn_metadata: AttentionMetadata,
kv_scale: float,
) -> torch.Tensor:
raise NotImplementedError
......@@ -248,6 +248,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.prompt_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
......
......@@ -106,7 +106,7 @@ class TorchSDPABackendImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata,
attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
......@@ -136,6 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
kv_scale)
if attn_metadata.is_prompt:
assert attn_metadata.prompt_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
......
......@@ -288,6 +288,7 @@ class XFormersImpl(AttentionImpl):
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
assert attn_metadata.prompt_lens is not None
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
......
......@@ -104,6 +104,7 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
assert self._blocks is not None
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)
......
......@@ -99,7 +99,7 @@ class CopyOnWriteTracker:
refcounter: RefCounter,
allocator: BlockAllocator,
):
self._copy_on_writes = defaultdict(list)
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
self._refcounter = refcounter
self._allocator = allocator
......@@ -138,6 +138,8 @@ class CopyOnWriteTracker:
prev_block=block.prev_block).block_id
# Track src/dst copy.
assert src_block_id is not None
assert block_id is not None
self._copy_on_writes[src_block_id].append(block_id)
return block_id
......@@ -180,6 +182,6 @@ def get_all_blocks_recursively(last_block: Block) -> List[Block]:
recurse(block.prev_block, lst)
lst.append(block)
all_blocks = []
all_blocks: List[Block] = []
recurse(last_block, all_blocks)
return all_blocks
......@@ -52,8 +52,7 @@ class Block(ABC):
class BlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass
@abstractmethod
......@@ -98,8 +97,7 @@ class BlockAllocator(ABC):
class DeviceAwareBlockAllocator(BlockAllocator):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass
@abstractmethod
......
import os
from contextlib import contextmanager
from typing import List, Optional
from typing import Any, List, Optional
import torch
import torch.distributed as dist
......@@ -18,7 +18,7 @@ except ImportError:
logger = init_logger(__name__)
_CA_HANDLE = None
_CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
......@@ -51,7 +51,7 @@ def init_custom_ar() -> None:
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return False
return
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
if "CUDA_VISIBLE_DEVICES" in os.environ:
......@@ -117,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle()
# when custom allreduce is disabled, this will be None
if ca_handle is None:
return
return None
if is_capturing():
if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input):
......@@ -135,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input)
return None
@contextmanager
def _nvml():
......@@ -224,14 +226,14 @@ class CustomAllreduce:
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
all_data = [None] * self.world_size
all_data: List[Optional[Any]] = [None] * self.world_size
dist.all_gather_object(all_data, shard_data)
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0])
offsets.append(all_data[i][1])
handles.append(all_data[i][0]) # type: ignore
offsets.append(all_data[i][1]) # type: ignore
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
......
......@@ -107,9 +107,10 @@ _c_ncclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
]
ncclDataType_t = ctypes.c_int
# enums
class ncclDataType_t(ctypes.c_int):
class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
......@@ -128,7 +129,7 @@ class ncclDataType_t(ctypes.c_int):
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
......@@ -148,7 +149,10 @@ class ncclDataType_t(ctypes.c_int):
raise ValueError(f"Unsupported dtype: {dtype}")
class ncclRedOp_t(ctypes.c_int):
ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
......@@ -157,7 +161,7 @@ class ncclRedOp_t(ctypes.c_int):
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
......@@ -180,8 +184,8 @@ class ncclRedOp_t(ctypes.c_int):
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
]
# equivalent to c declaration:
......@@ -251,8 +255,8 @@ class NCCLCommunicator:
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataType_t.from_torch(tensor.dtype),
ncclRedOp_t.from_torch(op), self.comm,
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream))
assert result == 0
......
......@@ -30,6 +30,7 @@ def is_initialized() -> bool:
def set_pynccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
try:
assert comm is not None
comm.stream = stream
yield
finally:
......@@ -52,6 +53,7 @@ def init_process_group(world_size: int,
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
assert comm is not None
comm.all_reduce(input_, op)
......@@ -62,8 +64,9 @@ def destroy_process_group() -> None:
def get_world_size() -> int:
"""Returns the world size."""
assert comm is not None
return comm.world_size
def get_nccl_backend():
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
return comm
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List
from typing import Callable, List
from transformers import PreTrainedTokenizer
......@@ -8,6 +8,7 @@ from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
class SequenceGroupOutputProcessor(ABC):
......@@ -27,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker",
):
......
from typing import Callable, Iterable, List
from typing import Callable, List
from transformers import PreTrainedTokenizer
......@@ -11,6 +11,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
......@@ -33,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
self,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker,
):
......
from typing import Iterable, List, Tuple, Union
from typing import Dict, List, Tuple, Union
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
......@@ -10,6 +10,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
......@@ -33,7 +34,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
stop_checker: StopChecker,
):
self.scheduler_config = scheduler_config
......@@ -69,7 +70,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {
parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
......@@ -92,7 +93,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
......
......@@ -8,7 +8,9 @@ def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group = [[] for _ in range(num_seq_groups)]
output_by_sequence_group: List[List[SamplerOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in sampler_outputs:
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)
......
......@@ -18,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest, ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
......@@ -26,8 +27,8 @@ from vllm.usage.usage_lib import UsageContext
TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat = None
openai_serving_completion: OpenAIServingCompletion = None
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
logger = init_logger(__name__)
......@@ -95,6 +96,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump())
......
......@@ -4,7 +4,8 @@ import time
from typing import Dict, List, Literal, Optional, Union
import torch
from pydantic import BaseModel, Field, conint, model_validator
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Annotated
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
......@@ -30,7 +31,7 @@ class ModelPermission(BaseModel):
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: str = False
is_blocking: bool = False
class ModelCard(BaseModel):
......@@ -56,7 +57,7 @@ class UsageInfo(BaseModel):
class ResponseFormat(BaseModel):
# type must be "json_object" or "text"
type: str = Literal["text", "json_object"]
type: Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel):
......@@ -152,6 +153,7 @@ class ChatCompletionRequest(BaseModel):
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
......@@ -213,7 +215,7 @@ class CompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
n: Optional[int] = 1
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
......@@ -235,7 +237,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[conint(ge=1)] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
......@@ -289,6 +291,7 @@ class CompletionRequest(BaseModel):
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
......
......@@ -115,12 +115,12 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = True
# Send response for each token for each request.n (index)
assert request.n is not None
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
try:
async for res in result_generator:
res: RequestOutput
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment