"docs/source/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "b1054cbb8a8fbe0e266a9adb5e0bde555c9f3c3b"
Unverified Commit 5b23c3f2 authored by Junda Chen's avatar Junda Chen Committed by GitHub
Browse files

Add `group` as an argument in broadcast ops (#2522)

parent 00efdc84
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from torch.distributed import ProcessGroup
import torch import torch
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
...@@ -86,47 +88,59 @@ def tensor_model_parallel_gather(input_: torch.Tensor, ...@@ -86,47 +88,59 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
return output_tensor return output_tensor
def broadcast(input_: torch.Tensor, src: int = 0): def broadcast(input_: torch.Tensor,
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input tensor.""" """Broadcast the input tensor."""
world_size = torch.distributed.get_world_size() group = group or torch.distributed.group.WORLD
assert 0 <= src < world_size, f"Invalid src rank ({src})" ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1: if world_size == 1:
return input_ return input_
# Broadcast. # Broadcast.
torch.distributed.broadcast(input_, src=src) torch.distributed.broadcast(input_, src=src, group=group)
return input_ return input_
def broadcast_object_list(obj_list: List[Any], src: int = 0): def broadcast_object_list(obj_list: List[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input object list.""" """Broadcast the input object list."""
world_size = torch.distributed.get_world_size() group = group or torch.distributed.group.WORLD
assert 0 <= src < world_size, f"Invalid src rank ({src})" ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1: if world_size == 1:
return obj_list return obj_list
# Broadcast. # Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src) torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
return obj_list return obj_list
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, def broadcast_tensor_dict(
Any]]] = None, tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0) -> Dict[Any, Union[torch.Tensor, Any]]: src: int = 0,
group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]:
"""Broadcast the input tensor dictionary.""" """Broadcast the input tensor dictionary."""
rank = torch.distributed.get_rank() group = group or torch.distributed.group.WORLD
world_size = torch.distributed.get_world_size() ranks = torch.distributed.get_process_group_ranks(group)
assert 0 <= src < world_size, f"Invalid src rank ({src})" assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1: if world_size == 1:
return tensor_dict return tensor_dict
rank = torch.distributed.get_rank()
if rank == src: if rank == src:
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
...@@ -141,14 +155,18 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, ...@@ -141,14 +155,18 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
(key, TensorMetadata(value.dtype, value.size()))) (key, TensorMetadata(value.dtype, value.size())))
else: else:
metadata_list.append((key, value)) metadata_list.append((key, value))
torch.distributed.broadcast_object_list([metadata_list], src=src) torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=group)
for key, value in metadata_list: for key, value in metadata_list:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = tensor_dict[key] tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src) torch.distributed.broadcast(tensor, src=src)
else: else:
recv_metadata_list = [None] recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list, src=src) torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=group)
metadata_list = recv_metadata_list[0] metadata_list = recv_metadata_list[0]
tensor_dict = {} tensor_dict = {}
async_handles = [] async_handles = []
...@@ -159,7 +177,8 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, ...@@ -159,7 +177,8 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
device="cuda") device="cuda")
async_handle = torch.distributed.broadcast(tensor, async_handle = torch.distributed.broadcast(tensor,
src=src, src=src,
async_op=True) async_op=True,
group=group)
async_handles.append(async_handle) async_handles.append(async_handle)
tensor_dict[key] = tensor tensor_dict[key] = tensor
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment