Commit ebfe47e4 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Initial commit

parents
import torch
from .utils import EventOverlap
from .buffer import Buffer
# noinspection PyUnresolvedReferences
from deep_ep_cpp import Config
import os
import torch
import torch.distributed as dist
from typing import Callable, List, Tuple, Optional, Union
# noinspection PyUnresolvedReferences
import deep_ep_cpp
# noinspection PyUnresolvedReferences
from deep_ep_cpp import Config, EventHandle
from .utils import EventOverlap
class Buffer:
"""
The core expert-parallel (EP) communication buffers for Mixture of Experts (MoE) model, which supports:
- high-throughput intranode all-to-all (dispatch and combine, using NVLink)
- high-throughput internode all-to-all (dispatch and combine, using RDMA without AR)
- low-latency all-to-all (dispatch and combine, using RDMA, AR supported)
Attributes:
num_sms: the SMs used in high-throughput kernels.
rank: the local rank number.
group_size: the number of ranks in the group.
group: the communication group.
num_nvl_bytes: the buffer size for intranode NVLink communication.
num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication.
runtime: the C++ runtime.
"""
num_sms: int = 20
def __init__(self, group: dist.ProcessGroup,
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None:
"""
Initialize the communication buffer.
Arguments:
group: the communication group.
num_nvl_bytes: the buffer size for intranode NVLink communication.
num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication.
low_latency_mode: whether to enable low-latency mode.
num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals
to the number of local experts.
"""
# TODO: argument docs
# Initialize the CPP runtime
self.rank = group.rank()
self.group_size = group.size()
self.group = group
self.num_nvl_bytes = num_nvl_bytes
self.num_rdma_bytes = num_rdma_bytes
self.low_latency_mode = low_latency_mode
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode)
# Synchronize device IDs
device_ids = [None, ] * self.group_size
local_device_id = self.runtime.get_local_device_id()
dist.all_gather_object(device_ids, local_device_id, group)
# Synchronize IPC handles
ipc_handles = [None, ] * self.group_size
local_ipc_handle = self.runtime.get_local_ipc_handle()
dist.all_gather_object(ipc_handles, local_ipc_handle, group)
# Synchronize NVSHMEM unique IDs
root_unique_id = None
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
if low_latency_mode:
assert num_qps_per_rank > 0
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
os.environ['NVSHMEM_QP_DEPTH'] = '1024'
# NOTES: NVSHMEM initialization requires at least 256 MiB
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'
# NOTES: make sure AR (Adaptive Routing) is turned off while running normal kernels, as we cannot verify AR status in the code
# Synchronize using the root ID
nvshmem_unique_ids = [None, ] * self.group_size
if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0):
root_unique_id = self.runtime.get_local_nvshmem_unique_id()
dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group)
root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)]
# Make CPP runtime available
self.runtime.sync(device_ids, ipc_handles, root_unique_id)
assert self.runtime.is_available()
@staticmethod
def set_num_sms(new_num_sms: int) -> None:
"""
Set the number of SMs to use in high-throughput kernels.
Arguments:
new_num_sms: the new number to be set.
"""
assert new_num_sms % 2 == 0, 'The SM count must be even'
Buffer.num_sms = new_num_sms
@staticmethod
def capture() -> EventOverlap:
"""
Capture a CUDA event on the current stream, i.e. `torch.cuda.current_stream()`.
Returns:
event: the captured event.
"""
return EventOverlap(EventHandle())
@staticmethod
def get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int) -> int:
"""
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
Arguments:
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
hidden: the hidden dimension of each token.
num_ranks: the number of EP group ranks.
num_experts: the number of all experts.
Returns:
size: the RDMA buffer size recommended.
"""
return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] = None,
offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor:
"""
Get the raw buffer (slice supported) as a PyTorch tensor.
Argument:
dtype: the data type (PyTorch `dtype`) for the tensor.
size: the slice size (by elements) to get from the buffer.
offset: the offset of the beginning element.
use_rdma_buffer: whether to return the RDMA buffer.
"""
tensor = self.runtime.get_local_buffer_tensor(dtype, offset, use_rdma_buffer)
if size is None:
return tensor
assert tensor.numel() >= size.numel()
return tensor[:size.numel()].view(size)
@staticmethod
def get_dispatch_config(num_ranks: int) -> Config:
"""
Get a recommended dispatch config.
Argument:
num_ranks: the number of ranks.
Returns:
config: the recommended config.
"""
# Intranode
if num_ranks <= 8:
return Config(Buffer.num_sms, 6, 256, 6, 128)
# Internode
config_map = {
16: Config(Buffer.num_sms, 16, 288, 20, 128),
24: Config(Buffer.num_sms, 8, 288, 32, 128),
32: Config(Buffer.num_sms, 8, 288, 32, 128),
64: Config(Buffer.num_sms, 20, 288, 28, 128),
128: Config(Buffer.num_sms, 20, 560, 32, 128),
144: Config(Buffer.num_sms, 32, 720, 12, 128),
160: Config(Buffer.num_sms, 28, 720, 12, 128),
}
assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}'
return config_map[num_ranks]
@staticmethod
def get_combine_config(num_ranks: int) -> Config:
"""
Get a recommended combine config.
Argument:
num_ranks: the number of ranks.
Returns:
config: the recommended config.
"""
# Intranode
if num_ranks <= 8:
return Config(Buffer.num_sms, 6, 256, 6, 128)
# Internode
config_map = {
16: Config(Buffer.num_sms, 2, 288, 28, 128),
24: Config(Buffer.num_sms, 1, 288, 20, 128),
32: Config(Buffer.num_sms, 1, 288, 20, 128),
64: Config(Buffer.num_sms, 1, 288, 20, 128),
128: Config(Buffer.num_sms, 1, 560, 12, 128),
144: Config(Buffer.num_sms, 2, 720, 8, 128),
160: Config(Buffer.num_sms, 2, 720, 8, 128),
}
assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}'
return config_map[num_ranks]
# noinspection PyTypeChecker
def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, EventOverlap]:
"""
Calculate the layout required for later communication.
Arguments:
topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token,
`-1` means no selections.
num_experts: the number of experts.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
Returns:
num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.
num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA
rank (with the same GPU index), return `None` for intranode settings.
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
event: the event after executing the kernel (valid only if `async_finish` is set).
"""
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \
self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None),
async_finish, allocate_on_comm_stream)
return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event)
# noinspection PyTypeChecker
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
handle: Optional[Tuple] = None,
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], List[int], Tuple, EventOverlap]:
"""
Dispatch tokens to different ranks, both intranode and internode settings are supported.
Intranode kernels require all the ranks should be visible via NVLink.
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
index should be visible via RDMA. AR must be disabled.
Arguments:
x: `torch.Tensor` or tuple of `torch.Tensor`, for the first type, the shape must be `[num_tokens, hidden]`,
and type must be `torch.bfloat16`; for the second type, the first element of the tuple must be shaped as
`[num_tokens, hidden]` with type `torch.float8_e4m3fn`, the second must be `[num_tokens, hidden // 128]`
(requiring divisible) with type `torch.float`.
handle: an optional communication handle, if set, the CPU will reuse the layout information to save some time.
num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.
num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA
rank (with the same GPU index), return `None` for intranode settings.
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token,
`-1` means no selections.
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
expert_alignment: align the number of tokens received by each local expert to this variable.
config: the performance tuning config.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
Returns:
recv_x: received tokens, the same type and tuple as the input `x`, but the number of tokens equals to the
received token count.
recv_topk_idx: received expert indices.
recv_topk_weights: received expert weights.
num_recv_tokens_per_expert_list: Python list shaped `[num_local_experts]`, the received token count by
each local expert, aligned to the input `expert_alignment`.
handle: the returned communication handle.
event: the event after executing the kernel (valid only if `async_finish` is set).
"""
# Default config
config = self.get_dispatch_config(self.group_size) if config is None else config
# Internode
if self.runtime.get_num_rdma_ranks() > 1:
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)
# Launch the kernel with cached or non-cached mode
x, x_scales = x if isinstance(x, tuple) else (x, None)
if handle is not None:
assert topk_idx is None and topk_weights is None
rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head = handle
num_recv_tokens = recv_src_idx.size(0)
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
x, x_scales, None, None,
None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
else:
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
# noinspection PyTypeChecker
def combine(self, x: torch.Tensor, handle: Tuple,
topk_weights: Optional[torch.Tensor] = None,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:
"""
Combine (reduce) tokens (addition **without** weights) from different ranks, both intranode and internode
settings are supported.
Intranode kernels require all the ranks should be visible via NVLink.
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
index should be visible via RDMA. AR must be disabled.
Arguments:
x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks.
handle: a must-set communication handle, you can obtain this from the dispatch function.
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the tokens' top-k weights for reducing to its original ranks.
config: the performance tuning config.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
Returns:
recv_x: the reduced token from its dispatched ranks.
recv_topk_weights: the reduced top-k weights from its dispatch ranks.
event: the event after executing the kernel (valid only if `async_finish` is set).
"""
# Default config
config = self.get_combine_config(self.group_size) if config is None else config
# Internode
if self.runtime.get_num_rdma_ranks() > 1:
return self.internode_combine(x, handle, topk_weights, config, previous_event, async_finish, allocate_on_comm_stream)
# NOTES: the second `_` is for the sending side, so we should use the third one
rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle
# Launch the kernel
recv_x, recv_topk_weights, event = self.runtime.intranode_combine(
x, topk_weights,
src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config,
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return recv_x, recv_topk_weights, EventOverlap(event)
# noinspection PyTypeChecker
def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
handle: Optional[Tuple] = None,
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], List[int], Tuple, EventOverlap]:
"""
Internode dispatch implementation, for more details, please refer to the `dispatch` docs.
Normally, you should not directly call this function.
"""
assert config is not None
# Launch the kernel with cached or non-cached mode
x, x_scales = x if isinstance(x, tuple) else (x, None)
if handle is not None:
assert topk_idx is None and topk_weights is None
is_token_in_rank, \
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
recv_src_meta, send_rdma_head, send_nvl_head = handle
num_recv_tokens = recv_src_meta.size(0)
num_rdma_recv_tokens = send_nvl_head.size(0)
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = self.runtime.internode_dispatch(
x, x_scales, topk_idx, topk_weights,
None, None, is_token_in_rank, None,
num_recv_tokens, num_rdma_recv_tokens,
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
else:
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, \
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
recv_src_meta, send_rdma_head, send_nvl_head, event = self.runtime.internode_dispatch(
x, x_scales, topk_idx, topk_weights,
num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
0, 0, None, None, None, None,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
handle = (is_token_in_rank,
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix,
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
recv_src_meta, send_rdma_head, send_nvl_head)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
# noinspection PyTypeChecker
def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
topk_weights: Optional[torch.Tensor] = None,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:
"""
Internode combine implementation, for more details, please refer to the `combine` docs.
Normally, you should not directly call this function.
"""
assert config is not None
# Unpack handle
is_combined_token_in_rank, \
_, _, \
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \
src_meta, send_rdma_head, send_nvl_head = handle
# Launch the kernel
combined_x, combined_topk_weights, event = self.runtime.internode_combine(
x, topk_weights,
src_meta, is_combined_token_in_rank,
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None),
async_finish, allocate_on_comm_stream)
return combined_x, combined_topk_weights, EventOverlap(event)
def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None:
"""
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
if the buffer is dirty at some time.
For example, after running the normal dispatch/combine, you must run this function before executing any
low-latency kernel.
Arguments:
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
hidden: the hidden dimension of each token.
num_experts: the number of all experts.
"""
self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)
# noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int,
async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
"""
A low-latency implementation for dispatching with IBGDA **with implicit FP8 casting**.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`.
topk_idx: `torch.Tensor` with `torch.int64`, shaped as `[num_tokens, num_topk]`, only several top-k shapes
are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
Returns:
recv_x: a tuple with received tokens for each expert. The first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx,
num_max_dispatch_tokens_per_rank, num_experts,
async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, num_experts)
tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count,
packed_recv_src_info, packed_recv_layout_range)
return (packed_recv_x, packed_recv_x_scales), packed_recv_count, handle, \
EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[torch.Tensor, EventOverlap, Callable]:
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
the local calculated tokens to be sent to this original rank and reduced.
topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched
tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals
to the number of dispatched tokens.
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
num_max_dispatch_tokens_per_rank, num_experts,
async_finish, return_recv_hook)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
import torch
from typing import Any, Optional, Tuple
# noinspection PyUnresolvedReferences
from deep_ep_cpp import Config, EventHandle
class EventOverlap:
"""
A wrapper class to manage CUDA events, also for better overlapping convenience.
Attributes:
event: the CUDA event captured.
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
"""
def __init__(self, event: Optional[EventHandle] = None,
extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None:
"""
Initialize the class.
Arguments:
event: the CUDA event captured.
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
"""
self.event = event
# NOTES: we use extra tensors to achieve stream recording, otherwise,
# stream recording will be incompatible with CUDA graph.
self.extra_tensors = extra_tensors
def current_stream_wait(self) -> None:
"""
The current stream `torch.cuda.current_stream()` waits for the event to be finished.
"""
assert self.event is not None
self.event.current_stream_wait()
def __enter__(self) -> Any:
"""
Utility for overlapping and Python `with` syntax.
You can overlap the kernels on the current stream with the following example:
```python
event_overlap = event_after_all_to_all_kernels()
with event_overlap():
do_something_on_current_stream()
# After exiting the `with` scope, the current stream with wait the event to be finished.
```
"""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""
Utility for overlapping and Python `with` syntax.
Please follow the example in the `__enter__` function.
"""
if self.event is not None:
self.event.current_stream_wait()
import os
import subprocess
import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
if __name__ == '__main__':
nvshmem_dir = os.getenv('NVSHMEM_DIR', None)
assert nvshmem_dir is not None and os.path.exists(nvshmem_dir), 'Failed to find NVSHMEM'
print(f'NVSHMEM directory: {nvshmem_dir}')
# TODO: currently, we only support Hopper architecture, we may add Ampere support later
os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0'
cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
'-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
nvcc_flags = ['-O3', '-Xcompiler', '-O3', '-rdc=true', '--ptxas-options=--register-usage-level=10']
include_dirs = ['csrc/', f'{nvshmem_dir}/include']
sources = ['csrc/deep_ep.cpp',
'csrc/kernels/runtime.cu', 'csrc/kernels/intranode.cu',
'csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu']
library_dirs = [f'{nvshmem_dir}/lib']
# Disable aggressive PTX instructions
if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '0')):
cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
# Disable DLTO (default by PyTorch)
nvcc_dlink = ['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem']
extra_link_args = ['-l:libnvshmem.a', '-l:nvshmem_bootstrap_uid.so', f'-Wl,-rpath,{nvshmem_dir}/lib']
extra_compile_args = {
'cxx': cxx_flags,
'nvcc': nvcc_flags,
'nvcc_dlink': nvcc_dlink
}
# noinspection PyBroadException
try:
cmd = ['git', 'rev-parse', '--short', 'HEAD']
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except Exception as _:
revision = ''
setuptools.setup(
name='deep_ep',
version='1.0.0' + revision,
packages=setuptools.find_packages(
include=['deep_ep']
),
ext_modules=[
CUDAExtension(
name='deep_ep_cpp',
include_dirs=include_dirs,
library_dirs=library_dirs,
sources=sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args
)
],
cmdclass={
'build_ext': BuildExtension
}
)
import os
import time
import torch
import torch.distributed as dist
# noinspection PyUnresolvedReferences
import deep_ep
from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back
# Test compatibility with low latency functions
import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings
num_tokens, hidden, num_topk_groups, num_topk, num_experts = 4096, 7168, min(num_nodes, 4), 8, (256 // num_ranks) * num_ranks
assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
rdma_rank_idx = rank_idx // num_local_ranks
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
inplace_unique(rdma_rank_idx, num_nodes)
# RDMA dispatch counts
rdma_idx = topk_idx // (num_experts // num_nodes)
rdma_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rdma_idx, num_nodes)
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
num_tokens_per_rdma_rank = torch.empty((num_nodes, ), dtype=torch.int, device='cuda')
token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
for i in range(num_nodes):
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
buffer.get_dispatch_layout(topk_idx, num_experts)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
print()
group.barrier()
time.sleep(1)
# Config
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, recv_gbl_rank_prefix_sum):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = recv_gbl_rank_prefix_sum[i].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank,
'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
if with_topk:
dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks
recv_gbl_rank_prefix_sum = handle[-4]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
if with_topk:
# Check `topk_idx`
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs)
# NOTES: handle must be refreshed
if not with_topk:
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
# Test combine
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
if local_rank == 0:
print(' passed', flush=True)
if local_rank == 0:
print()
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 33, 4):
for rdma_chunk_size in range(4, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor([best_results[0], best_results[1], best_results[2]], dtype=torch.int32, device='cuda')
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size)
dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank,
'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
'config': dispatch_config if dispatch_config is not None else config}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1):
for rdma_chunk_size in range(8, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size)
if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
# Please make sure AR (Adaptive Routing) is turned off when running normal internode kernels,
num_nodes = int(os.getenv('WORLD_SIZE', 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = False
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)
for i in (24, ):
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
if local_rank == 0:
print()
# Test compatibility with low latency functions
if test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
if __name__ == '__main__':
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
import os
import time
import torch
import torch.distributed as dist
# noinspection PyUnresolvedReferences
import deep_ep
from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_back
# Test compatibility with low latency functions
import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings
num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks
assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
buffer.get_dispatch_layout(topk_idx, num_experts)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
print()
group.barrier()
time.sleep(1)
# Config
nvl_buffer_size = 256
config = deep_ep.Config(num_sms, 8, nvl_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, rank_prefix_matrix):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = rank_prefix_matrix[i][rank].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'is_token_in_rank': is_token_in_rank,
'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
if with_topk:
dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks
rank_prefix_matrix = handle[0]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
if with_topk:
# Check `topk_idx`
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
check_data(recv_topk_weights, rank_prefix_matrix)
# Test cached dispatch (must without top-k staffs)
# NOTES: handle must be refreshed
if not with_topk:
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
# Test combine
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
if local_rank == 0:
print(' passed', flush=True)
if local_rank == 0:
print()
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda')
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size)
dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank,
'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
'config': dispatch_config if dispatch_config is not None else config}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility, num_rdma_bytes = False, 0
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
torch.manual_seed(rank)
for i in (24, ):
test_main(i, local_rank, num_local_ranks, num_ranks, rank, buffer, group)
if local_rank == 0:
print()
# Test compatibility with low latency functions
if test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
if __name__ == '__main__':
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
import random
import torch
import torch.distributed as dist
from functools import partial
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
for i in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for return_recv_hook in (False, True):
num_times += 1
for i in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous())
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i])
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
# Check received data
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
# Check combine correctness
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < 1e-5, f'Error: diff={diff}'
hash_value ^= hash_tensor(combined_x)
def create_test_cast_with_outliers(num_outliers):
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
tmp /= tmp.abs().amax(dim=1).view(-1, 1)
assert tmp.abs().amax().item() <= 1
# Create some amax outliers
for i in range(num_outliers):
tmp[random.randint(0, num_tokens - 1)] *= 1e3
return tmp
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(return_recv_hook):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
# Separate profiling
for return_recv_hook in (False, True):
group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True)
if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us')
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us')
return hash_value
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
do_pressure_test = False
for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0:
print(f'Testing with seed {seed} ...', flush=True)
ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed)
for i in range(20):
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}'
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
import os
import sys
import numpy as np
import torch
import torch.distributed as dist
from typing import Optional
def init_dist(local_rank: int, num_local_ranks: int):
# NOTES: you may rewrite this function with your own cluster settings
ip = os.getenv('MASTER_ADDR', '127.0.0.1')
port = int(os.getenv('MASTER_PORT', '8361'))
num_nodes = int(os.getenv('WORLD_SIZE', 1))
node_rank = int(os.getenv('RANK', 0))
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
dist.init_process_group(
backend='nccl',
init_method=f'tcp://{ip}:{port}',
world_size=num_nodes * num_local_ranks,
rank=node_rank * num_local_ranks + local_rank
)
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cuda')
torch.cuda.set_device(local_rank)
return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes)))
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double() + 1, y.double() + 1
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return (1 - sim).item()
def per_token_cast_to_fp8(x: torch.Tensor):
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
def inplace_unique(x: torch.Tensor, num_slots: int):
assert x.dim() == 2
mask = x < 0
x_padded = x.masked_fill(mask, num_slots)
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
bin_count = bin_count[:, :num_slots]
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
x[:, :].fill_(-1)
valid_len = min(num_slots, x.size(1))
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int):
num_tokens, num_experts = scores.shape
scores = scores.view(num_tokens, num_groups, -1)
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
return (scores * mask).view(num_tokens, num_experts)
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
# Flush L2 cache with 256 MB data
torch.cuda.synchronize()
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
# Warmup
for _ in range(num_warmups):
fn()
# Flush L2
cache.zero_()
# Testing
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
for i in range(num_tests):
# Record
start_events[i].record()
fn()
end_events[i].record()
if post_fn is not None:
post_fn()
torch.cuda.synchronize()
times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:]
return np.average(times), np.min(times), np.max(times)
class empty_suppress:
def __enter__(self):
return self
def __exit__(self, *_):
pass
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False):
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
lhs @ rhs
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
for _ in range(num_tests):
fn()
prof.step()
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tupled = isinstance(kernel_names, tuple)
prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
# Save chrome traces
if trace_path is not None:
prof.export_chrome_trace(trace_path)
# Return average kernel times
units = {'ms': 1e3, 'us': 1e6}
kernel_times = []
for name in kernel_names:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, '')) / scale)
break
break
return tuple(kernel_times) if is_tupled else kernel_times[0]
def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item()
# Install NVSHMEM
## Important notices
**This project is neither sponsored nor supported by NVIDIA.**
**Use of NVIDIA NVSHMEM is governed by the terms at [NVSHMEM Software License Agreement](https://docs.nvidia.com/nvshmem/api/sla.html).**
## Prerequisites
1. [GDRCopy](https://github.com/NVIDIA/gdrcopy) (v2.4 and above recommended) is a low-latency GPU memory copy library based on NVIDIA GPUDirect RDMA technology, and *it requires kernel module installation with root privileges.*
2. Hardware requirements
- GPUDirect RDMA capable devices, see [GPUDirect RDMA Documentation](https://docs.nvidia.com/cuda/gpudirect-rdma/)
- InfiniBand GPUDirect Async (IBGDA) support, see [IBGDA Overview](https://developer.nvidia.com/blog/improving-network-performance-of-hpc-systems-using-nvidia-magnum-io-nvshmem-and-gpudirect-async/)
- For more detailed requirements, see [NVSHMEM Hardware Specifications](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html#hardware-requirements)
## Installation procedure
### 1. Install GDRCopy
GDRCopy requires kernel module installation on the host system. Complete these steps on the bare-metal host before container deployment:
#### Build and installation
```bash
git clone https://github.com/NVIDIA/gdrcopy
cd gdrcopy
make -j$(nproc)
sudo make prefix=/opt/gdrcopy install
```
#### Kernel module installation
```bash
cd packages
CUDA=/path/to/cuda ./build-deb-packages.sh
sudo dpkg -i gdrdrv-dkms_2.4-4_amd64.deb \
libgdrapi_2.4-4_amd64.deb \
gdrcopy-tests_2.4-4_amd64.deb \
gdrcopy_2.4-4_amd64.deb
sudo ./insmod.sh # Load kernel modules on bare-metal system
```
#### Container environment notes
For containerized environments:
1. Host: keep kernel modules loaded (`gdrdrv`)
2. Container: install DEB packages *without* rebuilding modules:
```bash
sudo dpkg -i gdrcopy_2.4-4_amd64.deb \
libgdrapi_2.4-4_amd64.deb \
gdrcopy-tests_2.4-4_amd64.deb
```
#### Verification
```bash
gdrcopy_copybw # Should show bandwidth test results
```
### 2. Acquiring NVSHMEM source code
Download NVSHMEM v3.1.7 from the [NVIDIA NVSHMEM Archive](https://developer.nvidia.com/nvshmem-archive).
### 3. Apply our custom patch
Navigate to your NVSHMEM source directory and apply our provided patch:
```bash
git apply /path/to/deep_ep/dir/third-party/nvshmem.patch
```
### 4. Configure NVIDIA driver
Enable IBGDA by modifying `/etc/modprobe.d/nvidia.conf`:
```bash
options nvidia NVreg_EnableStreamMemOPs=1 NVreg_RegistryDwords="PeerMappingOverride=1;"
```
Update kernel configuration:
```bash
sudo update-initramfs -u
sudo reboot
```
For more detailed configurations, please refer to the [NVSHMEM Installation Guide](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html).
### 5. Build and installation
The following example demonstrates building NVSHMEM with IBGDA support:
```bash
CUDA_HOME=/path/to/cuda && \
GDRCOPY_HOME=/path/to/gdrcopy && \
NVSHMEM_SHMEM_SUPPORT=0 \
NVSHMEM_UCX_SUPPORT=0 \
NVSHMEM_USE_NCCL=0 \
NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT=0 \
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
NVSHMEM_USE_GDRCOPY=1 \
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/path/to/your/dir/to/install
cd build
make -j$(nproc)
make install
```
## Post-installation configuration
Set environment variables in your shell configuration:
```bash
export NVSHMEM_DIR=/path/to/your/dir/to/install # Use for DeepEP installation
export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
export PATH="${NVSHMEM_DIR}/bin:$PATH"
```
## Verification
```bash
nvshmem-info -a # Should display details of nvshmem
```
From 9d784943e1032f15dd7cdd2599192937ba9d9343 Mon Sep 17 00:00:00 2001
From: Shangyan Zhou <sy.zhou@deepseek.com>
Date: Fri, 20 Dec 2024 10:57:12 +0800
Subject: [PATCH 1/5] Change QP creating order.
---
src/modules/transport/ibgda/ibgda.cpp | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
index 31bc56a..ff02f50 100644
--- a/src/modules/transport/ibgda/ibgda.cpp
+++ b/src/modules/transport/ibgda/ibgda.cpp
@@ -2921,17 +2921,20 @@ int nvshmemt_ibgda_connect_endpoints(nvshmem_transport_t t, int *selected_dev_id
INFO(ibgda_state->log_level, "Creating %d RC QPs", device->rc.num_eps_per_pe);
for (int i = 0; i < num_rc_eps; ++i) {
// Do not create loopback to self
- if (i / device->rc.num_eps_per_pe == mype) {
+ int dst_pe = (i + 1 + mype) % n_pes;
+ int offset = i / n_pes;
+ int mapped_i = dst_pe * device->rc.num_eps_per_pe + offset;
+ if (dst_pe == mype) {
continue;
}
- status = ibgda_create_qp(&device->rc.eps[i], device, portid, i,
+ status = ibgda_create_qp(&device->rc.eps[mapped_i], device, portid, mapped_i,
NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
- "ibgda_create_dci failed on RC #%d.", i);
+ "ibgda_create_dci failed on RC #%d.", mapped_i);
- status = ibgda_get_rc_handle(&local_rc_handles[i], device->rc.eps[i], device);
+ status = ibgda_get_rc_handle(&local_rc_handles[mapped_i], device->rc.eps[mapped_i], device);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
- "ibgda_get_rc_handle failed on RC #%d.", i);
+ "ibgda_get_rc_handle failed on RC #%d.", mapped_i);
}
if (num_rc_eps) {
--
2.25.1
From 3cd3938bcbbabed7fb7675032afb02647ea9c2fe Mon Sep 17 00:00:00 2001
From: Shangyan Zhou <sy.zhou@deepseek.com>
Date: Mon, 23 Dec 2024 09:55:27 +0800
Subject: [PATCH 2/5] Disable timeout check
---
CMakeLists.txt | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 771ff98..9246d29 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -140,7 +140,7 @@ option(NVSHMEM_NVTX "Enable NVSHMEM NVTX support" ${NVSHMEM_NVTX_DEFAULT})
option(NVSHMEM_PMIX_SUPPORT "Enable Compilation of the PMIX bootstrap and PMIX specific code" $ENV{NVSHMEM_PMIX_SUPPORT})
option(NVSHMEM_SHMEM_SUPPORT "Enable Compilation of the SHMEM bootstrap and SHMEM specific code" $ENV{NVSHMEM_SHMEM_SUPPORT})
option(NVSHMEM_TEST_STATIC_LIB "Force tests to link only against the combined nvshmem.a binary" $ENV{NVSHMEM_TEST_STATIC_LIB})
-option(NVSHMEM_TIMEOUT_DEVICE_POLLING "Enable timeouts for NVSHMEM device-side polling functions (e.g. wait_until)" $ENV{NVSHMEM_TIMEOUT_DEVICE_POLLING})
+option(NVSHMEM_TIMEOUT_DEVICE_POLLING "Enable timeouts for NVSHMEM device-side polling functions (e.g. wait_until)" OFF)
option(NVSHMEM_TRACE "Enable NVSHMEM trace print events" $ENV{NVSHMEM_TRACE})
option(NVSHMEM_UCX_SUPPORT "Enable compilation of the UCX remote transport" $ENV{NVSHMEM_UCX_SUPPORT})
option(NVSHMEM_USE_DLMALLOC "Set dlmalloc as the NVSHMEM heap allocation method" $ENV{NVSHMEM_USE_DLMALLOC})
@@ -165,6 +165,7 @@ set(NVSHMEM_PREFIX ${NVSHMEM_PREFIX_DEFAULT} CACHE PATH "path to NVSHMEM install
set(PMIX_HOME ${PMIX_HOME_DEFAULT} CACHE PATH "path to PMIX installation")
set(SHMEM_HOME ${MPI_HOME} CACHE PATH "path to SHMEM installation")
set(UCX_HOME ${UCX_HOME_DEFAULT} CACHE PATH "path to UCX installation")
+set(NVSHMEM_TIMEOUT_DEVICE_POLLING OFF)
message(STATUS "NVSHMEM_PREFIX: ${NVSHMEM_PREFIX}")
message(STATUS "NVSHMEM_DEVEL: ${NVSHMEM_DEVEL}")
--
2.25.1
From 4e0eaff589d38f448715e43a935479451a41c0fe Mon Sep 17 00:00:00 2001
From: Shangyan Zhou <sy.zhou@deepseek.com>
Date: Fri, 10 Jan 2025 11:53:38 +0800
Subject: [PATCH 3/5] Add recv queue and recv cq for rc qps.
Let the ibgda rc qps use regular recv queue.
Add recv queue to ibgda dev qp.
IBGDA create recv cq
Setup recv cq.
fix recv queue.
Remove some useless idx.
Longer recv queue.
---
.../nvshmem_common_ibgda.h | 19 +++++-
src/modules/transport/ibgda/ibgda.cpp | 65 ++++++++++++++++---
2 files changed, 71 insertions(+), 13 deletions(-)
diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
index 32f6d02..7d4e250 100644
--- a/src/include/device_host_transport/nvshmem_common_ibgda.h
+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
@@ -168,14 +168,17 @@ typedef struct {
uint64_t get_head; // last wqe idx + 1 with a "fetch" operation (g, get, amo_fetch)
uint64_t get_tail; // last wqe idx + 1 polled with cst; get_tail > get_head is possible
} tx_wq;
+ struct {
+ uint64_t resv_head; // last reserved wqe idx + 1
+ } rx_wq;
struct {
uint64_t head;
uint64_t tail;
} ibuf;
char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING];
} __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1;
-static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 96,
- "ibgda_device_qp_management_v1 must be 96 bytes.");
+static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104,
+ "ibgda_device_qp_management_v1 must be 104 bytes.");
typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t;
@@ -199,9 +202,19 @@ typedef struct nvshmemi_ibgda_device_qp {
// May point to mvars.prod_idx or internal prod_idx
uint64_t *prod_idx;
} tx_wq;
+ struct {
+ uint16_t nwqes;
+ uint64_t tail;
+ void *wqe;
+ __be32 *dbrec;
+ void *bf;
+ nvshmemi_ibgda_device_cq_t *cq;
+ // May point to mvars.prod_idx or internal prod_idx
+ uint64_t *prod_idx;
+ } rx_wq;
nvshmemi_ibgda_device_qp_management_v1 mvars; // management variables
} nvshmemi_ibgda_device_qp_v1;
-static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 184, "ibgda_device_qp_v1 must be 184 bytes.");
+static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, "ibgda_device_qp_v1 must be 248 bytes.");
typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t;
diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
index ff02f50..b8d6bc7 100644
--- a/src/modules/transport/ibgda/ibgda.cpp
+++ b/src/modules/transport/ibgda/ibgda.cpp
@@ -194,6 +194,7 @@ struct ibgda_ep {
off_t dbr_offset;
struct ibgda_cq *send_cq;
+ struct ibgda_cq *recv_cq;
struct ibv_ah *ah;
uint32_t user_index;
@@ -1520,7 +1521,8 @@ static int ibgda_create_cq_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
struct ibv_context *context = device->context;
- unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes;
+ // Each RC qp has one send CQ and one recv CQ.
+ unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes * 2;
assert(ibgda_qp_depth > 0);
size_t num_cqe = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);
@@ -1683,7 +1685,8 @@ static int ibgda_create_qp_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
}
// Allocate and map WQ buffer for all QPs.
- wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB; // num_wqebb is always a power of 2
+ // Todo: reduce the size of wq buffer.
+ wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB * 2; // num_wqebb is always a power of 2
wq_buf_size = wq_buf_size_per_qp * num_eps;
status = ibgda_nic_control_alloc(&wq_mobject, wq_buf_size, IBGDA_GPAGE_SIZE);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "cannot allocate wq buf.\n");
@@ -1864,8 +1867,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
int cqe_version = 0;
struct ibgda_cq *send_cq = NULL;
+ struct ibgda_cq *recv_cq = NULL;
size_t num_wqebb = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);
+ size_t num_recv_wqe = ibgda_qp_depth;
+ size_t recv_wqe_size = 16;
int status = 0;
@@ -1893,6 +1899,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
status = ibgda_create_cq(&send_cq, device);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n");
+ if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {
+ status = ibgda_create_cq(&recv_cq, device);
+ NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n");
+ }
+
ep = (struct ibgda_ep *)calloc(1, sizeof(struct ibgda_ep));
NVSHMEMI_NULL_ERROR_JMP(ep, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
"Unable to allocate mem for ep.\n");
@@ -1921,12 +1932,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
DEVX_SET(qpc, qp_context, pm_state, MLX5_QPC_PM_STATE_MIGRATED);
DEVX_SET(qpc, qp_context, pd, device->qp_shared_object.pdn);
DEVX_SET(qpc, qp_context, uar_page, uar_mobject->uar->page_id); // BF register
- DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE); // Shared Receive Queue
- DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);
DEVX_SET(qpc, qp_context, cqn_snd, send_cq->cqn);
- DEVX_SET(qpc, qp_context, cqn_rcv, device->qp_shared_object.rcqn);
+ DEVX_SET(qpc, qp_context, cqn_rcv, qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC ? recv_cq->cqn : device->qp_shared_object.rcqn);
DEVX_SET(qpc, qp_context, log_sq_size, IBGDA_ILOG2_OR0(num_wqebb));
- DEVX_SET(qpc, qp_context, log_rq_size, 0);
DEVX_SET(qpc, qp_context, cs_req, 0); // Disable CS Request
DEVX_SET(qpc, qp_context, cs_res, 0); // Disable CS Response
DEVX_SET(qpc, qp_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE); // Enable dbr_umem_id
@@ -1935,6 +1943,15 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
DEVX_SET(qpc, qp_context, dbr_umem_id, dbr_umem->umem_id); // DBR buffer
DEVX_SET(qpc, qp_context, user_index, qp_idx);
DEVX_SET(qpc, qp_context, page_offset, 0);
+ if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC){
+ DEVX_SET(qpc, qp_context, rq_type, 0); // Regular recv queue
+ DEVX_SET(qpc, qp_context, log_rq_size, IBGDA_ILOG2(num_recv_wqe)); // 4 wqe
+ DEVX_SET(qpc, qp_context, log_rq_stride, IBGDA_ILOG2(recv_wqe_size) - 4); // max recv wqe size = 16B
+ } else {
+ DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE); // Shared Receive Queue, DC must use this.
+ DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);
+ DEVX_SET(qpc, qp_context, log_rq_size, 0);
+ }
ep->devx_qp = mlx5dv_devx_obj_create(context, cmd_in, sizeof(cmd_in), cmd_out, sizeof(cmd_out));
NVSHMEMI_NULL_ERROR_JMP(ep->devx_qp, status, NVSHMEMX_ERROR_INTERNAL, out,
@@ -1944,9 +1961,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
ep->portid = portid;
ep->sq_cnt = num_wqebb;
- ep->sq_buf_offset = 0;
+ ep->sq_buf_offset = num_recv_wqe * recv_wqe_size;
- ep->rq_cnt = 0;
+ ep->rq_cnt = num_recv_wqe;
ep->rq_buf_offset = 0;
ep->wq_mobject = device->qp_shared_object.wq_mobject;
@@ -1960,6 +1977,7 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
ep->uar_mobject = uar_mobject;
ep->send_cq = send_cq;
+ ep->recv_cq = recv_cq;
ep->qp_type = qp_type;
@@ -1971,6 +1989,7 @@ out:
if (status) {
if (uar_mobject) ibgda_unmap_and_free_qp_uar(uar_mobject);
if (send_cq) ibgda_destroy_cq(send_cq);
+ if (recv_cq) ibgda_destroy_cq(recv_cq);
if (ep) free(ep);
}
@@ -2269,6 +2288,10 @@ static int ibgda_destroy_ep(struct ibgda_ep *ep) {
ibgda_destroy_cq(ep->send_cq);
}
+ if (ep->recv_cq) {
+ ibgda_destroy_cq(ep->recv_cq);
+ }
+
if (ep->ah) {
ftable.destroy_ah(ep->ah);
}
@@ -2300,7 +2323,7 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda
dev_qp->qpn = ep->qpn;
assert(ep->wq_mobject->has_gpu_mapping);
- dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset);
+ dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->sq_buf_offset);
if (ibgda_nic_handler == IBGDA_NIC_HANDLER_GPU) {
assert(ep->dbr_mobject->has_gpu_mapping);
@@ -2312,6 +2335,12 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda
}
dev_qp->tx_wq.nwqes = ep->sq_cnt;
+ if (ep->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {
+ dev_qp->rx_wq.nwqes = ep->rq_cnt;
+ dev_qp->rx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->rq_buf_offset);
+ dev_qp->rx_wq.dbrec = (__be32 *)((uintptr_t)ep->dbr_mobject->aligned.gpu_ptr + ep->dbr_offset);
+ dev_qp->rx_wq.bf = (void *)ep->uar_mobject->aligned.gpu_ptr;
+ }
ibuf_dci_start = (uintptr_t)device->qp_shared_object.internal_buf.mem_object->aligned.gpu_ptr;
ibuf_rc_start = ibuf_dci_start + (size_per_dci * device->dci.num_eps);
@@ -2361,6 +2390,9 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
nvshmemi_ibgda_device_cq_t *cq_d = NULL;
nvshmemi_ibgda_device_cq_t *cq_h = NULL;
+ nvshmemi_ibgda_device_cq_t *recv_cq_d = NULL;
+ nvshmemi_ibgda_device_cq_t *recv_cq_h = NULL;
+
uint8_t *qp_group_switches_d = NULL;
const size_t mvars_offset = offsetof(nvshmemi_ibgda_device_qp_t, mvars);
@@ -2368,6 +2400,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
const size_t cons_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.cons_idx);
const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head);
const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head);
+ const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head);
nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
@@ -2405,7 +2438,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
num_dct_handles += device->dct.num_eps * n_pes;
num_dci_handles += device->dci.num_eps;
num_rc_handles += device->rc.num_eps_per_pe * n_pes;
- num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1));
+ num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1) * 2);
num_shared_dci_handles += device->dci.num_shared_eps;
}
num_elements = num_dct_handles - NVSHMEMI_IBGDA_MAX_CONST_DCTS;
@@ -2441,6 +2474,10 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
for (int i = 0; i < num_cq_handles; i++) {
nvshmemi_init_ibgda_device_cq(cq_h[i]);
}
+
+ recv_cq_h = (nvshmemi_ibgda_device_cq_t *)calloc(1, sizeof(*recv_cq_h));
+ NVSHMEMI_NULL_ERROR_JMP(recv_cq_h, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "recv_cq calloc err.");
+ nvshmemi_init_ibgda_device_cq(recv_cq_h[0]);
/* allocate host memory for dct, rc, cq, dci end */
/* allocate device memory for dct, rc, cq, dci start */
@@ -2544,6 +2581,14 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
}
++cq_idx;
+
+ rc_h[arr_idx].rx_wq.cq = &cq_d[cq_idx];
+
+ ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq);
+ cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset);
+ cq_h[cq_idx].qpn = rc_h[arr_idx].qpn;
+ cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type;
+ ++cq_idx;
}
}
}
--
2.25.1
From 0cc285269f154049f1c9775e07e306e03228eedc Mon Sep 17 00:00:00 2001
From: Shangyan Zhou <sy.zhou@deepseek.com>
Date: Sat, 8 Feb 2025 18:02:39 +0800
Subject: [PATCH 4/5] Maintain recv queue's cons_idx.
---
src/include/device_host_transport/nvshmem_common_ibgda.h | 5 +++--
src/modules/transport/ibgda/ibgda.cpp | 6 ++++--
2 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
index 7d4e250..502645d 100644
--- a/src/include/device_host_transport/nvshmem_common_ibgda.h
+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
@@ -170,6 +170,7 @@ typedef struct {
} tx_wq;
struct {
uint64_t resv_head; // last reserved wqe idx + 1
+ uint64_t cons_idx; // polled wqe idx + 1 (consumer index + 1)
} rx_wq;
struct {
uint64_t head;
@@ -177,7 +178,7 @@ typedef struct {
} ibuf;
char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING];
} __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1;
-static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104,
+static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 112,
"ibgda_device_qp_management_v1 must be 104 bytes.");
typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t;
@@ -214,7 +215,7 @@ typedef struct nvshmemi_ibgda_device_qp {
} rx_wq;
nvshmemi_ibgda_device_qp_management_v1 mvars; // management variables
} nvshmemi_ibgda_device_qp_v1;
-static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, "ibgda_device_qp_v1 must be 248 bytes.");
+static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 256, "ibgda_device_qp_v1 must be 248 bytes.");
typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t;
diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
index b8d6bc7..a1cfe2e 100644
--- a/src/modules/transport/ibgda/ibgda.cpp
+++ b/src/modules/transport/ibgda/ibgda.cpp
@@ -1063,7 +1063,7 @@ static inline void ibgda_nic_control_free(struct ibgda_mem_object *mobject) {
ibgda_host_mem_free(mobject);
}
-static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device) {
+static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device, int cc = 1) {
int status = 0;
struct ibgda_cq *gcq = NULL;
@@ -1114,7 +1114,7 @@ static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device)
cq_context = DEVX_ADDR_OF(create_cq_in, cmd_in, cq_context);
DEVX_SET(cqc, cq_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE);
DEVX_SET(cqc, cq_context, cqe_sz, MLX5_CQE_SIZE_64B);
- DEVX_SET(cqc, cq_context, cc, 0x1); // Use collapsed CQ
+ DEVX_SET(cqc, cq_context, cc, cc); // Use collapsed CQ
DEVX_SET(cqc, cq_context, oi, 0x1); // Allow overrun
DEVX_SET(cqc, cq_context, dbr_umem_id, dbr_umem->umem_id);
DEVX_SET(cqc, cq_context, log_cq_size, IBGDA_ILOG2_OR0(num_cqe));
@@ -2401,6 +2401,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head);
const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head);
const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head);
+ const size_t rx_cons_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.cons_idx);
nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
@@ -2586,6 +2587,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq);
cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset);
+ cq_h[cq_idx].cons_idx = (uint64_t *)(base_mvars_d_addr + rx_cons_offset);
cq_h[cq_idx].qpn = rc_h[arr_idx].qpn;
cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type;
++cq_idx;
--
2.25.1
From f91eb8510f8c9aa4f5769bd88434db5ab000e65a Mon Sep 17 00:00:00 2001
From: Shangyan Zhou <sy.zhou@deepseek.com>
Date: Tue, 11 Feb 2025 11:00:57 +0800
Subject: [PATCH 5/5] Init rx_wq counters.
---
src/include/device_host_transport/nvshmem_common_ibgda.h | 2 ++
1 file changed, 2 insertions(+)
diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
index 502645d..f0bc328 100644
--- a/src/include/device_host_transport/nvshmem_common_ibgda.h
+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
@@ -46,6 +46,8 @@
qp_man.tx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
qp_man.tx_wq.get_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
qp_man.tx_wq.get_tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
+ qp_man.rx_wq.resv_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
+ qp_man.rx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
qp_man.ibuf.head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
qp_man.ibuf.tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
} while (0);
--
2.25.1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment