"src/vscode:/vscode.git/clone" did not exist on "b62d9a1fdc0910bb864340b7bc29e86f6aa31d47"
Unverified Commit 28d4d472 authored by li haoyang's avatar li haoyang Committed by GitHub
Browse files

[Feature] Integrate quick allreduce and select the best allreduce implementation (#6619)


Signed-off-by: default avatarHaoyang Li <Haoyang.Li@amd.com>
Co-authored-by: default avatarilmarkov <imarkov@redhat.com>
parent f4674df6
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import logging import logging
from typing import List, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -114,6 +114,34 @@ else: ...@@ -114,6 +114,34 @@ else:
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp) return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
# ROCM custom quick allreduce
def init_custom_qr(
rank: int, world_size: int, qr_max_size: Optional[int] = None
) -> int:
return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
def qr_get_handle(fa: int) -> torch.Tensor:
return sgl_kernel.allreduce.qr_get_handle(fa)
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
sgl_kernel.allreduce.qr_open_handles(fa, handles)
def qr_all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
quant_level: int,
cast_bf2half: bool,
) -> None:
sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
def qr_destroy(fa: int) -> None:
sgl_kernel.allreduce.qr_destroy(fa)
def qr_max_size() -> int:
return sgl_kernel.allreduce.qr_max_size()
def mscclpp_generate_unique_id() -> bytes: def mscclpp_generate_unique_id() -> bytes:
return sgl_kernel.allreduce.mscclpp_generate_unique_id() return sgl_kernel.allreduce.mscclpp_generate_unique_id()
......
...@@ -4,18 +4,18 @@ import ctypes ...@@ -4,18 +4,18 @@ import ctypes
import logging import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from typing import Any, List, Optional, Union
from typing import Any, Callable, List, Optional, TypeVar, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
from sglang.srt import _custom_ops as ops from sglang.srt import _custom_ops as ops
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check, gpu_p2p_access_check,
is_full_nvlink,
is_weak_contiguous,
) )
from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
...@@ -25,23 +25,6 @@ logger = logging.getLogger(__name__) ...@@ -25,23 +25,6 @@ logger = logging.getLogger(__name__)
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
if _is_cuda:
try:
import pynvml
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
if _is_hip:
try:
from amdsmi import (
AmdSmiException,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
amdsmi_topo_get_link_type,
)
except ImportError as e:
logger.warning("Failed to import amdsmi with %r", e)
try: try:
if ops.use_vllm_custom_allreduce and not _is_hip: if ops.use_vllm_custom_allreduce and not _is_hip:
...@@ -57,70 +40,6 @@ except Exception: ...@@ -57,70 +40,6 @@ except Exception:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_P = ParamSpec("_P")
_R = TypeVar("_R")
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if _is_hip:
try:
amdsmi_init()
return fn(*args, **kwargs)
finally:
amdsmi_shut_down()
else:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
return wrapper
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
if _is_hip:
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
# type is 2 for XGMI
if link_type["hops"] != 1 or link_type["type"] != 2:
return False
except AmdSmiException as error:
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
return False
return True
else:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return False
return True
def _can_p2p(rank: int, world_size: int) -> bool: def _can_p2p(rank: int, world_size: int) -> bool:
# SGLANG_SKIP_P2P_CHECK can be set to False in sglang # SGLANG_SKIP_P2P_CHECK can be set to False in sglang
...@@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool: ...@@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool:
return True return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
class CustomAllreduce: class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
_MAX_CAR_SIZE = 8192 * 1024 _MAX_CAR_SIZE = 8192 * 1024
......
...@@ -8,17 +8,44 @@ import pickle ...@@ -8,17 +8,44 @@ import pickle
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
from functools import wraps
from itertools import product from itertools import product
from typing import Dict, List, Optional, Sequence from typing import Callable, Dict, List, Optional, Sequence, TypeVar
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from typing_extensions import ParamSpec
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
try:
import pynvml
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
if _is_hip:
try:
from amdsmi import (
AmdSmiException,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
amdsmi_topo_get_link_type,
)
except ImportError as e:
logger.warning("Failed to import amdsmi with %r", e)
_P = ParamSpec("_P")
_R = TypeVar("_R")
def update_environment_variables(envs: Dict[str, str]): def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items(): for k, v in envs.items():
...@@ -282,6 +309,74 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: ...@@ -282,6 +309,74 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
return _gpu_p2p_access_cache[f"{src}->{tgt}"] return _gpu_p2p_access_cache[f"{src}->{tgt}"]
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if _is_hip:
try:
amdsmi_init()
return fn(*args, **kwargs)
finally:
amdsmi_shut_down()
else:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
return wrapper
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
if _is_hip:
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
# type is 2 for XGMI
if link_type["hops"] != 1 or link_type["type"] != 2:
return False
except AmdSmiException as error:
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
return False
return True
else:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return False
return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
__all__ = ["gpu_p2p_access_check"] __all__ = ["gpu_p2p_access_check"]
if __name__ == "__main__": if __name__ == "__main__":
......
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from enum import Enum
from typing import Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt import _custom_ops as ops
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
is_full_nvlink,
is_weak_contiguous,
)
from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
try:
ops.qr_max_size()
quick_ar = True
except Exception:
# For CPUs and CUDA
quick_ar = False
def qr_rocm_arch_available():
if not _is_hip:
return False
try:
props = torch.cuda.get_device_properties(0)
gcn_arch = getattr(props, "gcnArchName", "")
supported_archs = ["gfx94", "gfx95"]
return any(gfx in gcn_arch for gfx in supported_archs)
except Exception as e:
logger.warning("Failed to determine ROCm for quick allreduce: %s", e)
return False
class QuickReduceRegime(Enum):
FP = 0
INT8 = 1
INT6 = 2
INT4 = 3
NONE = 4
MB = 1024 * 1024
class QuickAllReduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 8]
_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
# The following data is based on kernel tests.
# In this order [FP, INT8, INT6, INT4].
_QR_MIN_SIZE = {
(torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
(torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
(torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
(torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
(torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
}
def __init__(
self, group: ProcessGroup, device: Union[int, str, torch.device]
) -> None:
"""
Custom allreduce provides non-destructive acceleration and is
available for CUDA and ROCm MI300 series.
Custom quick allreduce leverages quantization for further
acceleration on ROCm. It currently supports Q8, Q6, and Q4
quantization formats and FP(float16, bfloat16).
Quick allreduce is designed as a complement to custom allreduce.
Its initialization requires even stricter conditions.
Only the ROCm MI300 series is supported for quick allreduce at
this time.
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self.disabled = True
if not qr_rocm_arch_available():
logger.debug(
"Custom quick allreduce is only supported on ROCm MI300 series."
)
return
if not quick_ar:
# disable because of missing quick reduce library
# e.g. in a cuda environment
logger.info(
"Custom quick allreduce is disabled because "
"of missing custom quick allreduce library"
)
return
self.group = group
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "Custom quick allreduce should be attached to a non-NCCL group."
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom quick allreduce for
# multi-node case.
logger.warning(
"Custom quick allreduce is disabled because this "
"process group spans across nodes."
)
return
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
self.rank = rank
self.world_size = world_size
if world_size == 1:
# No need to initialize QuickReduce for single GPU case.
return
if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom quick allreduce is disabled due to an "
"unsupported world size: %d. Supported world sizes: %s.",
world_size,
str(QuickAllReduce._SUPPORTED_WORLD_SIZES),
)
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(self.world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom quick allreduce is not supported
# this checks hardware and driver support for NVLink
if _is_cuda or _is_hip:
self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size)
if self.world_size > 2 and not self.fully_connected:
logger.debug(
"Custom quick allreduce is disabled because it's not supported "
"on more than two PCIe-only GPUs. "
)
return
self.init_quick_all_reduce()
def init_quick_all_reduce(self):
# On RocM, bfloat16 kernels are slower than fp16
# due to slower match operations
# If environment variable is set to 1, we convert input to fp16
self.use_fp16_kernels = int(
os.environ.get("ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", 1)
)
regime_str = os.environ.get("ROCM_QUICK_REDUCE_QUANTIZATION", "NONE")
if regime_str not in QuickReduceRegime.__members__:
logger.warning(
"Custom quick allreduce:",
f"Invalid quantization level: {regime_str}. "
"Supported levels: "
f"{list(QuickReduceRegime.__members__.keys())}",
)
return
if regime_str == "NONE":
logger.debug(
"Custom quick allreduce is disabled based "
"on env variable "
"ROCM_QUICK_REDUCE_QUANTIZATION='NONE'"
)
return
self.qr_quant_level = QuickReduceRegime[regime_str]
# TODO: If the dtype is not bfloat16 or then float16,
# quickallreduce should not be created.
# ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
qr_max_size = int(os.environ.get("ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", 0))
if qr_max_size > 0:
if qr_max_size < 1:
logger.info(
"You should not set a max_size smaller than 1MB, which can "
"lead to error or degradation to custom allreduce or rccl."
)
qr_max_size = qr_max_size * MB
# If qr_max_size is None, then 2GB is used by default.
self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size()
self.create_shared_buffer()
self.disabled = False
def create_shared_buffer(self):
"""
Creates a shared buffer for quickreduce.
Has to be called after init_custom_qr
"""
handle = ops.qr_get_handle(self._ptr)
world_size = dist.get_world_size(group=self.group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=self.group)
ops.qr_open_handles(self._ptr, handles)
def should_quick_allreduce(self, inp: torch.Tensor):
"""
Check if quickreduce is available
"""
if self.disabled:
return False
if inp.dtype not in self._SUPPORTED_DTYPES:
return False
inp_size = inp.numel() * inp.element_size()
# custom quick allreduce requires input byte size to be
# multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
dtype = inp.dtype
if self.use_fp16_kernels:
dtype = torch.float16
return (
inp_size <= self.qr_max_size
and inp_size
>= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]
)
def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
"""Performs an out-of-place custom quick all reduce."""
# quick allreduce doesn't require a separate graph mode,
# as QR uses static IPC buffer.
if out is None:
out = torch.empty_like(inp)
ops.qr_all_reduce(
self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels
)
return out
def close(self):
if not self.disabled and getattr(self, "_ptr", None):
if ops is not None:
ops.qr_destroy(self._ptr)
self._ptr = 0
self.disabled = True
def __del__(self):
self.close()
...@@ -44,6 +44,7 @@ from sglang.srt.utils import ( ...@@ -44,6 +44,7 @@ from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_int_env_var, get_int_env_var,
is_cuda_alike, is_cuda_alike,
is_hip,
is_npu, is_npu,
is_shm_available, is_shm_available,
supports_custom_op, supports_custom_op,
...@@ -126,14 +127,18 @@ if supports_custom_op(): ...@@ -126,14 +127,18 @@ if supports_custom_op():
fake_impl=inplace_all_reduce_fake, fake_impl=inplace_all_reduce_fake,
) )
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: def outplace_all_reduce(
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found." assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]() group = _groups[group_name]()
if group is None: if group is None:
raise ValueError(f"Group {group_name} is destroyed.") raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor) return group._all_reduce_out_place(tensor, outplace_all_reduce_method)
def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: def outplace_all_reduce_fake(
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
) -> torch.Tensor:
return torch.empty_like(tensor) return torch.empty_like(tensor)
direct_register_custom_op( direct_register_custom_op(
...@@ -264,6 +269,12 @@ class GroupCoordinator: ...@@ -264,6 +269,12 @@ class GroupCoordinator:
PyNcclCommunicator, PyNcclCommunicator,
) )
if is_hip():
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce,
qr_rocm_arch_available,
)
self.pynccl_comm: Optional[PyNcclCommunicator] = None self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1: if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator( self.pynccl_comm = PyNcclCommunicator(
...@@ -283,6 +294,7 @@ class GroupCoordinator: ...@@ -283,6 +294,7 @@ class GroupCoordinator:
) )
self.ca_comm: Optional[CustomAllreduce] = None self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
if use_custom_allreduce and self.world_size > 1: if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation. # Initialize a custom fast all-reduce implementation.
try: try:
...@@ -295,6 +307,18 @@ class GroupCoordinator: ...@@ -295,6 +307,18 @@ class GroupCoordinator:
f"Setup Custom allreduce failed with {e}. To silence this " f"Setup Custom allreduce failed with {e}. To silence this "
"warning, specify --disable-custom-all-reduce explicitly." "warning, specify --disable-custom-all-reduce explicitly."
) )
if is_hip():
try:
# Initialize a custom quick all-reduce implementation for AMD
# when rocm >= gfx942. Quick reduce is designed as a
# complement to custom allreduce.
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
if qr_rocm_arch_available():
self.qr_comm = QuickAllReduce(
group=self.cpu_group, device=self.device
)
except Exception as e:
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
from sglang.srt.distributed.device_communicators.hpu_communicator import ( from sglang.srt.distributed.device_communicators.hpu_communicator import (
HpuCommunicator, HpuCommunicator,
...@@ -373,7 +397,8 @@ class GroupCoordinator: ...@@ -373,7 +397,8 @@ class GroupCoordinator:
graph_capture_context = GraphCaptureContext(stream) graph_capture_context = GraphCaptureContext(stream)
else: else:
stream = graph_capture_context.stream stream = graph_capture_context.stream
# We don't need the context of custom quick allreduce because the ipc access
# is already collected in init() and we can capture the quick allreduce directly.
ca_comm = self.ca_comm ca_comm = self.ca_comm
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
...@@ -388,23 +413,24 @@ class GroupCoordinator: ...@@ -388,23 +413,24 @@ class GroupCoordinator:
# operations. The current status is: # operations. The current status is:
# allreduce \ Mode | Eager | Graph | # allreduce \ Mode | Eager | Graph |
# -------------------------------------------- # --------------------------------------------
# quick allreduce | enabled | enabled |
# custom allreduce | enabled | enabled | # custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled | # PyNccl | disabled| enabled |
# PyMscclpp | disabled| enabled | # PyMscclpp | disabled| enabled |
# torch.distributed | enabled | disabled| # torch.distributed | enabled | disabled|
# #
# Note: When custom quick allreduce is enabled, a runtime check
# will be performed. If the tensor size is too small, it will
# automatically fall back to the next available option.
# Note that custom allreduce will have a runtime check, if the # Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next # tensor size is too large, it will fallback to the next
# available option. # available option.
# Note that the PyMsccl needs to register the tensor in ahead, # Note that the PyMsccl needs to register the tensor in ahead,
# which will introduce large overhead in the eager case, # which will introduce large overhead in the eager case,
# therefore it is only supported in the graph case. # therefore it is only supported in the graph case.
# In summary: When using CUDA graph, we use # In summary: We select the appropriate allreduce method for
# either custom all-reduce kernel or pynccl. When not using # each mode based on the algorithm order in the table and
# CUDA graph, we use either custom all-reduce kernel or # their usage conditions.
# PyTorch NCCL. We always prioritize using custom all-reduce
# kernel but fall back to PyTorch or pynccl if it is
# disabled or not supported.
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
maybe_pynccl_context: Any maybe_pynccl_context: Any
if not pynccl_comm: if not pynccl_comm:
...@@ -464,27 +490,47 @@ class GroupCoordinator: ...@@ -464,27 +490,47 @@ class GroupCoordinator:
if self.npu_communicator is not None and not self.npu_communicator.disabled: if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_) return self.npu_communicator.all_reduce(input_)
outplace_all_reduce_method = None
if ( if (
self.qr_comm is not None
and not self.qr_comm.disabled
and self.qr_comm.should_quick_allreduce(input_)
):
outplace_all_reduce_method = "qr"
elif (
self.ca_comm is not None self.ca_comm is not None
and not self.ca_comm.disabled and not self.ca_comm.disabled
and self.ca_comm.should_custom_ar(input_) and self.ca_comm.should_custom_ar(input_)
) or ( ):
outplace_all_reduce_method = "ca"
elif (
self.pymscclpp_comm is not None self.pymscclpp_comm is not None
and not self.pymscclpp_comm.disabled and not self.pymscclpp_comm.disabled
and self.pymscclpp_comm.should_mscclpp_allreduce(input_) and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
): ):
outplace_all_reduce_method = "pymscclpp"
if outplace_all_reduce_method is not None:
return torch.ops.sglang.outplace_all_reduce( return torch.ops.sglang.outplace_all_reduce(
input_, group_name=self.unique_name input_,
group_name=self.unique_name,
outplace_all_reduce_method=outplace_all_reduce_method,
) )
else: else:
torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name) torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name)
return input_ return input_
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: def _all_reduce_out_place(
self, input_: torch.Tensor, outplace_all_reduce_method: str
) -> torch.Tensor:
qr_comm = self.qr_comm
ca_comm = self.ca_comm ca_comm = self.ca_comm
pymscclpp_comm = self.pymscclpp_comm pymscclpp_comm = self.pymscclpp_comm
assert ca_comm is not None or pymscclpp_comm is not None assert any([qr_comm, ca_comm, pymscclpp_comm])
if ca_comm is not None and not ca_comm.disabled: if outplace_all_reduce_method == "qr":
assert not qr_comm.disabled
out = qr_comm.quick_all_reduce(input_)
elif outplace_all_reduce_method == "ca":
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_) out = ca_comm.custom_all_reduce(input_)
else: else:
assert not pymscclpp_comm.disabled assert not pymscclpp_comm.disabled
......
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#ifdef USE_ROCM
#include "quick_all_reduce.h"
quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional<int64_t> qr_max_size) {
if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported");
if (world_size == 6) throw std::invalid_argument("world size == 6 is not supported");
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
fptr->init(world_size, rank, qr_max_size);
return (quickreduce::fptr_t)fptr;
}
void qr_destroy(quickreduce::fptr_t _fa) {
if (_fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
fa->destroy();
delete fa;
}
}
torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
hipIpcMemHandle_t handle = fa->get_handle();
auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle = torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
return data_handle;
}
void qr_open_handles(quickreduce::fptr_t _fa, const std::vector<torch::Tensor>& handles) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
std::vector<hipIpcMemHandle_t> ipc_handles;
ipc_handles.reserve(handles.size());
for (auto& handle : handles) {
// Ensure the tensor is on the same device as the current device.
hipIpcMemHandle_t ipc_handle;
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
ipc_handles.push_back(ipc_handle);
}
fa->open_ipc_handles(ipc_handles);
}
void qr_all_reduce(
quickreduce::fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
if (out.scalar_type() == at::ScalarType::Half) {
fa->allreduce<half, false>(
reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
out.numel(),
quant_level,
stream);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
if (cast_bf2half) {
fa->allreduce<half, true>(
reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
out.numel(),
quant_level,
stream);
} else {
fa->allreduce<quickreduce::nv_bfloat16, false>(
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
out.numel(),
quant_level,
stream);
}
} else {
throw std::runtime_error("quick allreduce only supports float16 and bfloat16");
}
}
int64_t qr_max_size() {
// The default is 2GB (2,147,483,648 bytes)
return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
}
#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, cast_bf2half>; \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, cast_bf2half>; \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>;
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)
#endif // USE_ROCM
This diff is collapsed.
#pragma once
#include <hip/hip_runtime.h>
#include <vector>
#include "quick_all_reduce.cuh"
#define HIP_CHECK(err) \
do { \
hipError_t err_ = (err); \
if (err_ != hipSuccess) { \
std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, hipGetErrorString(err_)); \
throw std::runtime_error("HIP error"); \
} \
} while (0)
namespace quickreduce {
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot(
T const* A,
T* B,
uint32_t N,
uint32_t num_blocks,
int rank,
uint8_t** dbuffer_list,
uint32_t data_offset,
uint32_t flag_color) {
int block = blockIdx.x;
int grid = gridDim.x;
while (block < num_blocks) {
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, flag_color);
block += grid;
flag_color++;
}
}
#define TWOSHOT_DISPATCH(__codec) \
if (world_size == 2) { \
using LineCodec = __codec<T, 2>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL( \
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), \
dim3(kBlockTwoShot), \
0, \
stream, \
A, \
B, \
N, \
num_blocks, \
rank, \
dbuffer_list, \
data_offset, \
flag_color); \
} else if (world_size == 4) { \
using LineCodec = __codec<T, 4>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL( \
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), \
dim3(kBlockTwoShot), \
0, \
stream, \
A, \
B, \
N, \
num_blocks, \
rank, \
dbuffer_list, \
data_offset, \
flag_color); \
} else if (world_size == 8) { \
using LineCodec = __codec<T, 8>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL( \
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), \
dim3(kBlockTwoShot), \
0, \
stream, \
A, \
B, \
N, \
num_blocks, \
rank, \
dbuffer_list, \
data_offset, \
flag_color); \
}
enum QuickReduceQuantLevel {
F16 = 0,
INT8 = 1,
INT6 = 2,
INT4 = 3,
};
struct DeviceComms {
// Max problem size is 2GB (in bytes) or half of uint32_t max value.
int64_t kMaxProblemSize = static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
// Max TP-8
static int constexpr kMaxWorldSize = 8;
bool initialized = false;
uint32_t flag_color = 1;
int world_size;
int rank;
uint8_t* dbuffer;
uint8_t** dbuffer_list;
hipIpcMemHandle_t buffer_ipc_handle;
std::vector<hipIpcMemHandle_t> all_buffer_ipc_handles;
std::vector<uint8_t*> buffer_list;
uint32_t data_offset;
DeviceComms() : initialized(false), world_size(1), rank(0) {}
~DeviceComms() {
destroy();
}
void init(int world_size, int rank, std::optional<int64_t> max_problem_size = std::nullopt) {
destroy();
this->world_size = world_size;
this->rank = rank;
if (max_problem_size.has_value() && max_problem_size.value() > 0) {
this->kMaxProblemSize = max_problem_size.value();
}
// Allocate buffer size for worst case: F16 2-stage buffer.
uint32_t flags_buffer_size = 2 * world_size * kMaxNumBlocks * sizeof(uint32_t);
static int64_t data_buffer_size = 2 * this->kMaxProblemSize;
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
data_offset = flags_buffer_size;
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached));
// Clear the flags buffer.
HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size));
// Device-side list of IPC buffers.
buffer_list.resize(world_size);
HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)));
// Create IPC handles for rank's communication buffer.
all_buffer_ipc_handles.resize(world_size);
HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer));
initialized = true;
}
int get_world_size() {
return world_size;
}
int get_rank() {
return rank;
}
bool status() {
return initialized;
}
hipIpcMemHandle_t const get_handle() {
return buffer_ipc_handle;
}
void destroy() {
if (initialized) {
for (int i = 0; i < world_size; i++) {
if (i != rank) {
HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i]));
}
}
HIP_CHECK(hipFree(dbuffer));
HIP_CHECK(hipFree(dbuffer_list));
initialized = false;
}
}
void open_ipc_handles(std::vector<hipIpcMemHandle_t> const& ipc_handles) {
assert(ipc_handles.size() == all_buffer_ipc_handles.size());
for (int i = 0; i < world_size; i++) {
all_buffer_ipc_handles[i] = ipc_handles[i];
}
// Open device memory access to the IPC communication buffers.
// Note: For our own rank, we do not need to open a handle.
for (int i = 0; i < world_size; i++) {
if (i != rank) {
HIP_CHECK(
hipIpcOpenMemHandle((void**)&buffer_list[i], all_buffer_ipc_handles[i], hipIpcMemLazyEnablePeerAccess));
} else {
buffer_list[i] = dbuffer;
}
}
HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), world_size * sizeof(uint8_t*), hipMemcpyHostToDevice));
}
template <typename T, bool cast_bf2half>
void allreduce(T const* A, T* B, uint32_t N, int quant_level, hipStream_t stream) {
if (world_size != 2 && world_size != 4 && world_size != 8) {
throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size));
}
// Configuration.
uint32_t msg_size = N * sizeof(T);
uint32_t num_blocks = divceil(msg_size, kTileSize);
uint32_t grid = min(kMaxNumBlocks, num_blocks);
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
switch (quant_level_) {
case QuickReduceQuantLevel::INT8:
TWOSHOT_DISPATCH(CodecQ8)
break;
case QuickReduceQuantLevel::INT6:
TWOSHOT_DISPATCH(CodecQ6)
break;
case QuickReduceQuantLevel::INT4:
TWOSHOT_DISPATCH(CodecQ4)
break;
default:
TWOSHOT_DISPATCH(CodecFP)
break;
}
HIP_CHECK(cudaGetLastError());
// Rotate the flag color.
flag_color += divceil(N, grid);
}
};
} // namespace quickreduce
#pragma once
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <cstdint>
#define __quickreduce_device_inline__ __device__ __forceinline__
#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4)
#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4)
namespace quickreduce {
typedef __hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat162 nv_bfloat162;
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
// Setup acquire-release semantics for vector memory reads (mubuf instruction)
// as per architecture.
#if defined(__gfx942__)
// CDNA3: Scope bits sc0, sc1
#define MUBUF_ACQUIRE 16
#define MUBUF_RELEASE 16
#elif (defined(__gfx908__) || defined(__gfx90a__))
// CDNA1 and CDNA2 - glc bit
#define MUBUF_ACQUIRE 1
#define MUBUF_RELEASE 0
#endif
static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t
// Number of atoms (4xf16x2_t) processed by a single thread
static constexpr int kAtoms = 8;
// We use a workgroup of 256 threads
static constexpr int kBlockSize = 256;
static constexpr int kAtomStride = kBlockSize;
// Size and atom stride of source/destination data that the block will
// process.
// Workgroup scope = Tile = (256 threads x 8 atoms x 16B)
static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t);
// Max number of blocks. 304 CUs on MI300
static constexpr int kMaxNumBlocks = 304 * 4;
// Standard CDNA wavefront size.
static constexpr int kWavefront = 64;
// 256 thread, 4 wavefronts.
static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1};
// Number of threads in a group for quantization
// It corresponds to 32 F16 elements in quantization block
static constexpr int kThreadGroupSize = 8;
// Methods
__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, unsigned long y) {
return ((x + y - 1) / y);
}
union BufferResource {
__quickreduce_device_inline__ constexpr BufferResource() : config(0x00020000U) {}
__quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, uint32_t buffer_size)
: address(buffer_address), range(buffer_size), config(0x00020000U) {}
int32x4_t descriptor;
struct {
void* address; // 8B, out of which first 48b is address, and 16b is stride
// (unused)
uint32_t range; // Byte range for the buffer resource
uint32_t config; // Constant, DFMT=32b
};
};
__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4(
int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
__quickreduce_device_inline__ static void
buffer_store_dwordx4(int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm(
"llvm.amdgcn.raw.buffer.store.v4i32");
__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) {
#if defined(__gfx942__)
if (value) {
asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::);
} else {
asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::);
}
#endif
}
union bf162_int_union {
int i;
nv_bfloat162 bf2;
};
template <typename T>
__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B);
template <>
__quickreduce_device_inline__ void packed_assign_add<half>(int32x4_t* A, int32x4_t* B) {
int32x4_t& tR_fragment = A[0];
int32x4_t& tA_fragment = B[0];
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[0]) : "v"(tR_fragment[0]), "v"(tA_fragment[0]));
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[1]) : "v"(tR_fragment[1]), "v"(tA_fragment[1]));
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[2]) : "v"(tR_fragment[2]), "v"(tA_fragment[2]));
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[3]) : "v"(tR_fragment[3]), "v"(tA_fragment[3]));
}
template <>
__quickreduce_device_inline__ void packed_assign_add<nv_bfloat16>(int32x4_t* A, int32x4_t* B) {
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(A);
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(B);
#pragma unroll
for (int i = 0; i < 4; i++) {
tA[i] = __hadd2(tA[i], tB[i]);
}
}
template <typename T>
__quickreduce_device_inline__ int packed_max(int a, int b);
template <>
__quickreduce_device_inline__ int packed_max<half>(int a, int b) {
int result;
asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_max<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hmax2(A.bf2, B.bf2);
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_min(int a, int b);
template <>
__quickreduce_device_inline__ int packed_min<half>(int a, int b) {
int result;
asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_min<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hmin2(A.bf2, B.bf2);
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_abs_max(int a, int b);
template <>
__quickreduce_device_inline__ int packed_abs_max<half>(int a, int b) {
half2 wmaxh2 = __builtin_bit_cast(half2, a);
half2 wminh2 = __builtin_bit_cast(half2, b);
half2 wblockmaxh2;
wblockmaxh2.x = __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x;
wblockmaxh2.y = __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y;
return __builtin_bit_cast(int, wblockmaxh2);
}
template <>
__quickreduce_device_inline__ int packed_abs_max<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x;
R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y;
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_add(int a, int b);
template <>
__quickreduce_device_inline__ int packed_add<half>(int a, int b) {
int result;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_add<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hadd2(A.bf2, B.bf2);
return R.i;
}
template <>
__quickreduce_device_inline__ int packed_add<int16_t>(int a, int b) {
int result;
asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <typename T>
__quickreduce_device_inline__ int packed_sub(int a, int b);
template <>
__quickreduce_device_inline__ int packed_sub<half>(int a, int b) {
int result;
// MI300 lacks packed fp16 sub instruction. So we do -1 * min + max
asm volatile("v_pk_fma_f16 %0, %1, %2 %3" : "=v"(result) : "v"(kNegOne), "v"(b), "v"(a));
return result;
}
template <>
__quickreduce_device_inline__ int packed_sub<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hsub2(A.bf2, B.bf2);
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_mul(int a, int b);
template <>
__quickreduce_device_inline__ int packed_mul<half>(int a, int b) {
int result;
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_mul<nv_bfloat16>(int a, int b) {
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
nv_bfloat162 tR = __hmul2(*tA, *tB);
return *(reinterpret_cast<int*>(&tR));
}
template <typename T>
__quickreduce_device_inline__ int packed_rcp(int a);
template <>
__quickreduce_device_inline__ int packed_rcp<half>(int a) {
return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a)));
}
template <>
__quickreduce_device_inline__ int packed_rcp<nv_bfloat16>(int a) {
bf162_int_union A, R;
A.i = a;
R.bf2 = h2rcp(A.bf2);
return R.i;
}
// changes dtype
__quickreduce_device_inline__ float T2float_cast(half a) {
return __half2float(a);
}
__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) {
return __bfloat162float(a);
}
template <typename T>
__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) {
const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize;
int wmax, wmin, wblockmax;
int a, b;
a = packed_max<T>(atom[0], atom[1]);
b = packed_max<T>(atom[2], atom[3]);
wmax = packed_max<T>(a, b);
a = packed_min<T>(atom[0], atom[1]);
b = packed_min<T>(atom[2], atom[3]);
wmin = packed_min<T>(a, b);
// Reduce the max among a group of threads
// Note: This is basically 2 blocks of values setup as the
// upper/lower halves of the f16x2_t
for (int i = 1; i < kThreadGroupSize; i <<= 1) {
int x = __shfl_down(wmax, i);
wmax = packed_max<T>(wmax, x);
int y = __shfl_down(wmin, i);
wmin = packed_min<T>(wmin, y);
}
wblockmax = packed_abs_max<T>(wmax, wmin);
// Share with the cohort
wblockmax = __shfl(wblockmax, group_leader);
return wblockmax;
}
__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, uint32_t flag) {
__atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE);
}
__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, uint32_t flag) {
while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) {
}
}
} // namespace quickreduce
...@@ -54,6 +54,25 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -54,6 +54,25 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle); m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle);
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle); m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
// quick allreduce
#ifdef USE_ROCM
m.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()");
m.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
m.def("init_custom_qr", &init_custom_qr);
m.def("qr_destroy", &qr_destroy);
m.def("qr_get_handle", &qr_get_handle);
m.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
m.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
// Max input size in bytes
m.def("qr_max_size", &qr_max_size);
#endif
/* /*
* From csrc/moe * From csrc/moe
*/ */
......
...@@ -66,6 +66,13 @@ void register_graph_buffers( ...@@ -66,6 +66,13 @@ void register_graph_buffers(
fptr_t _fa, const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets); fptr_t _fa, const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets);
torch::Tensor allocate_meta_buffer(int64_t size); torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
// quick allreduce
fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional<int64_t> qr_max_size = std::nullopt);
void qr_destroy(fptr_t _fa);
torch::Tensor qr_get_handle(fptr_t _fa);
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
#else #else
// custom allreduce // custom allreduce
fptr_t fptr_t
...@@ -77,6 +84,8 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta ...@@ -77,6 +84,8 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs); void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs);
void register_graph_buffers( void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets); fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
// mscclpp
torch::Tensor mscclpp_generate_unique_id(); torch::Tensor mscclpp_generate_unique_id();
fptr_t mscclpp_init_context( fptr_t mscclpp_init_context(
const torch::Tensor& unique_id, const torch::Tensor& unique_id,
......
from typing import List, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -49,6 +49,38 @@ if torch.version.hip is not None: ...@@ -49,6 +49,38 @@ if torch.version.hip is not None:
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp) return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
# ROCM quick allreduce
def init_custom_qr(
rank: int, world_size: int, qr_max_size: Optional[int] = None
) -> int:
return torch.ops.sgl_kernel.init_custom_qr.default(
world_size, rank, qr_max_size
)
def qr_get_handle(fa: int) -> torch.Tensor:
return torch.ops.sgl_kernel.qr_get_handle.default(fa)
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
torch.ops.sgl_kernel.qr_open_handles.default(fa, handles)
def qr_all_reduce(
fa: int,
profile: int,
inp: torch.Tensor,
out: torch.Tensor,
cast_bf162half: bool,
) -> None:
torch.ops.sgl_kernel.qr_all_reduce.default(
fa, profile, inp, out, cast_bf162half
)
def qr_destroy(fa: int) -> None:
torch.ops.sgl_kernel.qr_destroy.default(fa)
def qr_max_size() -> int:
return torch.ops.sgl_kernel.qr_max_size.default()
# mscclpp
def mscclpp_generate_unique_id() -> bytes: def mscclpp_generate_unique_id() -> bytes:
raise NotImplementedError() raise NotImplementedError()
......
...@@ -41,6 +41,7 @@ include_dirs = [ ...@@ -41,6 +41,7 @@ include_dirs = [
sources = [ sources = [
"csrc/allreduce/custom_all_reduce.hip", "csrc/allreduce/custom_all_reduce.hip",
"csrc/allreduce/quick_all_reduce.cu",
"csrc/moe/moe_align_kernel.cu", "csrc/moe/moe_align_kernel.cu",
"csrc/moe/moe_topk_softmax_kernels.cu", "csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/torch_extension_rocm.cc", "csrc/torch_extension_rocm.cc",
......
import os
import random
import socket
import unittest
from typing import Any
import ray
import torch
import torch.distributed as dist
from sglang.srt.distributed import init_distributed_environment
from sglang.srt.distributed.communication_op import ( # noqa
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
qr_rocm_arch_available,
)
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_group,
graph_capture,
initialize_model_parallel,
)
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(42)
random.seed(44) # keep the deterministic seed
def get_open_port() -> int:
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int, cls: Any, test_target: Any, quant_mode: str
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(log_to_driver=True)
distributed_init_port = get_open_port()
refs = []
for rank in range(world_size):
refs.append(
test_target.remote(cls, world_size, rank, distributed_init_port, quant_mode)
)
ray.get(refs)
ray.shutdown()
class TestQuickAllReduce(CustomTestCase):
TEST_SIZES = [
2 * 1024 * 1024,
4 * 1024 * 1024,
8 * 1024 * 1024,
16 * 1024 * 1024,
32 * 1024 * 1024,
]
TEST_LOOP = 5
# Too many configurations can lead to a test grid that is too large
# The tp takes too long to boot,let's just choose 4 out of 12 configurations
# WORLD_SIZES = [2, 4, 8]
# QUANT_MODE = ["FP", "INT8", "INT6", "INT4"]
QUANT_MODE_WORLD_SIZE_PART = [["FP", 8], ["INT4", 4], ["INT8", 2], ["INT6", 2]]
@unittest.skipIf(
not qr_rocm_arch_available(),
"Only test Quick AllReduce on ROCm architectures >= gfx94*",
)
def test_graph_allreduce(self):
for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART:
quant_mode = quant_mode_world_size_part[0]
world_size = quant_mode_world_size_part[1]
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.graph_allreduce, quant_mode)
@unittest.skipIf(
not qr_rocm_arch_available(),
"Only test Quick AllReduce on ROCm architectures >= gfx94*",
)
def test_eager_allreduce(self):
for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART:
quant_mode = quant_mode_world_size_part[0]
world_size = quant_mode_world_size_part[1]
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.eager_allreduce, quant_mode)
@ray.remote(num_gpus=1, max_calls=1)
def graph_allreduce(self, world_size, rank, distributed_init_port, quant_mode):
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode
os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0"
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=rank,
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
group = get_tensor_model_parallel_group().device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data = torch.zeros(1)
data = data.to(device=device)
torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize()
del data
for sz in self.TEST_SIZES:
for dtype in [torch.float16, torch.bfloat16]:
for _ in range(self.TEST_LOOP):
with graph_capture() as graph_capture_context:
# use integers so result matches NCCL exactly
inp1 = torch.randint(
1,
23,
(sz,),
dtype=dtype,
device=torch.cuda.current_device(),
)
inp2 = torch.randint(
-23,
1,
(sz,),
dtype=dtype,
device=torch.cuda.current_device(),
)
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(
graph, stream=graph_capture_context.stream
):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test
# synchronization
dist.all_reduce(inp1, group=group)
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2, group=group)
graph.replay()
atol = 1.25 * world_size
rtol = 0.5 * world_size
for inp, out in [[inp1, out1], [inp2, out2]]:
torch.testing.assert_close(out, inp, atol=atol, rtol=rtol)
# try:
# torch.testing.assert_close(out, inp, atol=atol, rtol=rtol)
# except AssertionError as e:
# print("Max abs diff:", (out - inp).abs().max())
# print("Max rel diff:", ((out - inp).abs() / inp.abs().clamp(min=1e-5)).max())
@ray.remote(num_gpus=1, max_calls=1)
def eager_allreduce(self, world_size, rank, distributed_init_port, quant_mode):
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode
os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0"
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=rank,
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
group = get_tensor_model_parallel_group().device_group
for sz in self.TEST_SIZES:
for dtype in [torch.float16, torch.bfloat16]:
for _ in range(self.TEST_LOOP):
inp1 = torch.randint(
1,
23,
(sz,),
dtype=dtype,
device=torch.cuda.current_device(),
)
out1 = tensor_model_parallel_all_reduce(inp1)
dist.all_reduce(inp1, group=group)
atol = 1.25 * world_size
rtol = 0.5 * world_size
torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol)
# try:
# torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol)
# except AssertionError as e:
# print("Max abs diff:", (out1 - inp1).abs().max())
# print("Max rel diff:", ((out1 - inp1).abs() / inp1.abs().clamp(min=1e-5)).max())
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment