Unverified Commit 819fc591 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Add prefix for torch symm mem (#12506)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 7efd8b3d
"""For Now, SYMM_MEM is only supported on TP8 case """For Now, TORCH_SYMM_MEM is only supported on following limited tp case
export WORLD_SIZE=1 SM90: {
2: 64 * MiB, # 64 MB
4: 64 * MiB, # 64 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
},
SM100: {
2: 64 * MiB, # 64 MB
4: 64 * MiB, # 64 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
export WORLD_SIZE=8
export RANK=0 export RANK=0
export MASTER_ADDR=127.0.0.1 export MASTER_ADDR=127.0.0.1
export MASTER_PORT=12345 export MASTER_PORT=12345
...@@ -9,7 +22,7 @@ torchrun --nproc_per_node gpu \ ...@@ -9,7 +22,7 @@ torchrun --nproc_per_node gpu \
--nnodes $WORLD_SIZE \ --nnodes $WORLD_SIZE \
--node_rank $RANK \ --node_rank $RANK \
--master_addr $MASTER_ADDR \ --master_addr $MASTER_ADDR \
--master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_symm_mem.py --master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_torch_symm_mem.py
""" """
import os import os
...@@ -22,12 +35,14 @@ from torch.distributed import ProcessGroup ...@@ -22,12 +35,14 @@ from torch.distributed import ProcessGroup
from sglang.srt.distributed import init_distributed_environment from sglang.srt.distributed import init_distributed_environment
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
from sglang.srt.distributed.device_communicators.symm_mem import SymmMemCommunicator from sglang.srt.distributed.device_communicators.torch_symm_mem import (
TorchSymmMemCommunicator,
)
from sglang.srt.distributed.parallel_state import ( from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
graph_capture, graph_capture,
initialize_model_parallel, initialize_model_parallel,
set_symm_mem_all_reduce, set_torch_symm_mem_all_reduce,
) )
# CI environment detection # CI environment detection
...@@ -42,10 +57,10 @@ def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Ten ...@@ -42,10 +57,10 @@ def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Ten
return torch_input return torch_input
def symm_mem_allreduce( def torch_symm_mem_allreduce(
symm_mem_input: torch.Tensor, symm_mem_comm: SymmMemCommunicator torch_symm_mem_input: torch.Tensor, torch_symm_mem_comm: TorchSymmMemCommunicator
) -> torch.Tensor: ) -> torch.Tensor:
return symm_mem_comm.all_reduce(symm_mem_input) return torch_symm_mem_comm.all_reduce(torch_symm_mem_input)
def pynccl_allreduce( def pynccl_allreduce(
...@@ -170,7 +185,7 @@ if __name__ == "__main__": ...@@ -170,7 +185,7 @@ if __name__ == "__main__":
rank = dist.get_rank() rank = dist.get_rank()
torch.cuda.set_device(rank % 8) torch.cuda.set_device(rank % 8)
device = torch.cuda.current_device() device = torch.cuda.current_device()
set_symm_mem_all_reduce(True) set_torch_symm_mem_all_reduce(True)
init_distributed_environment( init_distributed_environment(
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
...@@ -180,7 +195,7 @@ if __name__ == "__main__": ...@@ -180,7 +195,7 @@ if __name__ == "__main__":
group = get_tensor_model_parallel_group().device_group group = get_tensor_model_parallel_group().device_group
cpu_group = get_tensor_model_parallel_group().cpu_group cpu_group = get_tensor_model_parallel_group().cpu_group
pynccl_comm = get_tensor_model_parallel_group().pynccl_comm pynccl_comm = get_tensor_model_parallel_group().pynccl_comm
symm_mem_comm = get_tensor_model_parallel_group().symm_mem_comm torch_symm_mem_comm = get_tensor_model_parallel_group().torch_symm_mem_comm
dist.barrier() dist.barrier()
profile = False profile = False
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -204,10 +219,12 @@ if __name__ == "__main__": ...@@ -204,10 +219,12 @@ if __name__ == "__main__":
lambda inp: torch_allreduce(inp, group), inp_randn lambda inp: torch_allreduce(inp, group), inp_randn
) )
symm_mem_eager_output, symm_mem_eager_time = _bench_eager_time( symm_mem_eager_output, symm_mem_eager_time = _bench_eager_time(
lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn lambda inp: torch_symm_mem_allreduce(inp, torch_symm_mem_comm),
inp_randn,
) )
symm_mem_graph_output, symm_mem_graph_time = _bench_graph_time( symm_mem_graph_output, symm_mem_graph_time = _bench_graph_time(
lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn lambda inp: torch_symm_mem_allreduce(inp, torch_symm_mem_comm),
inp_randn,
) )
# since pynccl is inplace op, this return result is not correct if graph loop > 1 # since pynccl is inplace op, this return result is not correct if graph loop > 1
_, pynccl_graph_time = _bench_graph_time( _, pynccl_graph_time = _bench_graph_time(
...@@ -229,6 +246,6 @@ if __name__ == "__main__": ...@@ -229,6 +246,6 @@ if __name__ == "__main__":
if rank == 0: if rank == 0:
print_markdown_table(result) print_markdown_table(result)
if profile: if profile:
prof_dir = f"prof/symm_mem" prof_dir = f"prof/torch_symm_mem"
os.makedirs(prof_dir, exist_ok=True) os.makedirs(prof_dir, exist_ok=True)
ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz") ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")
MiB = 1024 * 1024 MiB = 1024 * 1024
SYMM_MEM_ALL_REDUCE_MAX_SIZES = { TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
9: { 9: {
2: 64 * MiB, # 64 MB 2: 64 * MiB, # 64 MB
4: 64 * MiB, # 64 MB 4: 64 * MiB, # 64 MB
......
...@@ -7,33 +7,29 @@ import torch.distributed as dist ...@@ -7,33 +7,29 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from sglang.srt.distributed.device_communicators.all_reduce_utils import ( from sglang.srt.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES, TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES,
) )
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
try: try:
import torch.distributed._symmetric_memory as torch_symm_mem import torch.distributed._symmetric_memory as torch_symm_mem
symm_mem_available = True _is_cuda = is_cuda()
_is_hip = is_hip()
torch_symm_mem_available = False
if _is_cuda:
torch_symm_mem_available = True
except ImportError: except ImportError:
symm_mem_available = False torch_symm_mem_available = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
symm_mem_is_available = False
if _is_hip:
symm_mem_is_available = False
if _is_cuda:
symm_mem_is_available = True
class TorchSymmMemCommunicator:
class SymmMemCommunicator:
""" """
Thin wrapper around symmetric-memory collectives. Thin wrapper around torch-symmetric-memory collectives.
This communicator: This communicator:
- Validates device capability and world size. - Validates device capability and world size.
...@@ -62,7 +58,7 @@ class SymmMemCommunicator: ...@@ -62,7 +58,7 @@ class SymmMemCommunicator:
self.disabled = True self.disabled = True
if not symm_mem_available: if not torch_symm_mem_available:
return return
if isinstance(device, int): if isinstance(device, int):
...@@ -77,19 +73,22 @@ class SymmMemCommunicator: ...@@ -77,19 +73,22 @@ class SymmMemCommunicator:
self.device_capability = torch.cuda.get_device_capability(device)[0] self.device_capability = torch.cuda.get_device_capability(device)[0]
if self.device_capability < 9: if self.device_capability < 9:
logger.warning( logger.warning(
"SymmMemCommunicator: Device capability %s not supported, " "TorchSymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.", "communicator is not available.",
self.device_capability, self.device_capability,
) )
return return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]: if (
self.world_size
not in TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]
):
logger.warning( logger.warning(
"SymmMemCommunicator: World size %d not supported, " "TorchSymmMemCommunicator: World size %d not supported, "
"communicator is not available.", "communicator is not available.",
self.world_size, self.world_size,
) )
return return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ self.max_size = TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size self.world_size
] ]
self.buffer = torch_symm_mem.empty( self.buffer = torch_symm_mem.empty(
...@@ -100,7 +99,7 @@ class SymmMemCommunicator: ...@@ -100,7 +99,7 @@ class SymmMemCommunicator:
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0: if handle.multicast_ptr == 0:
logger.warning( logger.warning(
"SymmMemCommunicator: symmetric memory " "TorchSymmMemCommunicator: torch symmetric memory "
"multicast operations are not supported." "multicast operations are not supported."
) )
self.buffer = None self.buffer = None
...@@ -108,7 +107,7 @@ class SymmMemCommunicator: ...@@ -108,7 +107,7 @@ class SymmMemCommunicator:
return return
self.disabled = False self.disabled = False
def should_symm_mem_allreduce(self, inp: torch.Tensor): def should_torch_symm_mem_allreduce(self, inp: torch.Tensor):
""" """
Fast-path eligibility check for a given tensor. Fast-path eligibility check for a given tensor.
...@@ -135,7 +134,7 @@ class SymmMemCommunicator: ...@@ -135,7 +134,7 @@ class SymmMemCommunicator:
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Perform an in-place sum all-reduce via symmetric memory. Perform an in-place sum all-reduce via torch symmetric memory.
Args: Args:
inp: Input tensor on the target CUDA device (bfloat16). inp: Input tensor on the target CUDA device (bfloat16).
......
...@@ -217,14 +217,16 @@ class GroupCoordinator: ...@@ -217,14 +217,16 @@ class GroupCoordinator:
use_pynccl: bool # a hint of whether to use PyNccl use_pynccl: bool # a hint of whether to use PyNccl
use_pymscclpp: bool # a hint of whether to use PyMsccl use_pymscclpp: bool # a hint of whether to use PyMsccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce use_torch_symm_mem_all_reduce: (
bool # a hint of whether to use TorchSymmMemAllReduce
)
use_message_queue_broadcaster: ( use_message_queue_broadcaster: (
bool # a hint of whether to use message queue broadcaster bool # a hint of whether to use message queue broadcaster
) )
# communicators are only created for world size > 1 # communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator ca_comm: Optional[Any] # Custom allreduce communicator
symm_mem_comm: Optional[Any] # Symm mem communicator torch_symm_mem_comm: Optional[Any] # Torch symm mem communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__( def __init__(
...@@ -235,7 +237,7 @@ class GroupCoordinator: ...@@ -235,7 +237,7 @@ class GroupCoordinator:
use_pynccl: bool, use_pynccl: bool,
use_pymscclpp: bool, use_pymscclpp: bool,
use_custom_allreduce: bool, use_custom_allreduce: bool,
use_torch_symm_mem: bool, use_torch_symm_mem_all_reduce: bool,
use_hpu_communicator: bool, use_hpu_communicator: bool,
use_xpu_communicator: bool, use_xpu_communicator: bool,
use_npu_communicator: bool, use_npu_communicator: bool,
...@@ -295,7 +297,7 @@ class GroupCoordinator: ...@@ -295,7 +297,7 @@ class GroupCoordinator:
self.pynccl_use_current_stream = pynccl_use_current_stream self.pynccl_use_current_stream = pynccl_use_current_stream
self.use_pymscclpp = use_pymscclpp self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem self.use_torch_symm_mem_all_reduce = use_torch_symm_mem_all_reduce
self.use_hpu_communicator = use_hpu_communicator self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator self.use_xpu_communicator = use_xpu_communicator
self.use_npu_communicator = use_npu_communicator self.use_npu_communicator = use_npu_communicator
...@@ -311,8 +313,8 @@ class GroupCoordinator: ...@@ -311,8 +313,8 @@ class GroupCoordinator:
from sglang.srt.distributed.device_communicators.pynccl import ( from sglang.srt.distributed.device_communicators.pynccl import (
PyNcclCommunicator, PyNcclCommunicator,
) )
from sglang.srt.distributed.device_communicators.symm_mem import ( from sglang.srt.distributed.device_communicators.torch_symm_mem import (
SymmMemCommunicator, TorchSymmMemCommunicator,
) )
if is_hip(): if is_hip():
...@@ -363,9 +365,9 @@ class GroupCoordinator: ...@@ -363,9 +365,9 @@ class GroupCoordinator:
except Exception as e: except Exception as e:
logger.warning(f"Failed to initialize QuickAllReduce: {e}") logger.warning(f"Failed to initialize QuickAllReduce: {e}")
self.symm_mem_comm: Optional[SymmMemCommunicator] = None self.torch_symm_mem_comm: Optional[TorchSymmMemCommunicator] = None
if self.use_torch_symm_mem and self.world_size > 1: if self.use_torch_symm_mem_all_reduce and self.world_size > 1:
self.symm_mem_comm = SymmMemCommunicator( self.torch_symm_mem_comm = TorchSymmMemCommunicator(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
) )
...@@ -580,11 +582,11 @@ class GroupCoordinator: ...@@ -580,11 +582,11 @@ class GroupCoordinator:
): ):
outplace_all_reduce_method = "pymscclpp" outplace_all_reduce_method = "pymscclpp"
elif ( elif (
self.symm_mem_comm is not None self.torch_symm_mem_comm is not None
and not self.symm_mem_comm.disabled and not self.torch_symm_mem_comm.disabled
and self.symm_mem_comm.should_symm_mem_allreduce(input_) and self.torch_symm_mem_comm.should_torch_symm_mem_allreduce(input_)
): ):
outplace_all_reduce_method = "symm_mem" outplace_all_reduce_method = "torch_symm_mem"
if outplace_all_reduce_method is not None: if outplace_all_reduce_method is not None:
return torch.ops.sglang.outplace_all_reduce( return torch.ops.sglang.outplace_all_reduce(
input_, input_,
...@@ -601,7 +603,7 @@ class GroupCoordinator: ...@@ -601,7 +603,7 @@ class GroupCoordinator:
ca_comm = self.ca_comm ca_comm = self.ca_comm
qr_comm = self.qr_comm qr_comm = self.qr_comm
pymscclpp_comm = self.pymscclpp_comm pymscclpp_comm = self.pymscclpp_comm
symm_mem_comm = self.symm_mem_comm torch_symm_mem_comm = self.torch_symm_mem_comm
assert any([qr_comm, ca_comm, pymscclpp_comm]) assert any([qr_comm, ca_comm, pymscclpp_comm])
if outplace_all_reduce_method == "ca": if outplace_all_reduce_method == "ca":
assert not ca_comm.disabled assert not ca_comm.disabled
...@@ -609,9 +611,9 @@ class GroupCoordinator: ...@@ -609,9 +611,9 @@ class GroupCoordinator:
elif outplace_all_reduce_method == "qr": elif outplace_all_reduce_method == "qr":
assert not qr_comm.disabled assert not qr_comm.disabled
out = qr_comm.quick_all_reduce(input_) out = qr_comm.quick_all_reduce(input_)
elif outplace_all_reduce_method == "symm_mem": elif outplace_all_reduce_method == "torch_symm_mem":
assert not symm_mem_comm.disabled assert not torch_symm_mem_comm.disabled
out = symm_mem_comm.all_reduce(input_) out = torch_symm_mem_comm.all_reduce(input_)
else: else:
assert not pymscclpp_comm.disabled assert not pymscclpp_comm.disabled
out = pymscclpp_comm.all_reduce(input_) out = pymscclpp_comm.all_reduce(input_)
...@@ -620,11 +622,11 @@ class GroupCoordinator: ...@@ -620,11 +622,11 @@ class GroupCoordinator:
def _all_reduce_in_place(self, input_: torch.Tensor) -> None: def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
symm_mem_comm = self.symm_mem_comm torch_symm_mem_comm = self.torch_symm_mem_comm
if pynccl_comm is not None and not pynccl_comm.disabled: if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.all_reduce(input_) pynccl_comm.all_reduce(input_)
elif symm_mem_comm is not None and not symm_mem_comm.disabled: elif torch_symm_mem_comm is not None and not torch_symm_mem_comm.disabled:
symm_mem_comm.all_reduce(input_) torch_symm_mem_comm.all_reduce(input_)
else: else:
torch.distributed.all_reduce(input_, group=self.device_group) torch.distributed.all_reduce(input_, group=self.device_group)
...@@ -1267,7 +1269,7 @@ def init_world_group( ...@@ -1267,7 +1269,7 @@ def init_world_group(
use_pynccl=False, use_pynccl=False,
use_pymscclpp=False, use_pymscclpp=False,
use_custom_allreduce=False, use_custom_allreduce=False,
use_torch_symm_mem=False, use_torch_symm_mem_all_reduce=False,
use_hpu_communicator=False, use_hpu_communicator=False,
use_xpu_communicator=False, use_xpu_communicator=False,
use_npu_communicator=False, use_npu_communicator=False,
...@@ -1284,15 +1286,15 @@ def init_model_parallel_group( ...@@ -1284,15 +1286,15 @@ def init_model_parallel_group(
group_name: Optional[str] = None, group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None, use_mscclpp_allreduce: Optional[bool] = None,
pynccl_use_current_stream: bool = True, pynccl_use_current_stream: bool = True,
use_symm_mem_allreduce: Optional[bool] = None, use_torch_symm_mem_allreduce: Optional[bool] = None,
torch_compile: Optional[bool] = None, torch_compile: Optional[bool] = None,
) -> GroupCoordinator: ) -> GroupCoordinator:
if use_custom_allreduce is None: if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
if use_mscclpp_allreduce is None: if use_mscclpp_allreduce is None:
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
if use_symm_mem_allreduce is None: if use_torch_symm_mem_allreduce is None:
use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE use_torch_symm_mem_allreduce = _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE
return GroupCoordinator( return GroupCoordinator(
group_ranks=group_ranks, group_ranks=group_ranks,
local_rank=local_rank, local_rank=local_rank,
...@@ -1300,7 +1302,7 @@ def init_model_parallel_group( ...@@ -1300,7 +1302,7 @@ def init_model_parallel_group(
use_pynccl=not (_is_npu or _is_xpu), use_pynccl=not (_is_npu or _is_xpu),
use_pymscclpp=use_mscclpp_allreduce, use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce, use_custom_allreduce=use_custom_allreduce,
use_torch_symm_mem=use_symm_mem_allreduce, use_torch_symm_mem_all_reduce=use_torch_symm_mem_allreduce,
use_hpu_communicator=True, use_hpu_communicator=True,
use_xpu_communicator=True, use_xpu_communicator=True,
use_npu_communicator=True, use_npu_communicator=True,
...@@ -1388,7 +1390,7 @@ logger = logging.getLogger(__name__) ...@@ -1388,7 +1390,7 @@ logger = logging.getLogger(__name__)
_ENABLE_CUSTOM_ALL_REDUCE = True _ENABLE_CUSTOM_ALL_REDUCE = True
_ENABLE_MSCCLPP_ALL_REDUCE = False _ENABLE_MSCCLPP_ALL_REDUCE = False
_ENABLE_SYMM_MEM_ALL_REDUCE = False _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE = False
def set_custom_all_reduce(enable: bool): def set_custom_all_reduce(enable: bool):
...@@ -1401,9 +1403,9 @@ def set_mscclpp_all_reduce(enable: bool): ...@@ -1401,9 +1403,9 @@ def set_mscclpp_all_reduce(enable: bool):
_ENABLE_MSCCLPP_ALL_REDUCE = enable _ENABLE_MSCCLPP_ALL_REDUCE = enable
def set_symm_mem_all_reduce(enable: bool): def set_torch_symm_mem_all_reduce(enable: bool):
global _ENABLE_SYMM_MEM_ALL_REDUCE global _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE
_ENABLE_SYMM_MEM_ALL_REDUCE = enable _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE = enable
def init_distributed_environment( def init_distributed_environment(
......
...@@ -280,7 +280,7 @@ def initialize_dp_attention( ...@@ -280,7 +280,7 @@ def initialize_dp_attention(
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP, use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
use_pymscclpp=False, use_pymscclpp=False,
use_custom_allreduce=False, use_custom_allreduce=False,
use_torch_symm_mem=False, use_torch_symm_mem_all_reduce=False,
use_hpu_communicator=False, use_hpu_communicator=False,
use_xpu_communicator=False, use_xpu_communicator=False,
use_npu_communicator=False, use_npu_communicator=False,
......
...@@ -56,7 +56,7 @@ from sglang.srt.distributed import ( ...@@ -56,7 +56,7 @@ from sglang.srt.distributed import (
initialize_model_parallel, initialize_model_parallel,
set_custom_all_reduce, set_custom_all_reduce,
set_mscclpp_all_reduce, set_mscclpp_all_reduce,
set_symm_mem_all_reduce, set_torch_symm_mem_all_reduce,
) )
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
...@@ -608,7 +608,7 @@ class ModelRunner: ...@@ -608,7 +608,7 @@ class ModelRunner:
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
set_mscclpp_all_reduce(self.server_args.enable_mscclpp) set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem) set_torch_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
if not self.is_draft_worker: if not self.is_draft_worker:
if self.device == "cpu": if self.device == "cpu":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment