Unverified Commit 89cd9235 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Roll back to use vllm custom allreduce (#3006)

parent dc188132
...@@ -12,7 +12,7 @@ import torch.library ...@@ -12,7 +12,7 @@ import torch.library
from sglang.srt.utils import is_hpu from sglang.srt.utils import is_hpu
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=False) use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
if not is_hpu(): if not is_hpu():
if use_vllm_custom_allreduce: if use_vllm_custom_allreduce:
......
from .communication_op import * from sglang.srt.distributed.communication_op import *
from .parallel_state import * from sglang.srt.distributed.parallel_state import *
from .utils import * from sglang.srt.distributed.utils import *
...@@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union ...@@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union
import torch import torch
import torch.distributed import torch.distributed
from .parallel_state import get_tp_group from sglang.srt.distributed.parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
......
...@@ -7,7 +7,6 @@ import pickle ...@@ -7,7 +7,6 @@ import pickle
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
from functools import lru_cache
from itertools import product from itertools import product
from typing import Dict, List, Optional, Sequence from typing import Dict, List, Optional, Sequence
......
...@@ -57,7 +57,7 @@ def find_nccl_library() -> str: ...@@ -57,7 +57,7 @@ def find_nccl_library() -> str:
so_file = "librccl.so.1" so_file = "librccl.so.1"
else: else:
raise ValueError("NCCL only supports CUDA and ROCm backends.") raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info("Found nccl from library %s", so_file) logger.debug("Found nccl from library %s", so_file)
return so_file return so_file
......
...@@ -313,7 +313,7 @@ class MessageQueue: ...@@ -313,7 +313,7 @@ class MessageQueue:
remote_subscribe_port=remote_subscribe_port, remote_subscribe_port=remote_subscribe_port,
) )
logger.info("vLLM message queue communication handle: %s", self.handle) logger.debug("Message queue communication handle: %s", self.handle)
def export_handle(self) -> Handle: def export_handle(self) -> Handle:
return self.handle return self.handle
......
...@@ -5,9 +5,9 @@ from typing import Optional ...@@ -5,9 +5,9 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange, repeat from einops import rearrange, repeat
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from sglang.srt.distributed import parallel_state
from sglang.srt.distributed import utils as dist_utils
from sglang.srt.layers.attention.triton_ops.prefill_attention import ( from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd, context_attention_fwd,
) )
......
...@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.utils import monkey_patch_vllm_all_gather
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -72,7 +71,6 @@ def patch_model( ...@@ -72,7 +71,6 @@ def patch_model(
try: try:
if enable_compile: if enable_compile:
_to_torch(model, reverse=False, batch_size=batch_size) _to_torch(model, reverse=False, batch_size=batch_size)
monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm backup_ca_comm = tp_group.ca_comm
# Use custom-allreduce here. # Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch, # We found the custom allreduce is much faster than the built-in allreduce in torch,
...@@ -88,7 +86,6 @@ def patch_model( ...@@ -88,7 +86,6 @@ def patch_model(
finally: finally:
if enable_compile: if enable_compile:
_to_torch(model, reverse=True, batch_size=batch_size) _to_torch(model, reverse=True, batch_size=batch_size)
monkey_patch_vllm_all_gather(reverse=True)
tp_group.ca_comm = backup_ca_comm tp_group.ca_comm = backup_ca_comm
......
...@@ -63,8 +63,8 @@ from sglang.srt.utils import ( ...@@ -63,8 +63,8 @@ from sglang.srt.utils import (
init_custom_process_group, init_custom_process_group,
is_cuda, is_cuda,
is_hip, is_hip,
monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
) )
...@@ -229,7 +229,8 @@ class ModelRunner: ...@@ -229,7 +229,8 @@ class ModelRunner:
backend = "gloo" backend = "gloo"
if not self.server_args.enable_p2p_check: if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id) monkey_patch_p2p_access_check()
if self.server_args.dist_init_addr: if self.server_args.dist_init_addr:
dist_init_method = f"tcp://{self.server_args.dist_init_addr}" dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
else: else:
......
...@@ -518,68 +518,24 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N ...@@ -518,68 +518,24 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
pass pass
def monkey_patch_vllm_p2p_access_check(gpu_id: int): def monkey_patch_p2p_access_check():
""" """
Monkey patch the slow p2p access check in vllm. Monkey patch the slow p2p access check.
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups. NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
""" """
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
# Suppress the warnings from this delete function when using sglang.bench_one_batch # Suppress the warnings from this delete function when using sglang.bench_one_batch
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce from sglang.srt.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce,
)
setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None) setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)
vllm_all_gather_backup = None
def monkey_patch_vllm_all_gather(reverse: bool = False):
"""Monkey patch all-gather to remove in-place operations."""
from torch.distributed import _functional_collectives as funcol
from vllm.distributed.parallel_state import GroupCoordinator
global vllm_all_gather_backup
if vllm_all_gather_backup is None:
vllm_all_gather_backup = GroupCoordinator.all_gather
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty(
(world_size,) + input_size, dtype=input_.dtype, device=input_.device
)
output_tensor = funcol.all_gather_tensor(
input_, gather_dim=0, group=self.device_group
).view((world_size,) + input_size)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
)
return output_tensor
if reverse:
setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup)
else:
setattr(GroupCoordinator, "all_gather", all_gather)
def monkey_patch_vllm_gguf_config(): def monkey_patch_vllm_gguf_config():
from vllm.model_executor.layers.quantization.gguf import ( from vllm.model_executor.layers.quantization.gguf import (
GGUFConfig, GGUFConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment