Unverified Commit 819fc591 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Add prefix for torch symm mem (#12506)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 7efd8b3d
"""For Now, SYMM_MEM is only supported on TP8 case
export WORLD_SIZE=1
"""For Now, TORCH_SYMM_MEM is only supported on following limited tp case
SM90: {
2: 64 * MiB, # 64 MB
4: 64 * MiB, # 64 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
},
SM100: {
2: 64 * MiB, # 64 MB
4: 64 * MiB, # 64 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
export WORLD_SIZE=8
export RANK=0
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=12345
......@@ -9,7 +22,7 @@ torchrun --nproc_per_node gpu \
--nnodes $WORLD_SIZE \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_symm_mem.py
--master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_torch_symm_mem.py
"""
import os
......@@ -22,12 +35,14 @@ from torch.distributed import ProcessGroup
from sglang.srt.distributed import init_distributed_environment
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
from sglang.srt.distributed.device_communicators.symm_mem import SymmMemCommunicator
from sglang.srt.distributed.device_communicators.torch_symm_mem import (
TorchSymmMemCommunicator,
)
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_group,
graph_capture,
initialize_model_parallel,
set_symm_mem_all_reduce,
set_torch_symm_mem_all_reduce,
)
# CI environment detection
......@@ -42,10 +57,10 @@ def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Ten
return torch_input
def symm_mem_allreduce(
symm_mem_input: torch.Tensor, symm_mem_comm: SymmMemCommunicator
def torch_symm_mem_allreduce(
torch_symm_mem_input: torch.Tensor, torch_symm_mem_comm: TorchSymmMemCommunicator
) -> torch.Tensor:
return symm_mem_comm.all_reduce(symm_mem_input)
return torch_symm_mem_comm.all_reduce(torch_symm_mem_input)
def pynccl_allreduce(
......@@ -170,7 +185,7 @@ if __name__ == "__main__":
rank = dist.get_rank()
torch.cuda.set_device(rank % 8)
device = torch.cuda.current_device()
set_symm_mem_all_reduce(True)
set_torch_symm_mem_all_reduce(True)
init_distributed_environment(
world_size=world_size,
rank=rank,
......@@ -180,7 +195,7 @@ if __name__ == "__main__":
group = get_tensor_model_parallel_group().device_group
cpu_group = get_tensor_model_parallel_group().cpu_group
pynccl_comm = get_tensor_model_parallel_group().pynccl_comm
symm_mem_comm = get_tensor_model_parallel_group().symm_mem_comm
torch_symm_mem_comm = get_tensor_model_parallel_group().torch_symm_mem_comm
dist.barrier()
profile = False
dtype = torch.bfloat16
......@@ -204,10 +219,12 @@ if __name__ == "__main__":
lambda inp: torch_allreduce(inp, group), inp_randn
)
symm_mem_eager_output, symm_mem_eager_time = _bench_eager_time(
lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn
lambda inp: torch_symm_mem_allreduce(inp, torch_symm_mem_comm),
inp_randn,
)
symm_mem_graph_output, symm_mem_graph_time = _bench_graph_time(
lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn
lambda inp: torch_symm_mem_allreduce(inp, torch_symm_mem_comm),
inp_randn,
)
# since pynccl is inplace op, this return result is not correct if graph loop > 1
_, pynccl_graph_time = _bench_graph_time(
......@@ -229,6 +246,6 @@ if __name__ == "__main__":
if rank == 0:
print_markdown_table(result)
if profile:
prof_dir = f"prof/symm_mem"
prof_dir = f"prof/torch_symm_mem"
os.makedirs(prof_dir, exist_ok=True)
ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")
MiB = 1024 * 1024
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
9: {
2: 64 * MiB, # 64 MB
4: 64 * MiB, # 64 MB
......
......@@ -7,33 +7,29 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES,
)
from sglang.srt.utils import is_cuda, is_hip
try:
import torch.distributed._symmetric_memory as torch_symm_mem
symm_mem_available = True
_is_cuda = is_cuda()
_is_hip = is_hip()
torch_symm_mem_available = False
if _is_cuda:
torch_symm_mem_available = True
except ImportError:
symm_mem_available = False
torch_symm_mem_available = False
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
symm_mem_is_available = False
if _is_hip:
symm_mem_is_available = False
if _is_cuda:
symm_mem_is_available = True
class SymmMemCommunicator:
class TorchSymmMemCommunicator:
"""
Thin wrapper around symmetric-memory collectives.
Thin wrapper around torch-symmetric-memory collectives.
This communicator:
- Validates device capability and world size.
......@@ -62,7 +58,7 @@ class SymmMemCommunicator:
self.disabled = True
if not symm_mem_available:
if not torch_symm_mem_available:
return
if isinstance(device, int):
......@@ -77,19 +73,22 @@ class SymmMemCommunicator:
self.device_capability = torch.cuda.get_device_capability(device)[0]
if self.device_capability < 9:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"TorchSymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
if (
self.world_size
not in TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]
):
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"TorchSymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.max_size = TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size
]
self.buffer = torch_symm_mem.empty(
......@@ -100,7 +99,7 @@ class SymmMemCommunicator:
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning(
"SymmMemCommunicator: symmetric memory "
"TorchSymmMemCommunicator: torch symmetric memory "
"multicast operations are not supported."
)
self.buffer = None
......@@ -108,7 +107,7 @@ class SymmMemCommunicator:
return
self.disabled = False
def should_symm_mem_allreduce(self, inp: torch.Tensor):
def should_torch_symm_mem_allreduce(self, inp: torch.Tensor):
"""
Fast-path eligibility check for a given tensor.
......@@ -135,7 +134,7 @@ class SymmMemCommunicator:
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
) -> Optional[torch.Tensor]:
"""
Perform an in-place sum all-reduce via symmetric memory.
Perform an in-place sum all-reduce via torch symmetric memory.
Args:
inp: Input tensor on the target CUDA device (bfloat16).
......
......@@ -217,14 +217,16 @@ class GroupCoordinator:
use_pynccl: bool # a hint of whether to use PyNccl
use_pymscclpp: bool # a hint of whether to use PyMsccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
use_torch_symm_mem_all_reduce: (
bool # a hint of whether to use TorchSymmMemAllReduce
)
use_message_queue_broadcaster: (
bool # a hint of whether to use message queue broadcaster
)
# communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator
symm_mem_comm: Optional[Any] # Symm mem communicator
torch_symm_mem_comm: Optional[Any] # Torch symm mem communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
......@@ -235,7 +237,7 @@ class GroupCoordinator:
use_pynccl: bool,
use_pymscclpp: bool,
use_custom_allreduce: bool,
use_torch_symm_mem: bool,
use_torch_symm_mem_all_reduce: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
use_npu_communicator: bool,
......@@ -295,7 +297,7 @@ class GroupCoordinator:
self.pynccl_use_current_stream = pynccl_use_current_stream
self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
self.use_torch_symm_mem_all_reduce = use_torch_symm_mem_all_reduce
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
self.use_npu_communicator = use_npu_communicator
......@@ -311,8 +313,8 @@ class GroupCoordinator:
from sglang.srt.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
)
from sglang.srt.distributed.device_communicators.symm_mem import (
SymmMemCommunicator,
from sglang.srt.distributed.device_communicators.torch_symm_mem import (
TorchSymmMemCommunicator,
)
if is_hip():
......@@ -363,9 +365,9 @@ class GroupCoordinator:
except Exception as e:
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if self.use_torch_symm_mem and self.world_size > 1:
self.symm_mem_comm = SymmMemCommunicator(
self.torch_symm_mem_comm: Optional[TorchSymmMemCommunicator] = None
if self.use_torch_symm_mem_all_reduce and self.world_size > 1:
self.torch_symm_mem_comm = TorchSymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
......@@ -580,11 +582,11 @@ class GroupCoordinator:
):
outplace_all_reduce_method = "pymscclpp"
elif (
self.symm_mem_comm is not None
and not self.symm_mem_comm.disabled
and self.symm_mem_comm.should_symm_mem_allreduce(input_)
self.torch_symm_mem_comm is not None
and not self.torch_symm_mem_comm.disabled
and self.torch_symm_mem_comm.should_torch_symm_mem_allreduce(input_)
):
outplace_all_reduce_method = "symm_mem"
outplace_all_reduce_method = "torch_symm_mem"
if outplace_all_reduce_method is not None:
return torch.ops.sglang.outplace_all_reduce(
input_,
......@@ -601,7 +603,7 @@ class GroupCoordinator:
ca_comm = self.ca_comm
qr_comm = self.qr_comm
pymscclpp_comm = self.pymscclpp_comm
symm_mem_comm = self.symm_mem_comm
torch_symm_mem_comm = self.torch_symm_mem_comm
assert any([qr_comm, ca_comm, pymscclpp_comm])
if outplace_all_reduce_method == "ca":
assert not ca_comm.disabled
......@@ -609,9 +611,9 @@ class GroupCoordinator:
elif outplace_all_reduce_method == "qr":
assert not qr_comm.disabled
out = qr_comm.quick_all_reduce(input_)
elif outplace_all_reduce_method == "symm_mem":
assert not symm_mem_comm.disabled
out = symm_mem_comm.all_reduce(input_)
elif outplace_all_reduce_method == "torch_symm_mem":
assert not torch_symm_mem_comm.disabled
out = torch_symm_mem_comm.all_reduce(input_)
else:
assert not pymscclpp_comm.disabled
out = pymscclpp_comm.all_reduce(input_)
......@@ -620,11 +622,11 @@ class GroupCoordinator:
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
pynccl_comm = self.pynccl_comm
symm_mem_comm = self.symm_mem_comm
torch_symm_mem_comm = self.torch_symm_mem_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.all_reduce(input_)
elif symm_mem_comm is not None and not symm_mem_comm.disabled:
symm_mem_comm.all_reduce(input_)
elif torch_symm_mem_comm is not None and not torch_symm_mem_comm.disabled:
torch_symm_mem_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
......@@ -1267,7 +1269,7 @@ def init_world_group(
use_pynccl=False,
use_pymscclpp=False,
use_custom_allreduce=False,
use_torch_symm_mem=False,
use_torch_symm_mem_all_reduce=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_npu_communicator=False,
......@@ -1284,15 +1286,15 @@ def init_model_parallel_group(
group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None,
pynccl_use_current_stream: bool = True,
use_symm_mem_allreduce: Optional[bool] = None,
use_torch_symm_mem_allreduce: Optional[bool] = None,
torch_compile: Optional[bool] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
if use_mscclpp_allreduce is None:
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
if use_symm_mem_allreduce is None:
use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE
if use_torch_symm_mem_allreduce is None:
use_torch_symm_mem_allreduce = _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
......@@ -1300,7 +1302,7 @@ def init_model_parallel_group(
use_pynccl=not (_is_npu or _is_xpu),
use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce,
use_torch_symm_mem=use_symm_mem_allreduce,
use_torch_symm_mem_all_reduce=use_torch_symm_mem_allreduce,
use_hpu_communicator=True,
use_xpu_communicator=True,
use_npu_communicator=True,
......@@ -1388,7 +1390,7 @@ logger = logging.getLogger(__name__)
_ENABLE_CUSTOM_ALL_REDUCE = True
_ENABLE_MSCCLPP_ALL_REDUCE = False
_ENABLE_SYMM_MEM_ALL_REDUCE = False
_ENABLE_TORCH_SYMM_MEM_ALL_REDUCE = False
def set_custom_all_reduce(enable: bool):
......@@ -1401,9 +1403,9 @@ def set_mscclpp_all_reduce(enable: bool):
_ENABLE_MSCCLPP_ALL_REDUCE = enable
def set_symm_mem_all_reduce(enable: bool):
global _ENABLE_SYMM_MEM_ALL_REDUCE
_ENABLE_SYMM_MEM_ALL_REDUCE = enable
def set_torch_symm_mem_all_reduce(enable: bool):
global _ENABLE_TORCH_SYMM_MEM_ALL_REDUCE
_ENABLE_TORCH_SYMM_MEM_ALL_REDUCE = enable
def init_distributed_environment(
......
......@@ -280,7 +280,7 @@ def initialize_dp_attention(
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
use_pymscclpp=False,
use_custom_allreduce=False,
use_torch_symm_mem=False,
use_torch_symm_mem_all_reduce=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_npu_communicator=False,
......
......@@ -56,7 +56,7 @@ from sglang.srt.distributed import (
initialize_model_parallel,
set_custom_all_reduce,
set_mscclpp_all_reduce,
set_symm_mem_all_reduce,
set_torch_symm_mem_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
......@@ -608,7 +608,7 @@ class ModelRunner:
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
set_torch_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
if not self.is_draft_worker:
if self.device == "cpu":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment