Unverified Commit 63e7176f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core][Refactor] move parallel_utils into vllm/distributed (#3950)

[WIP][Core][Refactor] move vllm/model_executor/parallel_utils into vllm/distributed and vllm/device_communicators (#3950)
parent 934d3662
......@@ -11,8 +11,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor,
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.model_executor.parallel_utils.parallel_state import (
destroy_model_parallel)
from vllm.distributed import destroy_model_parallel
from vllm.sequence import MultiModalData
from vllm.transformers_utils.tokenizer import get_tokenizer
......
......@@ -8,8 +8,8 @@ import pytest
import ray
import torch
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict, tensor_model_parallel_all_gather,
from vllm.distributed import (broadcast_tensor_dict,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
......
......@@ -6,9 +6,8 @@ import ray
import torch
import torch.distributed as dist
from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators import custom_all_reduce
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
......@@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port):
init_test_distributed_environment(1, world_size, rank,
distributed_init_port)
custom_ar.init_custom_ar()
custom_all_reduce.init_custom_all_reduce()
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with custom_ar.capture():
with custom_all_reduce.capture():
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
......@@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port):
distributed_init_port)
sz = 1024
custom_ar.init_custom_ar()
fa = custom_ar.get_handle()
custom_all_reduce.init_custom_all_reduce()
fa = custom_all_reduce.get_handle()
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size)
......
......@@ -4,7 +4,7 @@ import os
import pytest
import torch
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
......
......@@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
import vllm
from vllm.config import LoRAConfig
from vllm.distributed import destroy_model_parallel, initialize_model_parallel
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
......@@ -19,8 +20,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils.parallel_state import (
destroy_model_parallel, initialize_model_parallel)
def cleanup():
......
from .communication_op import *
from .parallel_state import *
from .utils import *
......@@ -4,12 +4,10 @@ from typing import Any, Dict, List, Optional, Union
import torch
from torch.distributed import ProcessGroup
from vllm.model_executor.parallel_utils import pynccl_utils
from vllm.model_executor.parallel_utils.custom_all_reduce import (
custom_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)
from .parallel_state import (get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
is_pynccl_enabled_for_all_reduce)
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
......@@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
custom_all_reduce)
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
......
......@@ -5,8 +5,6 @@ import torch
import torch.distributed as dist
from vllm.logger import init_logger
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
try:
import pynvml
......@@ -25,6 +23,9 @@ _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
def init_custom_ar() -> None:
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
global _CA_HANDLE
if _CA_HANDLE is not None:
return
......
......@@ -9,7 +9,7 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
try:
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetVersion)
except Exception as e:
# in non-NVIDIA environments, we can't import the nccl module
......
......@@ -8,8 +8,6 @@ from typing import Optional
import torch
from vllm.model_executor.parallel_utils import pynccl_utils
# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to.
......@@ -266,6 +264,7 @@ def destroy_model_parallel():
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None
from vllm.distributed.device_communicators import pynccl_utils
# Destroy the pynccl states if any.
pynccl_utils.destroy_process_group()
......@@ -279,6 +278,7 @@ _ENABLE_PYNCCL_FOR_ALL_REDUCE = False
@contextlib.contextmanager
def with_pynccl_for_all_reduce():
from vllm.distributed.device_communicators import pynccl_utils
"""use pynccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
......
......@@ -10,6 +10,12 @@ import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
......@@ -18,13 +24,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import (
split_tensor_along_last_dim)
if TYPE_CHECKING:
pass
......
......@@ -7,10 +7,9 @@ import torch.nn as nn
import torch.nn.functional as F
from vllm._C import ops
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.utils import set_weight_attrs
......
......@@ -5,13 +5,12 @@ import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import (
divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
......
......@@ -4,8 +4,7 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather)
from vllm.distributed import tensor_model_parallel_gather
from vllm.model_executor.sampling_metadata import SamplingMetadata
......
......@@ -4,11 +4,9 @@ import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.communication_op import (
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
......
......@@ -27,6 +27,8 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
......@@ -38,8 +40,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
......
......@@ -24,6 +24,8 @@ from torch import nn
from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
......@@ -33,8 +35,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment