Unverified Commit d781930f authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[Platform][Dist] Make torch distributed process group extendable (#18763)


Signed-off-by: default avatarMengqing Cao <cmq0113@163.com>
parent ce75efee
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses import dataclasses
import datetime
import os import os
import pickle import pickle
import socket import socket
...@@ -14,14 +13,14 @@ import time ...@@ -14,14 +13,14 @@ import time
import uuid import uuid
from collections import deque from collections import deque
from collections.abc import Sequence from collections.abc import Sequence
from datetime import timedelta
from typing import Any, Optional from typing import Any, Optional
import torch import torch
from torch.distributed import ProcessGroup, TCPStore from torch.distributed import ProcessGroup, TCPStore
from torch.distributed.distributed_c10d import (Backend, PrefixStore, from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout, _get_default_timeout,
_unregister_process_group, _unregister_process_group)
is_nccl_available)
from torch.distributed.rendezvous import rendezvous from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs import vllm.envs as envs
...@@ -406,7 +405,7 @@ class StatelessProcessGroup: ...@@ -406,7 +405,7 @@ class StatelessProcessGroup:
port=port, port=port,
world_size=world_size, world_size=world_size,
is_master=launch_server, is_master=launch_server,
timeout=datetime.timedelta(seconds=store_timeout), timeout=timedelta(seconds=store_timeout),
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd=listen_fd, master_listen_fd=listen_fd,
) )
...@@ -419,6 +418,43 @@ class StatelessProcessGroup: ...@@ -419,6 +418,43 @@ class StatelessProcessGroup:
data_expiration_seconds=data_expiration_seconds) data_expiration_seconds=data_expiration_seconds)
def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore,
group_rank: int, group_size: int,
timeout: timedelta) -> ProcessGroup:
"""
Stateless init ProcessGroup with gloo backend compatible with
different torch versions.
"""
if is_torch_equal_or_newer("2.6"):
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
else:
options = ProcessGroup.Options(backend=backend)
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
if is_torch_equal_or_newer("2.6"):
# _set_default_backend is supported in torch >= 2.6
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
def stateless_init_torch_distributed_process_group( def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int, host: str, port: int, rank: int, world_size: int,
backend: str) -> ProcessGroup: backend: str) -> ProcessGroup:
...@@ -468,40 +504,19 @@ def stateless_init_torch_distributed_process_group( ...@@ -468,40 +504,19 @@ def stateless_init_torch_distributed_process_group(
# different systems (e.g. RPC) in case the store is multi-tenant. # different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store) prefix_store = PrefixStore(init_method, store)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
if backend == "gloo": if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo return init_gloo_process_group(backend=backend,
backend_class = ProcessGroupGloo(prefix_store, prefix_store=prefix_store,
group_rank, group_rank=group_rank,
group_size, group_size=group_size,
timeout=timeout) timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO from vllm.platforms import current_platform
device = torch.device("cpu") return current_platform.stateless_init_device_torch_dist_pg(
elif backend == "nccl": backend=backend,
assert is_nccl_available() prefix_store=prefix_store,
from torch.distributed.distributed_c10d import ProcessGroupNCCL group_rank=group_rank,
group_size=group_size,
backend_options = ProcessGroupNCCL.Options() timeout=timeout)
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
else:
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
def stateless_destroy_torch_distributed_process_group( def stateless_destroy_torch_distributed_process_group(
......
...@@ -4,10 +4,13 @@ pynvml. However, it should not initialize cuda context. ...@@ -4,10 +4,13 @@ pynvml. However, it should not initialize cuda context.
""" """
import os import os
from datetime import timedelta
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
import torch import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
...@@ -316,6 +319,36 @@ class CudaPlatformBase(Platform): ...@@ -316,6 +319,36 @@ class CudaPlatformBase(Platform):
def get_piecewise_backend_cls(cls) -> str: def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
......
...@@ -3,11 +3,13 @@ import enum ...@@ -3,11 +3,13 @@ import enum
import os import os
import platform import platform
import random import random
from datetime import timedelta
from platform import uname from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Union from typing import TYPE_CHECKING, NamedTuple, Optional, Union
import numpy as np import numpy as np
import torch import torch
from torch.distributed import PrefixStore, ProcessGroup
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -486,6 +488,20 @@ class Platform: ...@@ -486,6 +488,20 @@ class Platform:
""" """
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
"""
Init platform-specific torch distributed process group.
"""
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from datetime import timedelta
from functools import cache, lru_cache, wraps from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -387,3 +390,33 @@ class RocmPlatform(Platform): ...@@ -387,3 +390,33 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_piecewise_backend_cls(cls) -> str: def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment