Unverified Commit 0b217da6 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `vllm/adapter_commons` (#18073)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 19324d66
...@@ -72,7 +72,6 @@ exclude = [ ...@@ -72,7 +72,6 @@ exclude = [
"vllm/version.py" = ["F401"] "vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"] "vllm/_version.py" = ["ALL"]
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0 # Python 3.8 typing. TODO: Remove these excludes after v1.0.0
"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"]
"vllm/attention/**/*.py" = ["UP006", "UP035"] "vllm/attention/**/*.py" = ["UP006", "UP035"]
"vllm/core/**/*.py" = ["UP006", "UP035"] "vllm/core/**/*.py" = ["UP006", "UP035"]
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"] "vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple
@dataclass @dataclass
class AdapterMapping: class AdapterMapping:
# Per every token in input_ids: # Per every token in input_ids:
index_mapping: Tuple[int, ...] index_mapping: tuple[int, ...]
# Per sampled token: # Per sampled token:
prompt_mapping: Tuple[int, ...] prompt_mapping: tuple[int, ...]
def __post_init__(self): def __post_init__(self):
self.index_mapping = tuple(self.index_mapping) self.index_mapping = tuple(self.index_mapping)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, TypeVar from typing import Any, Callable, Optional, TypeVar
from torch import nn from torch import nn
...@@ -49,9 +49,9 @@ class AdapterModelManager(ABC): ...@@ -49,9 +49,9 @@ class AdapterModelManager(ABC):
model: the model to be adapted. model: the model to be adapted.
""" """
self.model: nn.Module = model self.model: nn.Module = model
self._registered_adapters: Dict[int, Any] = {} self._registered_adapters: dict[int, Any] = {}
# Dict instead of a Set for compatibility with LRUCache. # Dict instead of a Set for compatibility with LRUCache.
self._active_adapters: Dict[int, None] = {} self._active_adapters: dict[int, None] = {}
self.adapter_type = 'Adapter' self.adapter_type = 'Adapter'
self._last_mapping = None self._last_mapping = None
...@@ -97,7 +97,7 @@ class AdapterModelManager(ABC): ...@@ -97,7 +97,7 @@ class AdapterModelManager(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def list_adapters(self) -> Dict[int, Any]: def list_adapters(self) -> dict[int, Any]:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, Optional, Set from typing import Any, Callable, Optional
## model functions ## model functions
def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], def deactivate_adapter(adapter_id: int, active_adapters: dict[int, None],
deactivate_func: Callable) -> bool: deactivate_func: Callable) -> bool:
if adapter_id in active_adapters: if adapter_id in active_adapters:
deactivate_func(adapter_id) deactivate_func(adapter_id)
...@@ -13,7 +13,7 @@ def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], ...@@ -13,7 +13,7 @@ def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
return False return False
def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], def add_adapter(adapter: Any, registered_adapters: dict[int, Any],
capacity: int, add_func: Callable) -> bool: capacity: int, add_func: Callable) -> bool:
if adapter.id not in registered_adapters: if adapter.id not in registered_adapters:
if len(registered_adapters) >= capacity: if len(registered_adapters) >= capacity:
...@@ -32,23 +32,23 @@ def set_adapter_mapping(mapping: Any, last_mapping: Any, ...@@ -32,23 +32,23 @@ def set_adapter_mapping(mapping: Any, last_mapping: Any,
return last_mapping return last_mapping
def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], def remove_adapter(adapter_id: int, registered_adapters: dict[int, Any],
deactivate_func: Callable) -> bool: deactivate_func: Callable) -> bool:
deactivate_func(adapter_id) deactivate_func(adapter_id)
return bool(registered_adapters.pop(adapter_id, None)) return bool(registered_adapters.pop(adapter_id, None))
def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: def list_adapters(registered_adapters: dict[int, Any]) -> dict[int, Any]:
return dict(registered_adapters) return dict(registered_adapters)
def get_adapter(adapter_id: int, def get_adapter(adapter_id: int,
registered_adapters: Dict[int, Any]) -> Optional[Any]: registered_adapters: dict[int, Any]) -> Optional[Any]:
return registered_adapters.get(adapter_id) return registered_adapters.get(adapter_id)
## worker functions ## worker functions
def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], def set_active_adapters_worker(requests: set[Any], mapping: Optional[Any],
apply_adapters_func, apply_adapters_func,
set_adapter_mapping_func) -> None: set_adapter_mapping_func) -> None:
apply_adapters_func(requests) apply_adapters_func(requests)
...@@ -66,7 +66,7 @@ def add_adapter_worker(adapter_request: Any, list_adapters_func, ...@@ -66,7 +66,7 @@ def add_adapter_worker(adapter_request: Any, list_adapters_func,
return loaded return loaded
def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, def apply_adapters_worker(adapter_requests: set[Any], list_adapters_func,
adapter_slots: int, remove_adapter_func, adapter_slots: int, remove_adapter_func,
add_adapter_func) -> None: add_adapter_func) -> None:
models_that_exist = list_adapters_func() models_that_exist = list_adapters_func()
...@@ -88,5 +88,5 @@ def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, ...@@ -88,5 +88,5 @@ def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
add_adapter_func(models_map[adapter_id]) add_adapter_func(models_map[adapter_id])
def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]: def list_adapters_worker(adapter_manager_list_adapters_func) -> set[int]:
return set(adapter_manager_list_adapters_func()) return set(adapter_manager_list_adapters_func())
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional, Set from typing import Any, Optional
import torch import torch
...@@ -17,7 +17,7 @@ class AbstractWorkerManager(ABC): ...@@ -17,7 +17,7 @@ class AbstractWorkerManager(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def set_active_adapters(self, requests: Set[Any], def set_active_adapters(self, requests: set[Any],
mapping: Optional[Any]) -> None: mapping: Optional[Any]) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -34,5 +34,5 @@ class AbstractWorkerManager(ABC): ...@@ -34,5 +34,5 @@ class AbstractWorkerManager(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def list_adapters(self) -> Set[int]: def list_adapters(self) -> set[int]:
raise NotImplementedError raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment