"components/vscode:/vscode.git/clone" did not exist on "ba9a8a9fb6f33a65a1745d096d83d7d789048357"
Unverified Commit dc372b9c authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `vllm/device_allocator` and `vllm/distributed` (#18126)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 9b5b39b6
......@@ -15,7 +15,7 @@
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Optional
import torch
......@@ -35,7 +35,7 @@ class BrokenPipeException(Exception):
super().__init__(self.message)
Metadata = Dict[str, Optional[torch.Tensor]]
Metadata = dict[str, Optional[torch.Tensor]]
class PyNcclPipe(KVPipeBase):
......@@ -83,7 +83,7 @@ class PyNcclPipe(KVPipeBase):
def _get_device_send_recv_impl(
self, group: StatelessProcessGroup
) -> Tuple[Callable[[torch.Tensor, int], None], Callable[
) -> tuple[Callable[[torch.Tensor, int], None], Callable[
[torch.Tensor, int], None]]:
send: Callable[[torch.Tensor, int], None]
......
......@@ -29,7 +29,7 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
import torch
......@@ -54,15 +54,15 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
tensor_dict: dict[str, Union[torch.Tensor, Any]]
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list: List[Tuple[str, Any]] = []
tensor_list: List[torch.Tensor] = []
metadata_list: list[tuple[str, Any]] = []
tensor_list: list[torch.Tensor] = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
......@@ -78,7 +78,7 @@ def _split_tensor_dict(
return metadata_list, tensor_list
_group_name_counter: Dict[str, int] = {}
_group_name_counter: dict[str, int] = {}
def _get_unique_name(name: str) -> str:
......@@ -94,7 +94,7 @@ def _get_unique_name(name: str) -> str:
return newname
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
def _register_group(group: "GroupCoordinator") -> None:
......@@ -182,7 +182,7 @@ class GroupCoordinator:
# available attributes:
rank: int # global rank
ranks: List[int] # global ranks in the group
ranks: list[int] # global ranks in the group
world_size: int # size of the group
# difference between `local_rank` and `rank_in_group`:
# if we have a group of size 4 across two nodes:
......@@ -201,7 +201,7 @@ class GroupCoordinator:
def __init__(
self,
group_ranks: List[List[int]],
group_ranks: list[list[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_device_communicator: bool,
......@@ -435,7 +435,7 @@ class GroupCoordinator:
return recv[0]
def broadcast_object_list(self,
obj_list: List[Any],
obj_list: list[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input object list.
......@@ -518,11 +518,11 @@ class GroupCoordinator:
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
......@@ -536,7 +536,7 @@ class GroupCoordinator:
rank_in_group = self.rank_in_group
if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = []
metadata_list: list[tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
......@@ -603,10 +603,10 @@ class GroupCoordinator:
def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
......@@ -626,7 +626,7 @@ class GroupCoordinator:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
metadata_list: list[tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
......@@ -661,7 +661,7 @@ class GroupCoordinator:
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
......@@ -682,7 +682,7 @@ class GroupCoordinator:
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: Dict[str, Any] = {}
tensor_dict: dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
......@@ -764,7 +764,7 @@ class GroupCoordinator:
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states,
router_logits)
......@@ -782,7 +782,7 @@ def get_world_group() -> GroupCoordinator:
return _WORLD
def init_world_group(ranks: List[int], local_rank: int,
def init_world_group(ranks: list[int], local_rank: int,
backend: str) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=[ranks],
......@@ -794,7 +794,7 @@ def init_world_group(ranks: List[int], local_rank: int,
def init_model_parallel_group(
group_ranks: List[List[int]],
group_ranks: list[list[int]],
local_rank: int,
backend: str,
use_message_queue_broadcaster: bool = False,
......@@ -1182,7 +1182,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
source_rank: int = 0) -> List[bool]:
source_rank: int = 0) -> list[bool]:
"""
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same
......
......@@ -10,7 +10,8 @@ import pickle
import socket
import time
from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import Any, Optional
import torch
from torch.distributed import ProcessGroup, TCPStore
......@@ -69,7 +70,7 @@ def split_tensor_along_last_dim(
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]:
pp_size: int) -> tuple[int, int]:
"""Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions,
......@@ -132,15 +133,15 @@ class StatelessProcessGroup:
data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
# src rank -> counter
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
broadcast_send_counter: int = 0
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(
broadcast_recv_src_counter: dict[int, int] = dataclasses.field(
default_factory=dict)
# A deque to store the data entries, with key and timestamp.
entries: Deque[Tuple[str,
entries: deque[tuple[str,
float]] = dataclasses.field(default_factory=deque)
def __post_init__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment