Unverified Commit db1764e4 authored by wangxiyuan's avatar wangxiyuan Committed by GitHub
Browse files

[Platform] allow platform to init dp group (#22243)


Signed-off-by: default avatarwangxiyuan <wangxiyuan1007@gmail.com>
parent 7f83b4ee
...@@ -334,7 +334,7 @@ class ParallelConfig: ...@@ -334,7 +334,7 @@ class ParallelConfig:
self.get_next_dp_init_port(), self.get_next_dp_init_port(),
self.data_parallel_rank, self.data_parallel_rank,
self.data_parallel_size, self.data_parallel_size,
backend="gloo", backend=current_platform.dist_backend,
) )
except DistNetworkError as e: except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE. # We only want to retry when the root cause is EADDRINUSE.
......
...@@ -415,7 +415,6 @@ class StatelessProcessGroup: ...@@ -415,7 +415,6 @@ class StatelessProcessGroup:
def init_gloo_process_group( def init_gloo_process_group(
backend: Backend,
prefix_store: PrefixStore, prefix_store: PrefixStore,
group_rank: int, group_rank: int,
group_size: int, group_size: int,
...@@ -432,7 +431,7 @@ def init_gloo_process_group( ...@@ -432,7 +431,7 @@ def init_gloo_process_group(
group_size, group_size,
) )
else: else:
options = ProcessGroup.Options(backend=backend) options = ProcessGroup.Options(backend="gloo")
pg = ProcessGroup( pg = ProcessGroup(
prefix_store, prefix_store,
group_rank, group_rank,
...@@ -504,19 +503,20 @@ def stateless_init_torch_distributed_process_group( ...@@ -504,19 +503,20 @@ def stateless_init_torch_distributed_process_group(
# Use a PrefixStore to avoid accidental overrides of keys used by # Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant. # different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store) prefix_store = PrefixStore(init_method, store)
try:
from vllm.platforms import current_platform
if backend == "gloo": return current_platform.stateless_init_device_torch_dist_pg(
return init_gloo_process_group(
backend=backend, backend=backend,
prefix_store=prefix_store, prefix_store=prefix_store,
group_rank=group_rank, group_rank=group_rank,
group_size=group_size, group_size=group_size,
timeout=timeout, timeout=timeout,
) )
from vllm.platforms import current_platform except NotImplementedError:
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
return current_platform.stateless_init_device_torch_dist_pg( # will raise a NotImplementedError. In this case, we fall back to gloo.
backend=backend, return init_gloo_process_group(
prefix_store=prefix_store, prefix_store=prefix_store,
group_rank=group_rank, group_rank=group_rank,
group_size=group_size, group_size=group_size,
......
...@@ -6,13 +6,10 @@ pynvml. However, it should not initialize cuda context. ...@@ -6,13 +6,10 @@ pynvml. However, it should not initialize cuda context.
import os import os
from collections.abc import Callable from collections.abc import Callable
from datetime import timedelta
from functools import cache, wraps from functools import cache, wraps
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, TypeVar
import torch import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
...@@ -455,37 +452,6 @@ class CudaPlatformBase(Platform): ...@@ -455,37 +452,6 @@ class CudaPlatformBase(Platform):
def get_static_graph_wrapper_cls(cls) -> str: def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper" return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod @classmethod
def device_count(cls) -> int: def device_count(cls) -> int:
return cuda_device_count_stateless() return cuda_device_count_stateless()
......
...@@ -551,7 +551,7 @@ class Platform: ...@@ -551,7 +551,7 @@ class Platform:
""" """
Init platform-specific torch distributed process group. Init platform-specific torch distributed process group.
""" """
raise RuntimeError(f"Unsupported torch distributed backend: {backend}") raise NotImplementedError
@classmethod @classmethod
def is_kv_cache_dtype_supported( def is_kv_cache_dtype_supported(
......
...@@ -2,13 +2,10 @@ ...@@ -2,13 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from datetime import timedelta
from functools import cache, lru_cache, wraps from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -476,37 +473,6 @@ class RocmPlatform(Platform): ...@@ -476,37 +473,6 @@ class RocmPlatform(Platform):
def get_static_graph_wrapper_cls(cls) -> str: def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper" return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod @classmethod
def device_count(cls) -> int: def device_count(cls) -> int:
return cuda_device_count_stateless() return cuda_device_count_stateless()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment