Unverified Commit ff2cfdb1 authored by Xu-Chen's avatar Xu-Chen Committed by GitHub
Browse files

[Feature] add disable-custom-all-reduce (#1148)


Co-authored-by: default avatarchenxu02 <chenxu02@zhihu.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent a8ae6403
...@@ -37,6 +37,7 @@ from vllm.distributed import ( ...@@ -37,6 +37,7 @@ from vllm.distributed import (
get_tp_group, get_tp_group,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
set_custom_all_reduce,
) )
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -105,6 +106,7 @@ class ModelRunner: ...@@ -105,6 +106,7 @@ class ModelRunner:
nccl_init_method = f"tcp://{server_args.nccl_init_addr}" nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
else: else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}" nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
set_custom_all_reduce(not server_args.disable_custom_all_reduce)
init_distributed_environment( init_distributed_environment(
backend="nccl", backend="nccl",
world_size=self.tp_size, world_size=self.tp_size,
......
...@@ -86,6 +86,7 @@ class ServerArgs: ...@@ -86,6 +86,7 @@ class ServerArgs:
enable_mla: bool = False enable_mla: bool = False
attention_reduce_in_fp32: bool = False attention_reduce_in_fp32: bool = False
efficient_weight_load: bool = False efficient_weight_load: bool = False
disable_custom_all_reduce: bool = False
# Distributed args # Distributed args
nccl_init_addr: Optional[str] = None nccl_init_addr: Optional[str] = None
...@@ -428,6 +429,12 @@ class ServerArgs: ...@@ -428,6 +429,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
) )
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
default=False,
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment