Unverified Commit dc372b9c authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `vllm/device_allocator` and `vllm/distributed` (#18126)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 9b5b39b6
...@@ -74,8 +74,6 @@ exclude = [ ...@@ -74,8 +74,6 @@ exclude = [
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0 # Python 3.8 typing. TODO: Remove these excludes after v1.0.0
"vllm/attention/**/*.py" = ["UP006", "UP035"] "vllm/attention/**/*.py" = ["UP006", "UP035"]
"vllm/core/**/*.py" = ["UP006", "UP035"] "vllm/core/**/*.py" = ["UP006", "UP035"]
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
"vllm/distributed/**/*.py" = ["UP006", "UP035"]
"vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
......
...@@ -11,7 +11,7 @@ import dataclasses ...@@ -11,7 +11,7 @@ import dataclasses
import gc import gc
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Optional, Union
import torch import torch
...@@ -63,7 +63,7 @@ except ModuleNotFoundError: ...@@ -63,7 +63,7 @@ except ModuleNotFoundError:
libcudart = None libcudart = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle # py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = Tuple[int, int, int, int] HandleType = tuple[int, int, int, int]
@dataclasses.dataclass @dataclasses.dataclass
...@@ -148,9 +148,9 @@ class CuMemAllocator: ...@@ -148,9 +148,9 @@ class CuMemAllocator:
"Please track https://github.com/pytorch/pytorch/issues/147851 " "Please track https://github.com/pytorch/pytorch/issues/147851 "
"for the latest updates.") "for the latest updates.")
self.pointer_to_data: Dict[int, AllocationData] = {} self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {} self.allocator_and_pools: dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None: def python_malloc_callback(self, allocation_handle: HandleType) -> None:
""" """
...@@ -172,7 +172,7 @@ class CuMemAllocator: ...@@ -172,7 +172,7 @@ class CuMemAllocator:
def sleep( def sleep(
self, self,
offload_tags: Optional[Union[Tuple[str, ...], offload_tags: Optional[Union[tuple[str, ...],
str]] = None) -> None: str]] = None) -> None:
""" """
Put the allocator in sleep mode. Put the allocator in sleep mode.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Optional, Union from typing import Any, Optional, Union
import torch import torch
import torch.distributed import torch.distributed
...@@ -32,7 +32,7 @@ def tensor_model_parallel_gather(input_: torch.Tensor, ...@@ -32,7 +32,7 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
return get_tp_group().gather(input_, dst, dim) return get_tp_group().gather(input_, dst, dim)
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor,
Any]]] = None, Any]]] = None,
src: int = 0): src: int = 0):
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -160,7 +160,7 @@ class DeviceCommunicatorBase: ...@@ -160,7 +160,7 @@ class DeviceCommunicatorBase:
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Dispatch the hidden states and router logits to the appropriate device. Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from typing import List, Optional from typing import Optional
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -126,7 +126,7 @@ class _CPUSHMDistributed: ...@@ -126,7 +126,7 @@ class _CPUSHMDistributed:
def gather(self, def gather(self,
input: torch.Tensor, input: torch.Tensor,
gather_list: Optional[List[torch.Tensor]], gather_list: Optional[list[torch.Tensor]],
dst: int = -1, dst: int = -1,
group: Optional[ProcessGroup] = None) -> None: group: Optional[ProcessGroup] = None) -> None:
# Note: different from the torch gather, here we use local dst rank. # Note: different from the torch gather, here we use local dst rank.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -154,7 +154,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -154,7 +154,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_impl is not None assert self.all2all_impl is not None
hidden_states, router_logits = self.all2all_impl.dispatch( hidden_states, router_logits = self.all2all_impl.dispatch(
hidden_states, router_logits) hidden_states, router_logits)
......
...@@ -6,7 +6,7 @@ convenient for use when we just need to call a few functions. ...@@ -6,7 +6,7 @@ convenient for use when we just need to call a few functions.
import ctypes import ctypes
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes` # this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa import torch # noqa
...@@ -32,7 +32,7 @@ class cudaIpcMemHandle_t(ctypes.Structure): ...@@ -32,7 +32,7 @@ class cudaIpcMemHandle_t(ctypes.Structure):
class Function: class Function:
name: str name: str
restype: Any restype: Any
argtypes: List[Any] argtypes: list[Any]
def find_loaded_library(lib_name) -> Optional[str]: def find_loaded_library(lib_name) -> Optional[str]:
...@@ -97,11 +97,11 @@ class CudaRTLibrary: ...@@ -97,11 +97,11 @@ class CudaRTLibrary:
# class attribute to store the mapping from the path to the library # class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times # to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {} path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path # class attribute to store the mapping from library path
# to the corresponding dictionary # to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None): def __init__(self, so_file: Optional[str] = None):
if so_file is None: if so_file is None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional, Union from typing import Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -276,7 +276,7 @@ class CustomAllreduce: ...@@ -276,7 +276,7 @@ class CustomAllreduce:
@staticmethod @staticmethod
def create_shared_buffer(size_in_bytes: int, def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> List[int]: uncached: Optional[bool] = False) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes) pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group) world_size = dist.get_world_size(group=group)
...@@ -284,7 +284,7 @@ class CustomAllreduce: ...@@ -284,7 +284,7 @@ class CustomAllreduce:
handles = [None] * world_size handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group) dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = [] pointers: list[int] = []
for i, h in enumerate(handles): for i, h in enumerate(handles):
if i == rank: if i == rank:
pointers.append(pointer) # type: ignore pointers.append(pointer) # type: ignore
...@@ -293,7 +293,7 @@ class CustomAllreduce: ...@@ -293,7 +293,7 @@ class CustomAllreduce:
return pointers return pointers
@staticmethod @staticmethod
def free_shared_buffer(pointers: List[int], def free_shared_buffer(pointers: list[int],
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
rank: Optional[int] = 0) -> None: rank: Optional[int] = 0) -> None:
if rank is None: if rank is None:
......
...@@ -7,8 +7,9 @@ import pickle ...@@ -7,8 +7,9 @@ import pickle
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
from collections.abc import Sequence
from itertools import product from itertools import product
from typing import Dict, List, Optional, Sequence from typing import Optional
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -149,7 +150,7 @@ def can_actually_p2p( ...@@ -149,7 +150,7 @@ def can_actually_p2p(
p_src.join() p_src.join()
p_tgt.join() p_tgt.join()
assert p_src.exitcode == 0 and p_tgt.exitcode == 0 assert p_src.exitcode == 0 and p_tgt.exitcode == 0
result: List[bool] = [] result: list[bool] = []
for src, tgt in zip(batch_src, batch_tgt): for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get() a = result_queue.get()
b = result_queue.get() b = result_queue.get()
...@@ -175,7 +176,7 @@ def can_actually_p2p( ...@@ -175,7 +176,7 @@ def can_actually_p2p(
# e.g. used by different vllm engines. The device id in the cache file is a # e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number # **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine. # of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None _gpu_p2p_access_cache: Optional[dict[str, bool]] = None
def gpu_p2p_access_check(src: int, tgt: int) -> bool: def gpu_p2p_access_check(src: int, tgt: int) -> bool:
...@@ -204,7 +205,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: ...@@ -204,7 +205,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# only the local master process (with local_rank == 0) can # only the local master process (with local_rank == 0) can
# enter this block to calculate the cache # enter this block to calculate the cache
logger.info("generating GPU P2P access cache in %s", path) logger.info("generating GPU P2P access cache in %s", path)
cache: Dict[str, bool] = {} cache: dict[str, bool] = {}
ids = list(range(num_dev)) ids = list(range(num_dev))
# batch of all pairs of GPUs # batch of all pairs of GPUs
batch_src, batch_tgt = zip(*list(product(ids, ids))) batch_src, batch_tgt = zip(*list(product(ids, ids)))
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
import ctypes import ctypes
import platform import platform
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
...@@ -121,7 +121,7 @@ class ncclRedOpTypeEnum: ...@@ -121,7 +121,7 @@ class ncclRedOpTypeEnum:
class Function: class Function:
name: str name: str
restype: Any restype: Any
argtypes: List[Any] argtypes: list[Any]
class NCCLLibrary: class NCCLLibrary:
...@@ -210,11 +210,11 @@ class NCCLLibrary: ...@@ -210,11 +210,11 @@ class NCCLLibrary:
# class attribute to store the mapping from the path to the library # class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times # to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {} path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path # class attribute to store the mapping from library path
# to the corresponding dictionary # to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None): def __init__(self, so_file: Optional[str] = None):
...@@ -238,7 +238,7 @@ class NCCLLibrary: ...@@ -238,7 +238,7 @@ class NCCLLibrary:
raise e raise e
if so_file not in NCCLLibrary.path_to_dict_mapping: if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {} _funcs: dict[str, Any] = {}
for func in NCCLLibrary.exported_functions: for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name) f = getattr(self.lib, func.name)
f.restype = func.restype f.restype = func.restype
......
...@@ -8,7 +8,7 @@ from contextlib import contextmanager ...@@ -8,7 +8,7 @@ from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing import shared_memory from multiprocessing import shared_memory
from threading import Event from threading import Event
from typing import Any, List, Optional, Tuple, Union from typing import Any, Optional, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -173,9 +173,9 @@ class ShmRingBuffer: ...@@ -173,9 +173,9 @@ class ShmRingBuffer:
@dataclass @dataclass
class Handle: class Handle:
local_reader_ranks: List[int] = field(default_factory=list) local_reader_ranks: list[int] = field(default_factory=list)
buffer_handle: Optional[Tuple[int, int, int, str]] = None buffer_handle: Optional[tuple[int, int, int, str]] = None
local_subscribe_addr: Optional[str] = None local_subscribe_addr: Optional[str] = None
remote_subscribe_addr: Optional[str] = None remote_subscribe_addr: Optional[str] = None
remote_addr_ipv6: bool = False remote_addr_ipv6: bool = False
...@@ -187,7 +187,7 @@ class MessageQueue: ...@@ -187,7 +187,7 @@ class MessageQueue:
self, self,
n_reader, # number of all readers n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory n_local_reader, # number of local readers through shared memory
local_reader_ranks: Optional[List[int]] = None, local_reader_ranks: Optional[list[int]] = None,
max_chunk_bytes: int = 1024 * 1024 * 10, max_chunk_bytes: int = 1024 * 1024 * 10,
max_chunks: int = 10, max_chunks: int = 10,
connect_ip: Optional[str] = None, connect_ip: Optional[str] = None,
......
...@@ -8,7 +8,7 @@ The class provides two primary abstract methods: ...@@ -8,7 +8,7 @@ The class provides two primary abstract methods:
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Tuple, Union from typing import TYPE_CHECKING, Union
import torch import torch
...@@ -55,7 +55,7 @@ class KVConnectorBase(ABC): ...@@ -55,7 +55,7 @@ class KVConnectorBase(ABC):
self, self,
model_executable: torch.nn.Module, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors], IntermediateTensors],
) -> None: ) -> None:
...@@ -71,7 +71,7 @@ class KVConnectorBase(ABC): ...@@ -71,7 +71,7 @@ class KVConnectorBase(ABC):
start and end layer information. start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM. metadata from vLLM.
kv_caches (List[torch.Tensor]): List of KV caches (keys and values) kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
for each layer. for each layer.
hidden_or_intermediate_states (Union[torch.Tensor, hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]): IntermediateTensors]):
...@@ -88,8 +88,8 @@ class KVConnectorBase(ABC): ...@@ -88,8 +88,8 @@ class KVConnectorBase(ABC):
def recv_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module, self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor] kv_caches: list[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]: "ModelInputForGPUWithSamplingMetadata"]:
""" """
Receive KV caches and hidden states from the connector. Receive KV caches and hidden states from the connector.
...@@ -104,7 +104,7 @@ class KVConnectorBase(ABC): ...@@ -104,7 +104,7 @@ class KVConnectorBase(ABC):
The model executable from vLLM modelrunner. The model executable from vLLM modelrunner.
model_input (ModelInputForGPUWithSamplingMetadata): model_input (ModelInputForGPUWithSamplingMetadata):
The model input from vLLM modelrunner. The model input from vLLM modelrunner.
kv_caches (List[torch.Tensor]): kv_caches (list[torch.Tensor]):
List of KV caches for each layer. List of KV caches for each layer.
Returns: Returns:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import importlib import importlib
from typing import TYPE_CHECKING, Callable, Dict, Type from typing import TYPE_CHECKING, Callable
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
...@@ -18,7 +18,7 @@ logger = init_logger(__name__) ...@@ -18,7 +18,7 @@ logger = init_logger(__name__)
class KVConnectorFactory: class KVConnectorFactory:
_registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {} _registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {}
@classmethod @classmethod
def register_connector(cls, name: str, module_path: str, def register_connector(cls, name: str, module_path: str,
...@@ -27,7 +27,7 @@ class KVConnectorFactory: ...@@ -27,7 +27,7 @@ class KVConnectorFactory:
if name in cls._registry: if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.") raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> Type[KVConnectorBaseType]: def loader() -> type[KVConnectorBaseType]:
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)
......
...@@ -7,7 +7,7 @@ The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker ...@@ -7,7 +7,7 @@ The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
(2) offload and share KV caches. (2) offload and share KV caches.
""" """
from typing import TYPE_CHECKING, List, Tuple, Union from typing import TYPE_CHECKING, Union
import torch import torch
...@@ -63,8 +63,8 @@ class LMCacheConnector(KVConnectorBase): ...@@ -63,8 +63,8 @@ class LMCacheConnector(KVConnectorBase):
def recv_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module, self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor] kv_caches: list[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]: "ModelInputForGPUWithSamplingMetadata"]:
retrieve_status = self.lmcache_should_retrieve(model_input) retrieve_status = self.lmcache_should_retrieve(model_input)
...@@ -78,7 +78,7 @@ class LMCacheConnector(KVConnectorBase): ...@@ -78,7 +78,7 @@ class LMCacheConnector(KVConnectorBase):
self, self,
model_executable: torch.nn.Module, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors], IntermediateTensors],
) -> None: ) -> None:
......
...@@ -6,7 +6,7 @@ The MooncakeStoreConnector transfers KV caches between prefill vLLM workers ...@@ -6,7 +6,7 @@ The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
database-style KVStore. database-style KVStore.
""" """
import hashlib import hashlib
from typing import TYPE_CHECKING, List, Tuple, Union from typing import TYPE_CHECKING, Union
import torch import torch
...@@ -70,7 +70,7 @@ class MooncakeStoreConnector(KVConnectorBase): ...@@ -70,7 +70,7 @@ class MooncakeStoreConnector(KVConnectorBase):
self, self,
model_executable: torch.nn.Module, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors], IntermediateTensors],
) -> None: ) -> None:
...@@ -113,8 +113,8 @@ class MooncakeStoreConnector(KVConnectorBase): ...@@ -113,8 +113,8 @@ class MooncakeStoreConnector(KVConnectorBase):
def recv_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module, self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor] kv_caches: list[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]: "ModelInputForGPUWithSamplingMetadata"]:
bypass_model_exec = True bypass_model_exec = True
input_tokens_tensor = model_input.input_tokens input_tokens_tensor = model_input.input_tokens
......
...@@ -8,7 +8,7 @@ MooncakePipe. ...@@ -8,7 +8,7 @@ MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer. But the logic can be extended to support other pipe and lookup buffer.
""" """
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Optional, Union
import torch import torch
...@@ -133,7 +133,7 @@ class SimpleConnector(KVConnectorBase): ...@@ -133,7 +133,7 @@ class SimpleConnector(KVConnectorBase):
) )
def select(self, input_tokens: Optional[torch.Tensor], def select(self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
assert self.consumer_buffer is not None, "Please initialize the "\ assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select." "consumer buffer before calling select."
...@@ -152,7 +152,7 @@ class SimpleConnector(KVConnectorBase): ...@@ -152,7 +152,7 @@ class SimpleConnector(KVConnectorBase):
self, self,
model_executable: torch.nn.Module, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors], IntermediateTensors],
) -> None: ) -> None:
...@@ -207,8 +207,8 @@ class SimpleConnector(KVConnectorBase): ...@@ -207,8 +207,8 @@ class SimpleConnector(KVConnectorBase):
def recv_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module, self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor] kv_caches: list[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]: "ModelInputForGPUWithSamplingMetadata"]:
# When bypass_model_exec is set to False, it means that at least for one # When bypass_model_exec is set to False, it means that at least for one
......
...@@ -5,13 +5,13 @@ import threading ...@@ -5,13 +5,13 @@ import threading
import time import time
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator from typing import TYPE_CHECKING, Any, Optional
import msgspec import msgspec
import torch import torch
import zmq import zmq
from typing_extensions import Optional
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
......
...@@ -5,7 +5,7 @@ This implementation is a shim wrapper on two APIs exposed by `kv_connector`: ...@@ -5,7 +5,7 @@ This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
1. `send_kv_caches_and_hidden_states` 1. `send_kv_caches_and_hidden_states`
2. `recv_kv_caches_and_hidden_states 2. `recv_kv_caches_and_hidden_states
""" """
from typing import TYPE_CHECKING, List, Tuple, Union from typing import TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
...@@ -53,7 +53,7 @@ class KVTransferAgent: ...@@ -53,7 +53,7 @@ class KVTransferAgent:
self, self,
model_executable: torch.nn.Module, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors], IntermediateTensors],
) -> None: ) -> None:
...@@ -68,8 +68,8 @@ class KVTransferAgent: ...@@ -68,8 +68,8 @@ class KVTransferAgent:
def recv_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module, self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor] kv_caches: list[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]: "ModelInputForGPUWithSamplingMetadata"]:
return self.connector.recv_kv_caches_and_hidden_states( return self.connector.recv_kv_caches_and_hidden_states(
......
...@@ -13,7 +13,7 @@ These classes above are abstracted behind class `KVCacheBufferBase`. ...@@ -13,7 +13,7 @@ These classes above are abstracted behind class `KVCacheBufferBase`.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import Optional
import torch import torch
...@@ -93,7 +93,7 @@ class KVLookupBufferBase(KVCacheBufferBase): ...@@ -93,7 +93,7 @@ class KVLookupBufferBase(KVCacheBufferBase):
@abstractmethod @abstractmethod
def drop_select( def drop_select(
self, input_tokens: Optional[torch.Tensor], self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
"""Select and *drop* KV cache entries from the lookup buffer. """Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements The functionality is similar to the following python statements
...@@ -111,7 +111,7 @@ class KVLookupBufferBase(KVCacheBufferBase): ...@@ -111,7 +111,7 @@ class KVLookupBufferBase(KVCacheBufferBase):
roi (torch.Tensor): A binary mask on top of the input tokens roi (torch.Tensor): A binary mask on top of the input tokens
Returns: Returns:
List[Optional[torch.Tensor]]: A list of tensors. Can be None. list[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises: Raises:
NotImplementedError: This method must be implemented in subclasses. NotImplementedError: This method must be implemented in subclasses.
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
""" """
import threading import threading
from collections import deque from collections import deque
from typing import Deque, List, Optional, Union from typing import Optional, Union
import torch import torch
...@@ -38,7 +38,7 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -38,7 +38,7 @@ class SimpleBuffer(KVLookupBufferBase):
data_pipe: on device (e.g. GPU) data_pipe: on device (e.g. GPU)
""" """
self.buffer: Deque[List[torch.Tensor]] = deque() self.buffer: deque[list[torch.Tensor]] = deque()
self.buffer_size = 0 self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh self.buffer_size_threshold = buffer_size_thresh
...@@ -50,8 +50,8 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -50,8 +50,8 @@ class SimpleBuffer(KVLookupBufferBase):
self.normal_signal = torch.tensor([0], device="cpu") self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None self.end_signal = None
def _matches(self, tokens_roi_sender: List[torch.Tensor], def _matches(self, tokens_roi_sender: list[torch.Tensor],
tokens_roi_recver: List[torch.Tensor]): tokens_roi_recver: list[torch.Tensor]):
# tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query) # tokens_roi_recver: tokens and roi of the consumer (query)
...@@ -88,7 +88,7 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -88,7 +88,7 @@ class SimpleBuffer(KVLookupBufferBase):
tensor = tensor.float() tensor = tensor.float()
self.data_pipe.send_tensor(tensor) self.data_pipe.send_tensor(tensor)
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data.element_size() * data.numel() return data.element_size() * data.numel()
...@@ -151,7 +151,7 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -151,7 +151,7 @@ class SimpleBuffer(KVLookupBufferBase):
tokens_roi_recver = [input_tokens, roi] tokens_roi_recver = [input_tokens, roi]
def is_buffer_available( def is_buffer_available(
tokens_roi_recver: List[torch.Tensor], ) -> bool: tokens_roi_recver: list[torch.Tensor], ) -> bool:
# perform input tokens and roi matching # perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1) # FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so # but this buffer size won't (and shouldn't) be too large so
...@@ -184,7 +184,7 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -184,7 +184,7 @@ class SimpleBuffer(KVLookupBufferBase):
def drop_select( def drop_select(
self, input_tokens: Optional[torch.Tensor], self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
assert self.request_handling_thread is None, \ assert self.request_handling_thread is None, \
"drop_select should be called by the KV cache consumer "\ "drop_select should be called by the KV cache consumer "\
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment