Unverified Commit 6d592eb4 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core] separate distributed_init from worker (#3904)

parent d036198e
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups.""" """Tensor and pipeline parallel groups."""
import contextlib import contextlib
from typing import Optional
import torch import torch
...@@ -14,14 +15,59 @@ _TENSOR_MODEL_PARALLEL_GROUP = None ...@@ -14,14 +15,59 @@ _TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to. # Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
# parameter of `init_distributed_environment` below.
# Essentially, this is `torch.distributed.group.WORLD`.
# We leave a line here to note that this is device-specific.
# Note that this variable is not safe to use, because when users
# call `init_distributed_environment` first, and then destroy
# the process group themselves, this variable will keep a reference to the
# destroyed process group, which is not useful.
_DEVICE_WORLD_GROUP = None
# duing `init_distributed_environment`, we will also initialize a
# group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
_CPU_WORLD_GROUP = None
# In summary, after calling `init_distributed_environment`, we will
# always have two groups: one for device-specific (and is the default)
# and one for CPU. All processes will be part of both groups.
# A list of global ranks for each pipeline group to ease calculation of the # A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage. # source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
def init_distributed_environment(
world_size: int,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
backend: str = "nccl",
):
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment")
# this backend is used for WORLD
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank)
global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
_DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo")
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None: ) -> None:
""" """
Initialize model parallel groups. Initialize model parallel groups.
...@@ -48,6 +94,8 @@ def initialize_model_parallel( ...@@ -48,6 +94,8 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size() world_size: int = torch.distributed.get_world_size()
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend()
if (world_size != if (world_size !=
tensor_model_parallel_size * pipeline_model_parallel_size): tensor_model_parallel_size * pipeline_model_parallel_size):
...@@ -69,7 +117,7 @@ def initialize_model_parallel( ...@@ -69,7 +117,7 @@ def initialize_model_parallel(
for i in range(num_tensor_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size) (i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks, backend=backend)
if rank in ranks: if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group _TENSOR_MODEL_PARALLEL_GROUP = group
...@@ -80,7 +128,7 @@ def initialize_model_parallel( ...@@ -80,7 +128,7 @@ def initialize_model_parallel(
"pipeline model parallel group is already initialized") "pipeline model parallel group is already initialized")
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups) ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks, backend=backend)
if rank in ranks: if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks _PIPELINE_GLOBAL_RANKS = ranks
...@@ -89,14 +137,17 @@ def initialize_model_parallel( ...@@ -89,14 +137,17 @@ def initialize_model_parallel(
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
backend: Optional[str] = None,
) -> None: ) -> None:
"""Helper to initialize model parallel groups if they are not initialized, """Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized. values if the model parallel groups are initialized.
""" """
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend()
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size) pipeline_model_parallel_size, backend)
return return
assert ( assert (
...@@ -117,6 +168,12 @@ def model_parallel_is_initialized(): ...@@ -117,6 +168,12 @@ def model_parallel_is_initialized():
and _PIPELINE_MODEL_PARALLEL_GROUP is not None) and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
def get_cpu_world_group():
"""Get the CPU world group."""
assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
return _CPU_WORLD_GROUP
def get_tensor_model_parallel_group(): def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to.""" """Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
......
import ray import ray
from vllm.config import ParallelConfig from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized, init_distributed_environment)
from vllm.utils import get_open_port from vllm.utils import get_open_port
from vllm.worker.worker import init_distributed_environment
def init_test_distributed_environment( def init_test_distributed_environment(
...@@ -12,15 +12,14 @@ def init_test_distributed_environment( ...@@ -12,15 +12,14 @@ def init_test_distributed_environment(
distributed_init_port: str, distributed_init_port: str,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
parallel_config = ParallelConfig(pipeline_parallel_size,
tensor_parallel_size,
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}" distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment( init_distributed_environment(
parallel_config, world_size=pipeline_parallel_size * tensor_parallel_size,
rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
local_rank=local_rank) local_rank=local_rank)
ensure_model_parallel_initialized(tensor_parallel_size,
pipeline_parallel_size)
def multi_process_tensor_parallel( def multi_process_tensor_parallel(
......
...@@ -13,7 +13,7 @@ from vllm.model_executor.model_loader import get_model ...@@ -13,7 +13,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized) ensure_model_parallel_initialized, init_distributed_environment)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
...@@ -251,25 +251,11 @@ class CPUWorker: ...@@ -251,25 +251,11 @@ class CPUWorker:
parallel_config = self.parallel_config parallel_config = self.parallel_config
rank = self.rank rank = self.rank
distributed_init_method = self.distributed_init_method distributed_init_method = self.distributed_init_method
init_distributed_environment(
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch "
"world size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
backend = "gloo"
torch.distributed.init_process_group(
backend=backend,
world_size=parallel_config.world_size, world_size=parallel_config.world_size,
rank=rank, rank=rank,
init_method=distributed_init_method, distributed_init_method=distributed_init_method,
backend="gloo",
) )
# A small all_reduce for warmup. # A small all_reduce for warmup.
......
...@@ -15,7 +15,7 @@ from vllm.model_executor.parallel_utils.communication_op import ( ...@@ -15,7 +15,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized) ensure_model_parallel_initialized, init_distributed_environment)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
...@@ -97,7 +97,7 @@ class Worker: ...@@ -97,7 +97,7 @@ class Worker:
raise RuntimeError( raise RuntimeError(
f"Not support device type: {self.device_config.device}") f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank, init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method, self.distributed_init_method,
self.local_rank) self.local_rank)
# Set random seed. # Set random seed.
...@@ -248,31 +248,15 @@ class Worker: ...@@ -248,31 +248,15 @@ class Worker:
self.parallel_config) self.parallel_config)
def init_distributed_environment( def init_worker_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
if torch.distributed.is_initialized(): init_distributed_environment(parallel_config.world_size, rank,
torch_world_size = torch.distributed.get_world_size() distributed_init_method, local_rank)
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group(
backend="nccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
if pynccl_utils.is_initialized(): if pynccl_utils.is_initialized():
pynccl_world_size = pynccl_utils.get_world_size() pynccl_world_size = pynccl_utils.get_world_size()
...@@ -291,10 +275,6 @@ def init_distributed_environment( ...@@ -291,10 +275,6 @@ def init_distributed_environment(
init_method=distributed_init_method, init_method=distributed_init_method,
) )
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
if pynccl_utils.is_initialized():
pynccl_utils.all_reduce(torch.zeros(1).cuda())
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
...@@ -302,6 +282,11 @@ def init_distributed_environment( ...@@ -302,6 +282,11 @@ def init_distributed_environment(
if not parallel_config.disable_custom_all_reduce: if not parallel_config.disable_custom_all_reduce:
init_custom_ar() init_custom_ar()
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
if pynccl_utils.is_initialized():
pynccl_utils.all_reduce(torch.zeros(1).cuda())
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment