Unverified Commit dc372b9c authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `vllm/device_allocator` and `vllm/distributed` (#18126)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 9b5b39b6
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import threading import threading
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Optional
import torch import torch
...@@ -35,7 +35,7 @@ class BrokenPipeException(Exception): ...@@ -35,7 +35,7 @@ class BrokenPipeException(Exception):
super().__init__(self.message) super().__init__(self.message)
Metadata = Dict[str, Optional[torch.Tensor]] Metadata = dict[str, Optional[torch.Tensor]]
class PyNcclPipe(KVPipeBase): class PyNcclPipe(KVPipeBase):
...@@ -83,7 +83,7 @@ class PyNcclPipe(KVPipeBase): ...@@ -83,7 +83,7 @@ class PyNcclPipe(KVPipeBase):
def _get_device_send_recv_impl( def _get_device_send_recv_impl(
self, group: StatelessProcessGroup self, group: StatelessProcessGroup
) -> Tuple[Callable[[torch.Tensor, int], None], Callable[ ) -> tuple[Callable[[torch.Tensor, int], None], Callable[
[torch.Tensor, int], None]]: [torch.Tensor, int], None]]:
send: Callable[[torch.Tensor, int], None] send: Callable[[torch.Tensor, int], None]
......
...@@ -29,7 +29,7 @@ from collections import namedtuple ...@@ -29,7 +29,7 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Optional, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -54,15 +54,15 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) ...@@ -54,15 +54,15 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict( def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]] tensor_dict: dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: ) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
"""Split the tensor dictionary into two parts: """Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata. by its metadata.
2. A list of tensors. 2. A list of tensors.
""" """
metadata_list: List[Tuple[str, Any]] = [] metadata_list: list[tuple[str, Any]] = []
tensor_list: List[torch.Tensor] = [] tensor_list: list[torch.Tensor] = []
for key, value in tensor_dict.items(): for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here, # Note: we cannot use `value.device` here,
...@@ -78,7 +78,7 @@ def _split_tensor_dict( ...@@ -78,7 +78,7 @@ def _split_tensor_dict(
return metadata_list, tensor_list return metadata_list, tensor_list
_group_name_counter: Dict[str, int] = {} _group_name_counter: dict[str, int] = {}
def _get_unique_name(name: str) -> str: def _get_unique_name(name: str) -> str:
...@@ -94,7 +94,7 @@ def _get_unique_name(name: str) -> str: ...@@ -94,7 +94,7 @@ def _get_unique_name(name: str) -> str:
return newname return newname
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} _groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
def _register_group(group: "GroupCoordinator") -> None: def _register_group(group: "GroupCoordinator") -> None:
...@@ -182,7 +182,7 @@ class GroupCoordinator: ...@@ -182,7 +182,7 @@ class GroupCoordinator:
# available attributes: # available attributes:
rank: int # global rank rank: int # global rank
ranks: List[int] # global ranks in the group ranks: list[int] # global ranks in the group
world_size: int # size of the group world_size: int # size of the group
# difference between `local_rank` and `rank_in_group`: # difference between `local_rank` and `rank_in_group`:
# if we have a group of size 4 across two nodes: # if we have a group of size 4 across two nodes:
...@@ -201,7 +201,7 @@ class GroupCoordinator: ...@@ -201,7 +201,7 @@ class GroupCoordinator:
def __init__( def __init__(
self, self,
group_ranks: List[List[int]], group_ranks: list[list[int]],
local_rank: int, local_rank: int,
torch_distributed_backend: Union[str, Backend], torch_distributed_backend: Union[str, Backend],
use_device_communicator: bool, use_device_communicator: bool,
...@@ -435,7 +435,7 @@ class GroupCoordinator: ...@@ -435,7 +435,7 @@ class GroupCoordinator:
return recv[0] return recv[0]
def broadcast_object_list(self, def broadcast_object_list(self,
obj_list: List[Any], obj_list: list[Any],
src: int = 0, src: int = 0,
group: Optional[ProcessGroup] = None): group: Optional[ProcessGroup] = None):
"""Broadcast the input object list. """Broadcast the input object list.
...@@ -518,11 +518,11 @@ class GroupCoordinator: ...@@ -518,11 +518,11 @@ class GroupCoordinator:
def broadcast_tensor_dict( def broadcast_tensor_dict(
self, self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0, src: int = 0,
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary. """Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank. NOTE: `src` is the local rank of the source rank.
""" """
...@@ -536,7 +536,7 @@ class GroupCoordinator: ...@@ -536,7 +536,7 @@ class GroupCoordinator:
rank_in_group = self.rank_in_group rank_in_group = self.rank_in_group
if rank_in_group == src: if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = [] metadata_list: list[tuple[Any, Any]] = []
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}") dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
...@@ -603,10 +603,10 @@ class GroupCoordinator: ...@@ -603,10 +603,10 @@ class GroupCoordinator:
def send_tensor_dict( def send_tensor_dict(
self, self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]], tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None, dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None, all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary. """Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank. NOTE: `dst` is the local rank of the source rank.
""" """
...@@ -626,7 +626,7 @@ class GroupCoordinator: ...@@ -626,7 +626,7 @@ class GroupCoordinator:
dst = (self.rank_in_group + 1) % self.world_size dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})" assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = [] metadata_list: list[tuple[Any, Any]] = []
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}" dict), f"Expecting a dictionary, got {type(tensor_dict)}"
...@@ -661,7 +661,7 @@ class GroupCoordinator: ...@@ -661,7 +661,7 @@ class GroupCoordinator:
self, self,
src: Optional[int] = None, src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None, all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary. """Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank. NOTE: `src` is the local rank of the source rank.
""" """
...@@ -682,7 +682,7 @@ class GroupCoordinator: ...@@ -682,7 +682,7 @@ class GroupCoordinator:
assert src < self.world_size, f"Invalid src rank ({src})" assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src) recv_metadata_list = self.recv_object(src=src)
tensor_dict: Dict[str, Any] = {} tensor_dict: dict[str, Any] = {}
for key, value in recv_metadata_list: for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, tensor = torch.empty(value.size,
...@@ -764,7 +764,7 @@ class GroupCoordinator: ...@@ -764,7 +764,7 @@ class GroupCoordinator:
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states, return self.device_communicator.dispatch(hidden_states,
router_logits) router_logits)
...@@ -782,7 +782,7 @@ def get_world_group() -> GroupCoordinator: ...@@ -782,7 +782,7 @@ def get_world_group() -> GroupCoordinator:
return _WORLD return _WORLD
def init_world_group(ranks: List[int], local_rank: int, def init_world_group(ranks: list[int], local_rank: int,
backend: str) -> GroupCoordinator: backend: str) -> GroupCoordinator:
return GroupCoordinator( return GroupCoordinator(
group_ranks=[ranks], group_ranks=[ranks],
...@@ -794,7 +794,7 @@ def init_world_group(ranks: List[int], local_rank: int, ...@@ -794,7 +794,7 @@ def init_world_group(ranks: List[int], local_rank: int,
def init_model_parallel_group( def init_model_parallel_group(
group_ranks: List[List[int]], group_ranks: list[list[int]],
local_rank: int, local_rank: int,
backend: str, backend: str,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
...@@ -1182,7 +1182,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): ...@@ -1182,7 +1182,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
source_rank: int = 0) -> List[bool]: source_rank: int = 0) -> list[bool]:
""" """
This is a collective operation that returns if each rank is in the same node This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same as the source rank. It tests if processes are attached to the same
......
...@@ -10,7 +10,8 @@ import pickle ...@@ -10,7 +10,8 @@ import pickle
import socket import socket
import time import time
from collections import deque from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple from collections.abc import Sequence
from typing import Any, Optional
import torch import torch
from torch.distributed import ProcessGroup, TCPStore from torch.distributed import ProcessGroup, TCPStore
...@@ -69,7 +70,7 @@ def split_tensor_along_last_dim( ...@@ -69,7 +70,7 @@ def split_tensor_along_last_dim(
def get_pp_indices(num_hidden_layers: int, pp_rank: int, def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]: pp_size: int) -> tuple[int, int]:
"""Try to evenly distribute layers across partitions. """Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions, If the number of layers is not divisible by the number of partitions,
...@@ -132,15 +133,15 @@ class StatelessProcessGroup: ...@@ -132,15 +133,15 @@ class StatelessProcessGroup:
data_expiration_seconds: int = 3600 # 1 hour data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter # dst rank -> counter
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict) send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
# src rank -> counter # src rank -> counter
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
broadcast_send_counter: int = 0 broadcast_send_counter: int = 0
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field( broadcast_recv_src_counter: dict[int, int] = dataclasses.field(
default_factory=dict) default_factory=dict)
# A deque to store the data entries, with key and timestamp. # A deque to store the data entries, with key and timestamp.
entries: Deque[Tuple[str, entries: deque[tuple[str,
float]] = dataclasses.field(default_factory=deque) float]] = dataclasses.field(default_factory=deque)
def __post_init__(self): def __post_init__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment