Unverified Commit 63e7176f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core][Refactor] move parallel_utils into vllm/distributed (#3950)

[WIP][Core][Refactor] move vllm/model_executor/parallel_utils into vllm/distributed and vllm/device_communicators (#3950)
parent 934d3662
...@@ -11,8 +11,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, ...@@ -11,8 +11,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor,
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.distributed import destroy_model_parallel
destroy_model_parallel)
from vllm.sequence import MultiModalData from vllm.sequence import MultiModalData
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
......
...@@ -8,9 +8,9 @@ import pytest ...@@ -8,9 +8,9 @@ import pytest
import ray import ray
import torch import torch
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.distributed import (broadcast_tensor_dict,
broadcast_tensor_dict, tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.test_utils import (init_test_distributed_environment, from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel) multi_process_tensor_parallel)
......
...@@ -6,9 +6,8 @@ import ray ...@@ -6,9 +6,8 @@ import ray
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.distributed.device_communicators import custom_all_reduce
tensor_model_parallel_all_reduce)
from vllm.test_utils import (init_test_distributed_environment, from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel) multi_process_tensor_parallel)
...@@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port): ...@@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port):
init_test_distributed_environment(1, world_size, rank, init_test_distributed_environment(1, world_size, rank,
distributed_init_port) distributed_init_port)
custom_ar.init_custom_ar() custom_all_reduce.init_custom_all_reduce()
for sz in test_sizes: for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]: for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with custom_ar.capture(): with custom_all_reduce.capture():
# use integers so result matches NCCL exactly # use integers so result matches NCCL exactly
inp1 = torch.randint(1, inp1 = torch.randint(1,
16, (sz, ), 16, (sz, ),
...@@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port): ...@@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port):
distributed_init_port) distributed_init_port)
sz = 1024 sz = 1024
custom_ar.init_custom_ar() custom_all_reduce.init_custom_all_reduce()
fa = custom_ar.get_handle() fa = custom_all_reduce.get_handle()
inp = torch.ones(sz, dtype=torch.float32, device=device) inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp) out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size) assert torch.allclose(out, inp * world_size)
......
...@@ -4,8 +4,8 @@ import os ...@@ -4,8 +4,8 @@ import os
import pytest import pytest
import torch import torch
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId) ncclGetUniqueId)
def distributed_run(fn, world_size): def distributed_run(fn, world_size):
......
...@@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download ...@@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
import vllm import vllm
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import destroy_model_parallel, initialize_model_parallel
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
...@@ -19,8 +20,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -19,8 +20,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils.parallel_state import (
destroy_model_parallel, initialize_model_parallel)
def cleanup(): def cleanup():
......
from .communication_op import *
from .parallel_state import *
from .utils import *
...@@ -4,12 +4,10 @@ from typing import Any, Dict, List, Optional, Union ...@@ -4,12 +4,10 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.model_executor.parallel_utils import pynccl_utils from .parallel_state import (get_tensor_model_parallel_group,
from vllm.model_executor.parallel_utils.custom_all_reduce import ( get_tensor_model_parallel_rank,
custom_all_reduce) get_tensor_model_parallel_world_size,
from vllm.model_executor.parallel_utils.parallel_state import ( is_pynccl_enabled_for_all_reduce)
get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
...@@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: ...@@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return TLDR: always assume this function modifies its input, but use the return
value as the output. value as the output.
""" """
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
custom_all_reduce)
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1: if get_tensor_model_parallel_world_size() == 1:
return input_ return input_
......
...@@ -5,8 +5,6 @@ import torch ...@@ -5,8 +5,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
try: try:
import pynvml import pynvml
...@@ -25,6 +23,9 @@ _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] ...@@ -25,6 +23,9 @@ _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
def init_custom_ar() -> None: def init_custom_ar() -> None:
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
global _CA_HANDLE global _CA_HANDLE
if _CA_HANDLE is not None: if _CA_HANDLE is not None:
return return
......
...@@ -9,8 +9,8 @@ from vllm.logger import init_logger ...@@ -9,8 +9,8 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
try: try:
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetVersion) ncclGetVersion)
except Exception as e: except Exception as e:
# in non-NVIDIA environments, we can't import the nccl module # in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs # e.g. when running on machines with AMD GPUs
......
...@@ -8,8 +8,6 @@ from typing import Optional ...@@ -8,8 +8,6 @@ from typing import Optional
import torch import torch
from vllm.model_executor.parallel_utils import pynccl_utils
# Tensor model parallel group that the current rank belongs to. # Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to. # Pipeline model parallel group that the current rank belongs to.
...@@ -266,6 +264,7 @@ def destroy_model_parallel(): ...@@ -266,6 +264,7 @@ def destroy_model_parallel():
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
from vllm.distributed.device_communicators import pynccl_utils
# Destroy the pynccl states if any. # Destroy the pynccl states if any.
pynccl_utils.destroy_process_group() pynccl_utils.destroy_process_group()
...@@ -279,6 +278,7 @@ _ENABLE_PYNCCL_FOR_ALL_REDUCE = False ...@@ -279,6 +278,7 @@ _ENABLE_PYNCCL_FOR_ALL_REDUCE = False
@contextlib.contextmanager @contextlib.contextmanager
def with_pynccl_for_all_reduce(): def with_pynccl_for_all_reduce():
from vllm.distributed.device_communicators import pynccl_utils
"""use pynccl instead of torch.distributed for all reduce""" """use pynccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1: if tp_size == 1:
......
...@@ -10,6 +10,12 @@ import torch.nn.functional as F ...@@ -10,6 +10,12 @@ import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -18,13 +24,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -18,13 +24,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import (
split_tensor_along_last_dim)
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
......
...@@ -7,10 +7,9 @@ import torch.nn as nn ...@@ -7,10 +7,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm._C import ops from vllm._C import ops
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
......
...@@ -5,13 +5,12 @@ import torch ...@@ -5,13 +5,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import (
divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -4,8 +4,7 @@ from typing import Optional ...@@ -4,8 +4,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.distributed import tensor_model_parallel_gather
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
......
...@@ -4,11 +4,9 @@ import torch ...@@ -4,11 +4,9 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce) get_tensor_model_parallel_world_size,
from vllm.model_executor.parallel_utils.parallel_state import ( tensor_model_parallel_all_reduce)
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
......
...@@ -27,6 +27,8 @@ from transformers import PretrainedConfig ...@@ -27,6 +27,8 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
...@@ -38,8 +40,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -38,8 +40,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
......
...@@ -24,6 +24,8 @@ from torch import nn ...@@ -24,6 +24,8 @@ from torch import nn
from transformers import BloomConfig from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
...@@ -33,8 +35,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -33,8 +35,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment