Unverified Commit 43c4f3d7 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Begin deprecation of `get_tensor_model_*_group` (#22494)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 1712543d
...@@ -10,8 +10,7 @@ import torch.distributed as dist ...@@ -10,8 +10,7 @@ import torch.distributed as dist
from vllm.distributed.communication_op import ( # noqa from vllm.distributed.communication_op import ( # noqa
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, from vllm.distributed.parallel_state import get_tp_group, graph_capture
get_tp_group, graph_capture)
from ..utils import (ensure_model_parallel_initialized, from ..utils import (ensure_model_parallel_initialized,
init_test_distributed_environment, multi_process_parallel) init_test_distributed_environment, multi_process_parallel)
...@@ -37,7 +36,7 @@ def graph_allreduce( ...@@ -37,7 +36,7 @@ def graph_allreduce(
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port) distributed_init_port)
ensure_model_parallel_initialized(tp_size, pp_size) ensure_model_parallel_initialized(tp_size, pp_size)
group = get_tensor_model_parallel_group().device_group group = get_tp_group().device_group
# A small all_reduce for warmup. # A small all_reduce for warmup.
# this is needed because device communicators might be created lazily # this is needed because device communicators might be created lazily
......
...@@ -10,8 +10,7 @@ import torch.distributed as dist ...@@ -10,8 +10,7 @@ import torch.distributed as dist
from vllm.distributed.communication_op import ( # noqa from vllm.distributed.communication_op import ( # noqa
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, from vllm.distributed.parallel_state import get_tp_group, graph_capture
get_tp_group, graph_capture)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import (ensure_model_parallel_initialized, from ..utils import (ensure_model_parallel_initialized,
...@@ -42,7 +41,7 @@ def graph_quickreduce( ...@@ -42,7 +41,7 @@ def graph_quickreduce(
init_test_distributed_environment(tp_size, pp_size, rank, init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port) distributed_init_port)
ensure_model_parallel_initialized(tp_size, pp_size) ensure_model_parallel_initialized(tp_size, pp_size)
group = get_tensor_model_parallel_group().device_group group = get_tp_group().device_group
# A small all_reduce for warmup. # A small all_reduce for warmup.
# this is needed because device communicators might be created lazily # this is needed because device communicators might be created lazily
......
...@@ -36,6 +36,7 @@ from unittest.mock import patch ...@@ -36,6 +36,7 @@ from unittest.mock import patch
import torch import torch
import torch.distributed import torch.distributed
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup
from typing_extensions import deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import ( from vllm.distributed.device_communicators.base_device_communicator import (
...@@ -894,8 +895,12 @@ def get_tp_group() -> GroupCoordinator: ...@@ -894,8 +895,12 @@ def get_tp_group() -> GroupCoordinator:
return _TP return _TP
# kept for backward compatibility @deprecated("`get_tensor_model_parallel_group` has been replaced with "
get_tensor_model_parallel_group = get_tp_group "`get_tp_group` and may be removed after v0.12. Please use "
"`get_tp_group` instead.")
def get_tensor_model_parallel_group():
return get_tp_group()
_PP: Optional[GroupCoordinator] = None _PP: Optional[GroupCoordinator] = None
...@@ -921,8 +926,11 @@ def get_pp_group() -> GroupCoordinator: ...@@ -921,8 +926,11 @@ def get_pp_group() -> GroupCoordinator:
return _PP return _PP
# kept for backward compatibility @deprecated("`get_pipeline_model_parallel_group` has been replaced with "
get_pipeline_model_parallel_group = get_pp_group "`get_pp_group` and may be removed in v0.12. Please use "
"`get_pp_group` instead.")
def get_pipeline_model_parallel_group():
return get_pp_group()
@contextmanager @contextmanager
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment