Unverified Commit 0ae11f78 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Mypy] Part 3 fix typing for nested directories for most of directory (#4161)

parent 34128a69
...@@ -32,19 +32,20 @@ jobs: ...@@ -32,19 +32,20 @@ jobs:
pip install types-setuptools pip install types-setuptools
- name: Mypy - name: Mypy
run: | run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml # TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/model_executor/*.py --config-file pyproject.toml
# TODO(sang): Follow up # TODO(sang): Fix nested dir
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml # mypy vllm/lora/*.py --config-file pyproject.toml
...@@ -94,21 +94,19 @@ echo 'vLLM yapf: Done' ...@@ -94,21 +94,19 @@ echo 'vLLM yapf: Done'
# Run mypy # Run mypy
echo 'vLLM mypy:' echo 'vLLM mypy:'
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
# TODO(sang): Follow up mypy vllm/worker --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/model_executor/*.py --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml # mypy vllm/lora/*.py --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
CODESPELL_EXCLUDES=( CODESPELL_EXCLUDES=(
......
...@@ -46,15 +46,17 @@ ignore = [ ...@@ -46,15 +46,17 @@ ignore = [
python_version = "3.8" python_version = "3.8"
ignore_missing_imports = true ignore_missing_imports = true
check_untyped_defs = true check_untyped_defs = true
follow_imports = "skip"
files = "vllm" files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [ exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/", "vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
# Ignore triton kernels in ops.
'vllm/attention/ops/.*\.py$'
] ]
[tool.codespell] [tool.codespell]
ignore-words-list = "dout, te, indicies" ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt" skip = "./tests/prompts,./benchmarks/sonnet.txt"
......
...@@ -116,7 +116,7 @@ class AttentionImpl(ABC): ...@@ -116,7 +116,7 @@ class AttentionImpl(ABC):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[AttentionMetadataPerStage], attn_metadata: AttentionMetadata,
kv_scale: float, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -248,6 +248,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -248,6 +248,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
assert prefill_meta.prompt_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0: if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention # triton attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
......
...@@ -106,7 +106,7 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -106,7 +106,7 @@ class TorchSDPABackendImpl(AttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
...@@ -136,6 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl): ...@@ -136,6 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
kv_scale) kv_scale)
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
assert attn_metadata.prompt_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
......
...@@ -288,6 +288,7 @@ class XFormersImpl(AttentionImpl): ...@@ -288,6 +288,7 @@ class XFormersImpl(AttentionImpl):
value: shape = [num_prefill_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
""" """
assert attn_metadata.prompt_lens is not None
original_query = query original_query = query
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K]. # GQA/MQA requires the shape [B, M, G, H, K].
......
...@@ -104,6 +104,7 @@ class BlockTable: ...@@ -104,6 +104,7 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended. token_ids (List[int]): The sequence of token IDs to be appended.
""" """
assert self._is_allocated assert self._is_allocated
assert self._blocks is not None
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots) num_lookahead_slots)
......
...@@ -99,7 +99,7 @@ class CopyOnWriteTracker: ...@@ -99,7 +99,7 @@ class CopyOnWriteTracker:
refcounter: RefCounter, refcounter: RefCounter,
allocator: BlockAllocator, allocator: BlockAllocator,
): ):
self._copy_on_writes = defaultdict(list) self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
self._refcounter = refcounter self._refcounter = refcounter
self._allocator = allocator self._allocator = allocator
...@@ -138,6 +138,8 @@ class CopyOnWriteTracker: ...@@ -138,6 +138,8 @@ class CopyOnWriteTracker:
prev_block=block.prev_block).block_id prev_block=block.prev_block).block_id
# Track src/dst copy. # Track src/dst copy.
assert src_block_id is not None
assert block_id is not None
self._copy_on_writes[src_block_id].append(block_id) self._copy_on_writes[src_block_id].append(block_id)
return block_id return block_id
...@@ -180,6 +182,6 @@ def get_all_blocks_recursively(last_block: Block) -> List[Block]: ...@@ -180,6 +182,6 @@ def get_all_blocks_recursively(last_block: Block) -> List[Block]:
recurse(block.prev_block, lst) recurse(block.prev_block, lst)
lst.append(block) lst.append(block)
all_blocks = [] all_blocks: List[Block] = []
recurse(last_block, all_blocks) recurse(last_block, all_blocks)
return all_blocks return all_blocks
...@@ -52,8 +52,7 @@ class Block(ABC): ...@@ -52,8 +52,7 @@ class Block(ABC):
class BlockAllocator(ABC): class BlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable(self, prev_block: Optional[Block], def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
device: Device) -> Block:
pass pass
@abstractmethod @abstractmethod
...@@ -98,8 +97,7 @@ class BlockAllocator(ABC): ...@@ -98,8 +97,7 @@ class BlockAllocator(ABC):
class DeviceAwareBlockAllocator(BlockAllocator): class DeviceAwareBlockAllocator(BlockAllocator):
@abstractmethod @abstractmethod
def allocate_mutable(self, prev_block: Optional[Block], def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
device: Device) -> Block:
pass pass
@abstractmethod @abstractmethod
......
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional from typing import Any, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -18,7 +18,7 @@ except ImportError: ...@@ -18,7 +18,7 @@ except ImportError:
logger = init_logger(__name__) logger = init_logger(__name__)
_CA_HANDLE = None _CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False _IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
...@@ -51,7 +51,7 @@ def init_custom_ar() -> None: ...@@ -51,7 +51,7 @@ def init_custom_ar() -> None:
"Cannot test GPU P2P because not all GPUs are visible to the " "Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.") " is set.")
return False return
# test nvlink first, this will filter out most of the cases # test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported # where custom allreduce is not supported
if "CUDA_VISIBLE_DEVICES" in os.environ: if "CUDA_VISIBLE_DEVICES" in os.environ:
...@@ -117,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: ...@@ -117,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle() ca_handle = get_handle()
# when custom allreduce is disabled, this will be None # when custom allreduce is disabled, this will be None
if ca_handle is None: if ca_handle is None:
return return None
if is_capturing(): if is_capturing():
if torch.cuda.is_current_stream_capturing(): if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input): if ca_handle.should_custom_ar(input):
...@@ -135,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: ...@@ -135,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
if ca_handle.should_custom_ar(input): if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input) return ca_handle.all_reduce_unreg(input)
return None
@contextmanager @contextmanager
def _nvml(): def _nvml():
...@@ -224,14 +226,14 @@ class CustomAllreduce: ...@@ -224,14 +226,14 @@ class CustomAllreduce:
return self._gather_ipc_meta(shard_data) return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data): def _gather_ipc_meta(self, shard_data):
all_data = [None] * self.world_size all_data: List[Optional[Any]] = [None] * self.world_size
dist.all_gather_object(all_data, shard_data) dist.all_gather_object(all_data, shard_data)
handles = [] handles = []
offsets = [] offsets = []
for i in range(len(all_data)): for i in range(len(all_data)):
handles.append(all_data[i][0]) handles.append(all_data[i][0]) # type: ignore
offsets.append(all_data[i][1]) offsets.append(all_data[i][1]) # type: ignore
return handles, offsets return handles, offsets
def register_buffer(self, inp: torch.Tensor): def register_buffer(self, inp: torch.Tensor):
......
...@@ -107,9 +107,10 @@ _c_ncclCommInitRank.argtypes = [ ...@@ -107,9 +107,10 @@ _c_ncclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
] ]
ncclDataType_t = ctypes.c_int
# enums
class ncclDataType_t(ctypes.c_int): class ncclDataTypeEnum:
ncclInt8 = 0 ncclInt8 = 0
ncclChar = 0 ncclChar = 0
ncclUint8 = 1 ncclUint8 = 1
...@@ -128,7 +129,7 @@ class ncclDataType_t(ctypes.c_int): ...@@ -128,7 +129,7 @@ class ncclDataType_t(ctypes.c_int):
ncclNumTypes = 10 ncclNumTypes = 10
@classmethod @classmethod
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t': def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8: if dtype == torch.int8:
return cls.ncclInt8 return cls.ncclInt8
if dtype == torch.uint8: if dtype == torch.uint8:
...@@ -148,7 +149,10 @@ class ncclDataType_t(ctypes.c_int): ...@@ -148,7 +149,10 @@ class ncclDataType_t(ctypes.c_int):
raise ValueError(f"Unsupported dtype: {dtype}") raise ValueError(f"Unsupported dtype: {dtype}")
class ncclRedOp_t(ctypes.c_int): ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0 ncclSum = 0
ncclProd = 1 ncclProd = 1
ncclMax = 2 ncclMax = 2
...@@ -157,7 +161,7 @@ class ncclRedOp_t(ctypes.c_int): ...@@ -157,7 +161,7 @@ class ncclRedOp_t(ctypes.c_int):
ncclNumOps = 5 ncclNumOps = 5
@classmethod @classmethod
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t': def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM: if op == ReduceOp.SUM:
return cls.ncclSum return cls.ncclSum
if op == ReduceOp.PRODUCT: if op == ReduceOp.PRODUCT:
...@@ -180,8 +184,8 @@ class ncclRedOp_t(ctypes.c_int): ...@@ -180,8 +184,8 @@ class ncclRedOp_t(ctypes.c_int):
_c_ncclAllReduce = nccl.ncclAllReduce _c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int _c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [ _c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
] ]
# equivalent to c declaration: # equivalent to c declaration:
...@@ -251,8 +255,8 @@ class NCCLCommunicator: ...@@ -251,8 +255,8 @@ class NCCLCommunicator:
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()), ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(), tensor.numel(),
ncclDataType_t.from_torch(tensor.dtype), ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOp_t.from_torch(op), self.comm, ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream)) ctypes.c_void_p(stream.cuda_stream))
assert result == 0 assert result == 0
......
...@@ -30,6 +30,7 @@ def is_initialized() -> bool: ...@@ -30,6 +30,7 @@ def is_initialized() -> bool:
def set_pynccl_stream(stream: torch.cuda.Stream): def set_pynccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication""" """Set the cuda stream for communication"""
try: try:
assert comm is not None
comm.stream = stream comm.stream = stream
yield yield
finally: finally:
...@@ -52,6 +53,7 @@ def init_process_group(world_size: int, ...@@ -52,6 +53,7 @@ def init_process_group(world_size: int,
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group.""" """All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor" assert input_.is_cuda, f"{input_} should be a cuda tensor"
assert comm is not None
comm.all_reduce(input_, op) comm.all_reduce(input_, op)
...@@ -62,8 +64,9 @@ def destroy_process_group() -> None: ...@@ -62,8 +64,9 @@ def destroy_process_group() -> None:
def get_world_size() -> int: def get_world_size() -> int:
"""Returns the world size.""" """Returns the world size."""
assert comm is not None
return comm.world_size return comm.world_size
def get_nccl_backend(): def get_nccl_backend() -> Optional["NCCLCommunicator"]:
return comm return comm
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Iterable, List from typing import Callable, List
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
...@@ -8,6 +8,7 @@ from vllm.core.scheduler import Scheduler ...@@ -8,6 +8,7 @@ from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
class SequenceGroupOutputProcessor(ABC): class SequenceGroupOutputProcessor(ABC):
...@@ -27,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC): ...@@ -27,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: Scheduler,
seq_counter: Iterable[int], seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker", stop_checker: "StopChecker",
): ):
......
from typing import Callable, Iterable, List from typing import Callable, List
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
...@@ -11,6 +11,7 @@ from vllm.sampling_params import SamplingParams ...@@ -11,6 +11,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, Sequence, SequenceGroup, from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus) SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -33,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -33,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
self, self,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: Scheduler,
seq_counter: Iterable[int], seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker, stop_checker: StopChecker,
): ):
......
from typing import Iterable, List, Tuple, Union from typing import Dict, List, Tuple, Union
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
...@@ -10,6 +10,7 @@ from vllm.sampling_params import SamplingParams ...@@ -10,6 +10,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus) SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -33,7 +34,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -33,7 +34,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: Scheduler,
seq_counter: Iterable[int], seq_counter: Counter,
stop_checker: StopChecker, stop_checker: StopChecker,
): ):
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
...@@ -69,7 +70,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -69,7 +70,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
samples = outputs.samples samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs() existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = { parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: [] parent_seq.seq_id: []
for parent_seq in parent_seqs for parent_seq in parent_seqs
} }
...@@ -92,7 +93,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -92,7 +93,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
continue continue
# Fork the parent sequence if there are multiple child samples. # Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]: for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter) new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id) child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token, child.append_token_id(child_sample.output_token,
child_sample.logprobs) child_sample.logprobs)
......
...@@ -8,7 +8,9 @@ def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], ...@@ -8,7 +8,9 @@ def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
"""Helper method which transforms a 2d list organized by """Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step]. [step][sequence group] into [sequence group][step].
""" """
output_by_sequence_group = [[] for _ in range(num_seq_groups)] output_by_sequence_group: List[List[SamplerOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in sampler_outputs: for step in sampler_outputs:
for i, sequence_group_output in enumerate(step): for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output) output_by_sequence_group[i].append(sequence_group_output)
......
...@@ -18,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -18,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest, ErrorResponse) CompletionRequest, ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
...@@ -26,8 +27,8 @@ from vllm.usage.usage_lib import UsageContext ...@@ -26,8 +27,8 @@ from vllm.usage.usage_lib import UsageContext
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat = None openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion = None openai_serving_completion: OpenAIServingCompletion
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -95,6 +96,7 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -95,6 +96,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator, return StreamingResponse(content=generator,
media_type="text/event-stream") media_type="text/event-stream")
else: else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
......
...@@ -4,7 +4,8 @@ import time ...@@ -4,7 +4,8 @@ import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
import torch import torch
from pydantic import BaseModel, Field, conint, model_validator from pydantic import BaseModel, Field, model_validator
from typing_extensions import Annotated
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -30,7 +31,7 @@ class ModelPermission(BaseModel): ...@@ -30,7 +31,7 @@ class ModelPermission(BaseModel):
allow_fine_tuning: bool = False allow_fine_tuning: bool = False
organization: str = "*" organization: str = "*"
group: Optional[str] = None group: Optional[str] = None
is_blocking: str = False is_blocking: bool = False
class ModelCard(BaseModel): class ModelCard(BaseModel):
...@@ -56,7 +57,7 @@ class UsageInfo(BaseModel): ...@@ -56,7 +57,7 @@ class UsageInfo(BaseModel):
class ResponseFormat(BaseModel): class ResponseFormat(BaseModel):
# type must be "json_object" or "text" # type must be "json_object" or "text"
type: str = Literal["text", "json_object"] type: Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
...@@ -152,6 +153,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -152,6 +153,7 @@ class ChatCompletionRequest(BaseModel):
def logit_bias_logits_processor( def logit_bias_logits_processor(
token_ids: List[int], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor: logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items(): for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec # Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias)) bias = min(100, max(-100, bias))
...@@ -213,7 +215,7 @@ class CompletionRequest(BaseModel): ...@@ -213,7 +215,7 @@ class CompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None logprobs: Optional[int] = None
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
n: Optional[int] = 1 n: int = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
...@@ -235,7 +237,7 @@ class CompletionRequest(BaseModel): ...@@ -235,7 +237,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = 0 min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[conint(ge=1)] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-completion-sampling-params # doc: end-completion-sampling-params
# doc: begin-completion-extra-params # doc: begin-completion-extra-params
...@@ -289,6 +291,7 @@ class CompletionRequest(BaseModel): ...@@ -289,6 +291,7 @@ class CompletionRequest(BaseModel):
def logit_bias_logits_processor( def logit_bias_logits_processor(
token_ids: List[int], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor: logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items(): for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec # Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias)) bias = min(100, max(-100, bias))
......
...@@ -115,12 +115,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -115,12 +115,12 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = True first_iteration = True
# Send response for each token for each request.n (index) # Send response for each token for each request.n (index)
assert request.n is not None
previous_texts = [""] * request.n previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n finish_reason_sent = [False] * request.n
try: try:
async for res in result_generator: async for res in result_generator:
res: RequestOutput
# We need to do it here, because if there are exceptions in # We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST # the result_generator, it needs to be sent as the FIRST
# response (by the try...catch). # response (by the try...catch).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment