Commit ebfe47e4 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Initial commit

parents
compile_commands.json
.idea
.DS_Store
*.pyc
build/
.cache/
.vscode/
*/cmake-build-*/
MIT License
Copyright (c) 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# DeepEP
DeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also as known as MoE dispatch and combine. The library also supports low-precision operations, including FP8.
To align with the group-limited gating algorithm proposed in the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper, DeepEP offers a set of kernels optimized for asymmetric-domain bandwidth forwarding, such as forwarding data from NVLink domain to RDMA domain. These kernels deliver high throughput, making them suitable for both training and inference prefilling tasks. Additionally, they support SM (Streaming Multiprocessors) number control.
For latency-sensitive inference decoding, DeepEP includes a set of low-latency kernels with pure RDMA to minimize delays. The library also introduces a hook-based communication-computation overlapping method that does not occupy any SM resource.
Notice: the implementation in this library may have some slight differences from the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper.
## Performance
### Normal kernels with NVLink and RDMA forwarding
We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow the DeepSeek-V3/R1 pretraining setting (4096 tokens per batch, 7168 hidden, top-4 groups, top-8 experts, FP8 dispatching and BF16 combining).
| Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth |
|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
| Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) |
| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) |
| Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) |
| Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) |
### Low-latency kernels with pure RDMA
We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, FP8 dispatching and BF16 combining).
| Dispatch #EP | Latency | RDMA bandwidth | Combine #EP | Latency | RDMA bandwidth |
|:------------:|:-------:|:--------------:|:-----------:|:-------:|:--------------:|
| 8 | 163 us | 46 GB/s | 8 | 318 us | 46 GB/s |
| 16 | 173 us | 43 GB/s | 16 | 329 us | 44 GB/s |
| 32 | 182 us | 41 GB/s | 32 | 350 us | 41 GB/s |
| 64 | 186 us | 40 GB/s | 64 | 353 us | 41 GB/s |
| 128 | 192 us | 39 GB/s | 128 | 369 us | 39 GB/s |
| 256 | 194 us | 39 GB/s | 256 | 360 us | 40 GB/s |
## Quick start
### Requirements
- Hopper GPUs (may support more architectures or devices later)
- Python 3.8 and above
- CUDA 12.3 and above
- PyTorch 2.1 and above
- NVLink for intranode communication
- RDMA network for internode communication
### Download and install NVSHMEM dependency
DeepEP also depends on our modified NVSHMEM. Please refer to our [NVSHMEM Installation Guide](third-party/README.md) for instructions.
### Development
```bash
# Build and make symbolic links for SO files
NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py build
# You may modify the specific SO names according to your own platform
ln -s build/lib.linux-x86_64-cpython-38/deep_ep_cpp.cpython-38-x86_64-linux-gnu.so
# Run test cases
# NOTES: you may modify the `init_dist` function in `tests/utils.py`
# according to your own cluster settings, and launch into multiple nodes
python tests/test_intranode.py
python tests/test_internode.py
python tests/test_low_latency.py
```
### Installation
```bash
NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install
```
Then, import `deep_ep` in your Python project, and enjoy!
## Network configurations
DeepEP is fully tested with InfiniBand networks. However, it is theoretically compatible with RDMA over Converged Ethernet (RoCE) as well.
### Traffic isolation
Traffic isolation is supported by InfiniBand through Virtual Lanes (VL).
To prevent interference between different types of traffic, we recommend segregating workloads across different virtual lanes as follows:
- workloads using normal kernels
- workloads using low-latency kernels
- other workloads
For DeepEP, you can control the virtual lane assignment by setting the `NVSHMEM_IB_SL` environment variable.
### Adaptive routing
Adaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Currently, low-latency kernels support adaptive routing, while normal kernels do not (support may be added soon). **Enabling adaptive routing for normal internode kernels may lead to deadlocks or data corruption issues**.
For low-latency kernels, enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance:
- enable adaptive routing in environments with heavy network loads
- use static routing in environments with light network loads
### Congestion control
Congestion control is disabled as we have not observed significant congestion in our production environment.
## Interfaces and examples
### Example use in model training or inference prefilling
The normal kernels can be used in model training or the inference prefilling phase (without the backward part) as the below example code shows.
```python
import torch
import torch.distributed as dist
from typing import List, Tuple, Optional, Union
from deep_ep import Buffer, EventOverlap
# Communication buffer (will allocate at runtime)
_buffer: Optional[Buffer] = None
# Set the number of SMs to use
# NOTES: this is a static variable
Buffer.set_num_sms(24)
# You may call this function at the framework initialization
def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
global _buffer
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
num_nvl_bytes, num_rdma_bytes = 0, 0
for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())):
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
# Allocate a buffer if not existed or not enough buffer size
# NOTES: the adaptive routing configuration of the network **must be off**
if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
return _buffer
def get_hidden_bytes(x: torch.Tensor) -> int:
t = x[0] if isinstance(x, tuple) else x
return t.size(1) * max(t.element_size(), 2)
def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int, previous_event: Optional[EventOverlap] = None) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]:
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
# refer to the docs of `Buffer.dispatch`
global _buffer
# Calculate layout before actual dispatch
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \
_buffer.get_dispatch_layout(topk_idx, num_experts,
previous_event=previous_event, async_finish=True,
allocate_on_comm_stream=previous_event is not None)
# Do MoE dispatch
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
# For more advanced usages, please refer to the docs of the `dispatch` function
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
_buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event, async_finish=True,
allocate_on_comm_stream=True)
# For event management, please refer to the docs of the `EventOverlap` class
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event
def dispatch_backward(grad_recv_x: torch.Tensor, grad_recv_topk_weights: torch.Tensor, handle: Tuple) -> \
Tuple[torch.Tensor, torch.Tensor, EventOverlap]:
global _buffer
# The backward process of MoE dispatch is actually a combine
# For more advanced usages, please refer to the docs of the `combine` function
combined_grad_x, combined_grad_recv_topk_weights, event = \
_buffer.combine(grad_recv_x, handle, topk_weights=grad_recv_topk_weights, async_finish=True)
# For event management, please refer to the docs of the `EventOverlap` class
return combined_grad_x, combined_grad_recv_topk_weights, event
def combine_forward(x: torch.Tensor, handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
Tuple[torch.Tensor, EventOverlap]:
global _buffer
# Do MoE combine
# For more advanced usages, please refer to the docs of the `combine` function
combined_x, _, event = _buffer.combine(x, handle, async_finish=True, previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None)
# For event management, please refer to the docs of the `EventOverlap` class
return combined_x, event
def combine_backward(grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]:
global _buffer
# The backward process of MoE combine is actually a dispatch
# For more advanced usages, please refer to the docs of the `combine` function
grad_x, _, _, _, _, event = _buffer.dispatch(grad_combined_x, handle=handle, async_finish=True,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None)
# For event management, please refer to the docs of the `EventOverlap` class
return grad_x, event
```
Moreover, inside the dispatch function, we may not know how many tokens to receive for the current rank. So an implicit CPU wait for GPU received count signal will be involved, as the following figure shows.
![normal](figures/normal.png)
### Example use in inference decoding
The low latency kernels can be used in the inference decoding phase as the below example code shows.
```python
import torch
import torch.distributed as dist
from typing import Tuple, Optional
from deep_ep import Buffer
# Communication buffer (will allocate at runtime)
# NOTES: there is no SM control API for the low-latency kernels
_buffer: Optional[Buffer] = None
# You may call this function at the framework initialization
def get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> Buffer:
# NOTES: the low-latency mode will consume much more space than the normal mode
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
global _buffer
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts)
# Allocate a buffer if not existed or not enough buffer size
if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes:
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert num_experts % group.size() == 0
_buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size())
return _buffer
def low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int):
global _buffer
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
recv_hidden_states, recv_expert_count, handle, event, hook = \
_buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts,
async_finish=False, return_recv_hook=True)
# NOTES: the actual tensor will not be received only if you call `hook()`,
# it is useful for double-batch overlapping, but **without any SM occupation**
# If you don't want to overlap, please set `return_recv_hook=False`
# Later, you can use our GEMM library to do the computation with this specific format
return recv_hidden_states, recv_expert_count, handle, event, hook
def low_latency_combine(hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple):
global _buffer
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
combined_hidden_states, event_overlap, hook = \
_buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle,
async_finish=False, return_recv_hook=True)
# NOTES: the same behavior as described in the dispatch kernel
return combined_hidden_states, event_overlap, hook
```
For two micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffics are happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e. the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload.
![low-latency](figures/low-latency.png)
## Notices
- For extreme performance, we discover and use an out-of-doc PTX instruction: `ld.global.nc.L1::no_allocate.L2::256B`. This instruction will lead to an undefined behavior: accessing volatile GPU memory with non-coherent read-only PTX modifiers `.nc`. But the correctness is tested to be guaranteed with `.L1::no_allocate` on Hopper architectures, and performance will be much better. If you find kernels not working on some other platforms, you may add `DISABLE_AGGRESSIVE_PTX_INSTRS=1` to `setup.py` and disable this, or file an issue.
- For better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek's internal cluster.
## License
This code repository is released under [the MIT License](LICENSE), except for codes that reference NVSHMEM (including `csrc/kernels/ibgda_device.cuh` and `third-party/nvshmem.patch`), which are subject to [NVSHMEM SLA](https://docs.nvidia.com/nvshmem/api/sla.html).
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
```bibtex
@misc{deepep2025,
title={DeepEP: an efficient expert-parallel communication library},
author={Chenggang Zhao and Shangyan Zhou and Liyue Zhang and Chengqi Deng and Zhean Xu and Yuxuan Liu and Kuai Yu and Jiashi Li and Liang Zhao},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/DeepEP}},
}
```
# NOTES: this CMake is only for debugging; for setup, please use Torch extension
cmake_minimum_required(VERSION 3.10)
project(deep_ep LANGUAGES CUDA CXX)
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC")
set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
set(TORCH_CUDA_ARCH_LIST "9.0")
find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)
find_package(NVSHMEM REQUIRED HINTS ${NVSHMEM_ROOT_DIR}/lib/cmake/nvshmem)
add_library(nvshmem ALIAS nvshmem::nvshmem)
add_library(nvshmem_host ALIAS nvshmem::nvshmem_host)
add_library(nvshmem_device ALIAS nvshmem::nvshmem_device)
# Seems bugs with CMake, NVCC 12 and C++ 17
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 14)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR})
add_subdirectory(kernels)
# Link CPP and CUDA together
pybind11_add_module(deep_ep_cpp deep_ep.cpp)
target_link_libraries(deep_ep_cpp PRIVATE ${EP_CUDA_LIBRARIES} ${TORCH_LIBRARIES} torch_python)
#pragma once
#include "kernels/api.cuh"
#include "kernels/exception.cuh"
namespace deep_ep {
template <typename dtype_t>
dtype_t cell_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}
template <typename dtype_t>
dtype_t align(dtype_t a, dtype_t b) {
return cell_div<dtype_t>(a, b) * b;
}
struct Config {
int num_sms;
int num_max_nvl_chunked_send_tokens;
int num_max_nvl_chunked_recv_tokens;
int num_max_rdma_chunked_send_tokens;
int num_max_rdma_chunked_recv_tokens;
Config(int num_sms,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) :
num_sms(num_sms),
num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
EP_HOST_ASSERT(num_sms >= 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
this->num_max_rdma_chunked_recv_tokens = align<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
}
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes();
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float);
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
// Legacy mode
if (num_ranks <= NUM_MAX_NVL_PEERS)
return 0;
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % 2 == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
};
struct LowLatencyBuffer {
int num_clean_int = 0;
void* dispatch_rdma_send_buffer = nullptr;
void* dispatch_rdma_recv_data_buffer = nullptr;
int* dispatch_rdma_recv_count_buffer = nullptr;
int* dispatch_rdma_atomic_token_counter = nullptr;
void* combine_rdma_send_buffer = nullptr;
void* combine_rdma_recv_data_buffer = nullptr;
int* combine_rdma_recv_flag_buffer = nullptr;
std::pair<int*, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int};
}
};
struct LowLatencyLayout {
size_t total_bytes = 0;
LowLatencyBuffer buffers[2];
template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*>
out_ptr_t advance(const in_ptr_t& ptr, size_t count) {
return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
}
LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
const int num_scales = hidden / 128;
const int num_local_experts = num_experts / num_ranks;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers
// - 2 symmetric odd/even signaling buffers
// Message sizes
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
size_t num_bytes_per_dispatch_msg = hidden + num_scales * sizeof(float) + sizeof(int4);
size_t num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(nv_bfloat16);
// Send buffer
size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
total_bytes += send_buffer_bytes * 2;
// Symmetric receive buffers
// TODO: optimize memory usages
size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
total_bytes += recv_buffer_bytes * 2;
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t dispatch_recv_atomic_token_counter_bytes = num_local_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes + dispatch_recv_atomic_token_counter_bytes,
combine_recv_flag_buffer_bytes);
total_bytes += signaling_buffer_bytes * 2;
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated
for (int i = 0; i < 2; ++ i) {
buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int)),
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i + dispatch_recv_count_buffer_bytes),
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i)
};
}
}
};
size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES;
}
} // namespace deep_ep
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
#include <atomic>
#include <chrono>
#include <cuda_runtime.h>
#include <memory>
#include <pybind11/functional.h>
#include <torch/python.h>
#include "deep_ep.hpp"
#include "kernels/api.cuh"
#include "kernels/configs.cuh"
namespace deep_ep {
Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode):
rank(rank), num_ranks(num_ranks),
num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes),
low_latency_mode(low_latency_mode),
comm_stream(at::cuda::getStreamFromPool(true)) {
// Task fifo memory
int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS;
int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS;
int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS;
// Common checks
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits<int>::max() or num_rdma_bytes == 0));
EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= std::numeric_limits<int>::max()));
EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode));
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
if (num_rdma_bytes > 0)
EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode);
// Get ranks
CUDA_CHECK(cudaGetDevice(&device_id));
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
// Get device info
cudaDeviceProp device_prop = {};
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));
if (num_nvl_bytes > 0) {
// Local IPC: alloc local memory and set local IPC handle
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes));
CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
buffer_ptrs_gpu = reinterpret_cast<void**>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes);
// Set task fifo
EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0);
task_fifo_ptrs[nvl_rank] = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
task_fifo_ptrs_gpu = reinterpret_cast<int**>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes + buffer_ptr_bytes);
// No need to synchronize, will do a full device sync during `sync`
CUDA_CHECK(cudaMemsetAsync(task_fifo_ptrs[nvl_rank], 0, fifo_bytes, comm_stream));
}
// Create 32 MiB workspace
CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES));
CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream));
// MoE counter
CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped));
CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_counter_mapped, const_cast<int*>(moe_recv_counter), 0));
*moe_recv_counter = -1;
// MoE expert-level counter
CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped));
CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast<int*>(moe_recv_expert_counter), 0));
for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++ i)
moe_recv_expert_counter[i] = -1;
// MoE RDMA-level counter
if (num_rdma_ranks > 0) {
CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped));
CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast<int*>(moe_recv_rdma_counter), 0));
*moe_recv_rdma_counter = -1;
}
}
Buffer::~Buffer() noexcept(false) {
// Synchronize
CUDA_CHECK(cudaDeviceSynchronize());
if (num_nvl_bytes > 0) {
// Barrier
intranode::barrier(task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream);
move_fifo_slots();
CUDA_CHECK(cudaDeviceSynchronize());
// Close remote IPC
if (is_available()) {
for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank)
CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i]));
}
// Free local buffer and error flag
CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank]));
}
// Free NVSHMEM
if (num_rdma_bytes > 0) {
CUDA_CHECK(cudaDeviceSynchronize());
internode::barrier();
internode::free(rdma_buffer_ptr);
internode::finalize();
}
// Free cuBLAS handle, workspace and MoE counter
CUDA_CHECK(cudaFree(workspace));
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_counter)));
// Free chunked mode staffs
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));
}
void Buffer::move_fifo_slots(int num_slots) {
head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS;
}
bool Buffer::is_available() const {
return available;
}
bool Buffer::is_internode_available() const {
return is_available() and num_ranks > NUM_MAX_NVL_PEERS;
}
int Buffer::get_num_rdma_ranks() const {
return num_rdma_ranks;
}
int Buffer::get_rdma_rank() const {
return rdma_rank;
}
int Buffer::get_root_rdma_rank(bool global) const {
return global ? nvl_rank : 0;
}
int Buffer::get_local_device_id() const {
return device_id;
}
pybind11::bytearray Buffer::get_local_ipc_handle() const {
return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE};
}
pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID");
auto unique_id = internode::get_unique_id();
return {reinterpret_cast<const char*>(unique_id.data()), unique_id.size()};
}
torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const {
torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype);
auto element_bytes = static_cast<int64_t>(elementSize(casted_dtype));
auto base_ptr = reinterpret_cast<uint8_t*>(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset;
auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes;
return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA));
}
void Buffer::sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray>& root_unique_id_opt) {
EP_HOST_ASSERT(not is_available());
// Sync IPC handles
if (num_nvl_bytes > 0) {
EP_HOST_ASSERT(num_ranks == device_ids.size());
EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size());
for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) {
EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value());
auto handle_str = std::string(all_gathered_handles[offset + i].value());
EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE);
if (offset + i != rank) {
std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE);
CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
task_fifo_ptrs[i] = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
} else {
EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0);
}
}
// Copy all buffer and task pointers to GPU
CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaDeviceSynchronize());
}
// Sync NVSHMEM handles and allocate memory
if (num_rdma_bytes > 0) {
// Initialize NVSHMEM
EP_HOST_ASSERT(root_unique_id_opt.has_value());
std::vector<uint8_t> root_unique_id(root_unique_id_opt->size());
auto root_unique_id_str = root_unique_id_opt->cast<std::string>();
std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size());
auto nvshmem_rank = low_latency_mode ? rank : rdma_rank;
auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks;
EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode));
internode::barrier();
// Allocate
rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES);
// Clean buffer (mainly for low-latency mode)
CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes));
// Barrier
internode::barrier();
CUDA_CHECK(cudaDeviceSynchronize());
}
// Ready to use
available = true;
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
EP_HOST_ASSERT(topk_idx.dim() == 2);
EP_HOST_ASSERT(topk_idx.is_contiguous());
EP_HOST_ASSERT(num_experts > 0);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
auto num_tokens = static_cast<int>(topk_idx.size(0)), num_topk = static_cast<int>(topk_idx.size(1));
auto num_tokens_per_rank = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
auto num_tokens_per_rdma_rank = std::optional<torch::Tensor>();
auto num_tokens_per_expert = torch::empty({num_experts}, dtype(torch::kInt32).device(torch::kCUDA));
auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, dtype(torch::kBool).device(torch::kCUDA));
if (is_internode_available())
num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
internode::get_dispatch_layout(topk_idx.data_ptr<int64_t>(),
num_tokens_per_rank.data_ptr<int>(),
num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>() : nullptr,
num_tokens_per_expert.data_ptr<int>(),
is_token_in_rank.data_ptr<bool>(),
num_tokens, num_topk, num_ranks, num_experts,
comm_stream);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto& t: {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {num_tokens_per_rdma_rank}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event};
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
bool cached_mode = cached_rank_prefix_matrix.has_value();
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.
EP_HOST_ASSERT(config.num_sms % 2 == 0);
int num_channels = config.num_sms / 2;
if (cached_mode) {
EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value());
} else {
EP_HOST_ASSERT(num_tokens_per_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_expert.has_value());
}
// Type checks
EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool);
if (cached_mode) {
EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32);
} else {
EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32);
}
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);
EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous());
EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) and is_token_in_rank.size(1) == num_ranks);
if (cached_mode) {
EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 and cached_rank_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and cached_rank_prefix_matrix->size(1) == num_ranks);
EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and cached_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and cached_channel_prefix_matrix->size(1) == num_channels);
} else {
EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS);
EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_experts = cached_mode ? 0 : static_cast<int>(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks;
// Top-k checks
int num_topk = 0;
int64_t* topk_idx_ptr = nullptr;
float* topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) {
num_topk = static_cast<int>(topk_idx->size(1));
EP_HOST_ASSERT(num_experts > 0);
EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous());
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_idx_ptr = topk_idx->data_ptr<int64_t>();
topk_weights_ptr = topk_weights->data_ptr<float>();
}
// FP8 scales checks
float* x_scales_ptr = nullptr;
int num_scales = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = x_scales->data_ptr<float>();
}
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Create handles (only return for non-cached mode)
int num_recv_tokens = -1;
auto rank_prefix_matrix = torch::Tensor();
auto channel_prefix_matrix = torch::Tensor();
std::vector<int> num_recv_tokens_per_expert_list;
// Barrier or send sizes
// To clean: channel start/end offset, head and tail
int num_memset_int = num_channels * num_ranks * 4;
if (cached_mode) {
num_recv_tokens = cached_num_recv_tokens;
rank_prefix_matrix = cached_rank_prefix_matrix.value();
channel_prefix_matrix = cached_channel_prefix_matrix.value();
// Copy rank prefix matrix and clean flags
intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr<int>(), num_memset_int,
buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, num_ranks,
comm_stream);
move_fifo_slots(2);
} else {
rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
// Send sizes
// Meta information:
// - Size prefix by ranks, shaped as `[num_ranks, num_ranks]`
// - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]`
// NOTES: no more token dropping in this version
*moe_recv_counter = -1;
for (int i = 0; i < num_local_experts; ++ i)
moe_recv_expert_counter[i] = -1;
EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes);
intranode::notify_dispatch(num_tokens_per_rank->data_ptr<int>(), moe_recv_counter_mapped, num_ranks,
num_tokens_per_expert->data_ptr<int>(), moe_recv_expert_counter_mapped, num_experts,
num_tokens, is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
rank_prefix_matrix.data_ptr<int>(),
num_memset_int, expert_alignment,
buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank,
comm_stream, num_channels);
move_fifo_slots(3);
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_recv_tokens = static_cast<int>(*moe_recv_counter);
// Read per-expert count
bool ready = (num_recv_tokens >= 0);
for (int i = 0; i < num_local_experts and ready; ++i)
ready &= moe_recv_expert_counter[i] >= 0;
if (ready)
break;
// Timeout check
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS)
throw std::runtime_error("DeepEP error: CPU recv timeout");
}
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
// Allocate new tensors
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA));
auto recv_topk_idx = std::optional<torch::Tensor>(), recv_topk_weights = std::optional<torch::Tensor>(), recv_x_scales = std::optional<torch::Tensor>();
auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
// Assign pointers
int64_t* recv_topk_idx_ptr = nullptr;
float* recv_topk_weights_ptr = nullptr;
float* recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) {
recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
if (x_scales.has_value()) {
recv_x_scales = x_scales->dim() == 1 ?
torch::empty({num_recv_tokens}, x_scales->options()) :
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
}
// Dispatch
EP_HOST_ASSERT(num_ranks * num_ranks * sizeof(int) + // Size prefix matrix
num_channels * num_ranks * sizeof(int) + // Channel start offset
num_channels * num_ranks * sizeof(int) + // Channel end offset
num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(int64_t) + // Top-k index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer
<= num_nvl_bytes);
intranode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr<int>(), recv_topk_idx_ptr, recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr<int>(),
send_head.data_ptr<int>(),
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
num_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto& t: {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_src_idx, recv_channel_prefix_matrix, send_head}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
// Return values
return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event};
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and rank_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and channel_prefix_matrix.scalar_type() == torch::kInt32);
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.
EP_HOST_ASSERT(config.num_sms % 2 == 0);
int num_channels = config.num_sms / 2;
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_recv_tokens = static_cast<int>(send_head.size(0));
EP_HOST_ASSERT(src_idx.size(0) == num_tokens);
EP_HOST_ASSERT(send_head.size(1) == num_ranks);
EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks and rank_prefix_matrix.size(1) == num_ranks);
EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
int num_topk = 0;
auto recv_topk_weights = std::optional<torch::Tensor>();
float* topk_weights_ptr = nullptr;
float* recv_topk_weights_ptr = nullptr;
if (topk_weights.has_value()) {
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(topk_weights->size(0) == num_tokens);
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
num_topk = static_cast<int>(topk_weights->size(1));
topk_weights_ptr = topk_weights->data_ptr<float>();
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
// Launch barrier and reset queue head and tail
EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes);
intranode::cached_notify_combine(buffer_ptrs_gpu, send_head.data_ptr<int>(),
num_channels, num_recv_tokens, num_channels * num_ranks * 2,
task_fifo_ptrs_gpu, head, rank, num_ranks,
comm_stream);
// NOTES: this function uses two FIFO slots (barrier before and after)
move_fifo_slots(2);
// Combine data
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer
<= num_nvl_bytes);
intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
recv_x.data_ptr(), recv_topk_weights_ptr,
x.data_ptr(), topk_weights_ptr,
src_idx.data_ptr<int>(), rank_prefix_matrix.data_ptr<int>(), channel_prefix_matrix.data_ptr<int>(),
send_head.data_ptr<int>(), num_tokens, num_recv_tokens, hidden, num_topk,
buffer_ptrs_gpu, rank, num_ranks,
comm_stream, config.num_sms,
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto& t: {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {topk_weights, recv_topk_weights}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
return {recv_x, recv_topk_weights, event};
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,
const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
const int num_channels = config.num_sms / 2;
EP_HOST_ASSERT(config.num_sms % 2 == 0);
EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS);
bool cached_mode = cached_rdma_channel_prefix_matrix.has_value();
if (cached_mode) {
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value());
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value());
} else {
EP_HOST_ASSERT(num_tokens_per_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_expert.has_value());
}
// Type checks
if (cached_mode) {
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32);
} else {
EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32);
}
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);
if (cached_mode) {
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 and cached_rdma_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and cached_rdma_channel_prefix_matrix->size(1) == num_channels);
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous());
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks);
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and cached_gbl_channel_prefix_matrix->size(1) == num_channels);
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous());
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks);
} else {
EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and num_tokens_per_rdma_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);
EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1)), hidden_int4 = static_cast<int>(x.size(1) * x.element_size() / sizeof(int4));
auto num_experts = cached_mode ? 0 : static_cast<int>(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks;
// Top-k checks
int num_topk = 0;
int64_t* topk_idx_ptr = nullptr;
float* topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) {
num_topk = static_cast<int>(topk_idx->size(1));
EP_HOST_ASSERT(num_experts > 0);
EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous());
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_idx_ptr = topk_idx->data_ptr<int64_t>();
topk_weights_ptr = topk_weights->data_ptr<float>();
}
// FP8 scales checks
float* x_scales_ptr = nullptr;
int num_scales = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = x_scales->data_ptr<float>();
}
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Create handles (only return for non-cached mode)
int num_recv_tokens = -1, num_rdma_recv_tokens = -1;
auto rdma_channel_prefix_matrix = torch::Tensor();
auto recv_rdma_rank_prefix_sum = torch::Tensor();
auto gbl_channel_prefix_matrix = torch::Tensor();
auto recv_gbl_rank_prefix_sum = torch::Tensor();
std::vector<int> num_recv_tokens_per_expert_list;
// Barrier or send sizes
if (cached_mode) {
num_recv_tokens = cached_num_recv_tokens;
num_rdma_recv_tokens = cached_num_rdma_recv_tokens;
rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value();
recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value();
gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value();
recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value();
// Just a barrier and clean flags
internode::cached_notify(hidden_int4, num_scales, num_topk, num_topk,
num_ranks, num_channels, 0, nullptr,
nullptr, nullptr, nullptr,
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu, head, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, true, low_latency_mode);
move_fifo_slots(2);
} else {
rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
recv_gbl_rank_prefix_sum = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
// Send sizes
*moe_recv_counter = -1, *moe_recv_rdma_counter = -1;
for (int i = 0; i < num_local_experts; ++ i)
moe_recv_expert_counter[i] = -1;
internode::notify_dispatch(num_tokens_per_rank->data_ptr<int>(), moe_recv_counter_mapped, num_ranks,
num_tokens_per_rdma_rank->data_ptr<int>(), moe_recv_rdma_counter_mapped,
num_tokens_per_expert->data_ptr<int>(), moe_recv_expert_counter_mapped, num_experts,
is_token_in_rank.data_ptr<bool>(), num_tokens, num_channels,
hidden_int4, num_scales, num_topk, expert_alignment,
rdma_channel_prefix_matrix.data_ptr<int>(), recv_rdma_rank_prefix_sum.data_ptr<int>(),
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu, head, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, low_latency_mode);
move_fifo_slots(3);
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_recv_tokens = static_cast<int>(*moe_recv_counter);
num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);
// Read per-expert count
bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);
for (int i = 0; i < num_local_experts and ready; ++ i)
ready &= moe_recv_expert_counter[i] >= 0;
if (ready)
break;
// Timeout check
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) {
printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens);
for (int i = 0; i < num_local_experts; ++ i)
printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]);
throw std::runtime_error("DeepEP error: timeout (dispatch CPU)");
}
}
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
// Allocate new tensors
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
auto recv_topk_idx = std::optional<torch::Tensor>(), recv_topk_weights = std::optional<torch::Tensor>(), recv_x_scales = std::optional<torch::Tensor>();
auto recv_src_meta = std::optional<torch::Tensor>();
auto recv_rdma_channel_prefix_matrix = std::optional<torch::Tensor>();
auto recv_gbl_channel_prefix_matrix = std::optional<torch::Tensor>();
auto send_rdma_head = std::optional<torch::Tensor>();
auto send_nvl_head = std::optional<torch::Tensor>();
if (not cached_mode) {
recv_src_meta = torch::empty({num_recv_tokens, internode::get_source_meta_bytes()}, dtype(torch::kByte).device(torch::kCUDA));
recv_rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
recv_gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
send_rdma_head = torch::empty({num_tokens, num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
send_nvl_head = torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, dtype(torch::kInt32).device(torch::kCUDA));
}
// Assign pointers
int64_t* recv_topk_idx_ptr = nullptr;
float* recv_topk_weights_ptr = nullptr;
float* recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) {
recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
if (x_scales.has_value()) {
recv_x_scales = x_scales->dim() == 1 ?
torch::empty({num_recv_tokens}, x_scales->options()) :
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
}
// Launch data dispatch
// NOTES: the buffer size checks are moved into the `.cu` file
internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr,
cached_mode ? nullptr : recv_src_meta->data_ptr(),
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
cached_mode ? nullptr : send_rdma_head->data_ptr<int>(), cached_mode ? nullptr : send_nvl_head->data_ptr<int>(),
cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr<int>(),
cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr<int>(),
rdma_channel_prefix_matrix.data_ptr<int>(), recv_rdma_rank_prefix_sum.data_ptr<int>(),
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
num_tokens, hidden_int4, num_scales, num_topk, num_experts,
is_token_in_rank.data_ptr<bool>(),
rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens,
rank, num_ranks, cached_mode,
comm_stream, num_channels, low_latency_mode);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto& t: {x, is_token_in_rank, recv_x,
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {x_scales, topk_idx, topk_weights,
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert,
cached_rdma_channel_prefix_matrix, cached_recv_rdma_rank_prefix_sum,
cached_gbl_channel_prefix_matrix, cached_recv_gbl_rank_prefix_sum,
recv_topk_idx, recv_topk_weights, recv_x_scales,
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head,
recv_src_meta}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
// Return values
return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list,
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix,
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum,
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
recv_src_meta, send_rdma_head, send_nvl_head, event};
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
const int num_channels = config.num_sms / 2;
EP_HOST_ASSERT(config.num_sms % 2 == 0);
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte);
EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and is_combined_token_in_rank.scalar_type() == torch::kBool);
EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and rdma_channel_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and rdma_rank_prefix_sum.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and gbl_channel_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and combined_rdma_head.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32);
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1)), hidden_int4 = static_cast<int>(x.size(1) * x.element_size() / sizeof(int4));
auto num_combined_tokens = static_cast<int>(is_combined_token_in_rank.size(0));
EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0);
EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes());
EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks);
EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and rdma_channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks);
EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and combined_rdma_head.size(1) == num_rdma_ranks);
EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Top-k checks
int num_topk = 0;
auto combined_topk_weights = std::optional<torch::Tensor>();
float* topk_weights_ptr = nullptr;
float* combined_topk_weights_ptr = nullptr;
if (topk_weights.has_value()) {
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(topk_weights->size(0) == num_tokens);
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
num_topk = static_cast<int>(topk_weights->size(1));
topk_weights_ptr = topk_weights->data_ptr<float>();
combined_topk_weights = torch::empty({num_combined_tokens, num_topk}, topk_weights->options());
combined_topk_weights_ptr = combined_topk_weights->data_ptr<float>();
}
// Extra check for avoid-dead-lock design
EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks);
// Launch barrier and reset queue head and tail
internode::cached_notify(hidden_int4, 0, 0, num_topk,
num_ranks, num_channels,
num_combined_tokens, combined_rdma_head.data_ptr<int>(),
rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu, head, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, false, low_latency_mode);
move_fifo_slots(2);
// Launch data combine
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
combined_x.data_ptr(), combined_topk_weights_ptr,
is_combined_token_in_rank.data_ptr<bool>(),
x.data_ptr(), topk_weights_ptr,
combined_rdma_head.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(), gbl_channel_prefix_matrix.data_ptr<int>(),
num_tokens, num_combined_tokens, hidden, num_topk,
rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens,
rank, num_ranks, comm_stream, num_channels, low_latency_mode);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto& t: {x, src_meta,
is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
combined_x, combined_rdma_head, combined_nvl_head}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {topk_weights, combined_topk_weights}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
// Return values
return {combined_x, combined_topk_weights, event};
}
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
EP_HOST_ASSERT(low_latency_mode);
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
auto clean_meta_0 = layout.buffers[0].clean_meta();
auto clean_meta_1 = layout.buffers[1].clean_meta();
auto check_boundary = [=](void* ptr, size_t num_bytes) {
auto offset = reinterpret_cast<int64_t>(ptr) - reinterpret_cast<int64_t>(rdma_buffer_ptr);
EP_HOST_ASSERT(0 <= offset and offset + num_bytes <= num_rdma_bytes);
};
check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int));
check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int));
internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second,
clean_meta_1.first, clean_meta_1.second,
at::cuda::getCurrentCUDAStream());
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode);
// Tensor checks
// By default using `ptp128c` FP8 cast
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0);
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(num_experts % num_ranks == 0);
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
int num_local_experts = num_experts / num_ranks;
// Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
auto compute_stream = at::cuda::getCurrentCUDAStream();
auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
EP_HOST_ASSERT(not (async and return_recv_hook));
if (not return_recv_hook)
stream_wait(launch_stream, compute_stream);
// Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(torch::kFloat8_e4m3fn));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::from_blob(buffer.dispatch_rdma_atomic_token_counter,
{num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Allocate column-majored scales
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
auto packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
packed_recv_x_scales = torch::transpose(packed_recv_x_scales, 1, 2);
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales.data_ptr<float>(),
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
workspace, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
}
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode);
// Tensor checks
EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks);
EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0);
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1));
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous());
EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous());
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0));
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
auto hidden = static_cast<int>(x.size(2));
auto num_local_experts = num_experts / num_ranks, num_topk = static_cast<int>(topk_weights.size(1));
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
// Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
auto compute_stream = at::cuda::getCurrentCUDAStream();
auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
EP_HOST_ASSERT(not (async and return_recv_hook));
if (not return_recv_hook)
stream_wait(launch_stream, compute_stream);
// Allocate output tensor
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::combine(combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
workspace, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values
return {combined_x, event, recv_hook};
}
} // namespace deep_ep
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "DeepEP: an efficient expert-parallel communication library";
pybind11::class_<deep_ep::Config>(m, "Config")
.def(pybind11::init<int, int, int, int, int>(),
py::arg("num_sms") = 20,
py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256,
py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256)
.def("get_nvl_buffer_size_hint", &deep_ep::Config::get_nvl_buffer_size_hint)
.def("get_rdma_buffer_size_hint", &deep_ep::Config::get_rdma_buffer_size_hint);
m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint);
pybind11::class_<deep_ep::EventHandle>(m, "EventHandle")
.def(pybind11::init<>())
.def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait);
pybind11::class_<deep_ep::Buffer>(m, "Buffer")
.def(pybind11::init<int, int, int64_t, int64_t, bool>())
.def("is_available", &deep_ep::Buffer::is_available)
.def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks)
.def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank)
.def("get_root_rdma_rank", &deep_ep::Buffer::get_root_rdma_rank)
.def("get_local_device_id", &deep_ep::Buffer::get_local_device_id)
.def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle)
.def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id)
.def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor)
.def("sync", &deep_ep::Buffer::sync)
.def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout)
.def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch)
.def("intranode_combine", &deep_ep::Buffer::intranode_combine)
.def("internode_dispatch", &deep_ep::Buffer::internode_dispatch)
.def("internode_combine", &deep_ep::Buffer::internode_combine)
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine);
}
#pragma once
// Forcibly disable NDEBUG
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/types.h>
#include <tuple>
#include <vector>
#include "config.hpp"
#include "event.hpp"
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
#ifndef TORCH_EXTENSION_NAME
#define TORCH_EXTENSION_NAME deep_ep_cpp
#endif
namespace deep_ep {
struct Buffer {
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8");
private:
// Low-latency mode buffer
int low_latency_buffer_idx = 0;
bool low_latency_mode = false;
// NVLink Buffer
int64_t num_nvl_bytes;
void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void** buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer
int64_t num_rdma_bytes;
void* rdma_buffer_ptr = nullptr;
// Device info and communication
int device_id;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
// Stream for communication
at::cuda::CUDAStream comm_stream;
// After IPC/NVSHMEM synchronization, this flag will be true
bool available = false;
// Task fifo
int head = 0;
int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** task_fifo_ptrs_gpu = nullptr;
// Workspace
void* workspace = nullptr;
// Host-side MoE info
volatile int* moe_recv_counter = nullptr;
int* moe_recv_counter_mapped = nullptr;
// Host-side expert-level MoE info
volatile int* moe_recv_expert_counter = nullptr;
int* moe_recv_expert_counter_mapped = nullptr;
// Host-side RDMA-level MoE info
volatile int* moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr;
private:
void move_fifo_slots(int num_slots = 1);
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode);
~Buffer() noexcept(false);
bool is_available() const;
bool is_internode_available() const;
int get_num_rdma_ranks() const;
int get_rdma_rank() const;
int get_root_rdma_rank(bool global) const;
int get_local_device_id() const;
pybind11::bytearray get_local_ipc_handle() const;
pybind11::bytearray get_local_nvshmem_unique_id() const;
torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;
void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event,
bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,
const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook);
};
} // namespace deep_ep
#include <ATen/cuda/CUDAContext.h>
#include <memory>
#include "kernels/exception.cuh"
namespace deep_ep {
struct EventHandle {
std::shared_ptr<torch::Event> event;
EventHandle() {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(at::cuda::getCurrentCUDAStream());
}
explicit EventHandle(const at::cuda::CUDAStream& stream) {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(stream);
}
EventHandle(const EventHandle& other) = default;
void current_stream_wait() const {
at::cuda::getCurrentCUDAStream().unwrap().wait(*event);
}
};
torch::Event create_event(const at::cuda::CUDAStream &s) {
auto event = torch::Event(torch::kCUDA);
event.record(s);
return event;
}
void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {
EP_HOST_ASSERT(s_0.id() != s_1.id());
s_0.unwrap().wait(create_event(s_1));
}
void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) {
s.unwrap().wait(*event.event);
}
} // namespace deep_ep
function(add_deep_ep_library target_name source_file)
add_library(${target_name} STATIC ${source_file})
set_target_properties(${target_name} PROPERTIES
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD_REQUIRED ON
CXX_STANDARD 14
CUDA_STANDARD 14
CUDA_SEPARABLE_COMPILATION ON
)
target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)
endfunction()
add_deep_ep_library(intranode_cuda intranode.cu)
add_deep_ep_library(runtime_cuda runtime.cu)
add_deep_ep_library(internode_cuda internode.cu)
add_deep_ep_library(internode_ll_cuda internode_ll.cu)
# Later, we should link all libraries in `EP_CUDA_LIBRARIES`
set(EP_CUDA_LIBRARIES intranode_cuda runtime_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)
#pragma once
#include <vector>
namespace deep_ep {
// Intranode runtime
namespace intranode {
void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
} // namespace intranode
// Internode runtime
namespace internode {
std::vector<uint8_t> get_unique_id();
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode);
void *alloc(size_t size, size_t alignment);
void free(void *ptr);
void barrier();
void finalize();
} // namespace internode
// Intranode kernels
namespace intranode {
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
cudaStream_t stream, int num_sms);
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, int num_ranks,
cudaStream_t stream);
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
void combine(cudaDataType_t type,
void* recv_x, float* recv_topk_weights,
const void* x, const float* topk_weights,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);
} // namespace intranode
// Internode kernels
namespace internode {
int get_source_meta_bytes();
void get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream);
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
const bool* is_token_in_rank, int num_tokens, int num_channels,
int hidden_int4, int num_scales, int num_topk, int expert_alignment,
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank,
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool low_latency_mode);
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
const bool* is_token_in_rank,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, bool is_cached_dispatch,
cudaStream_t stream, int num_channels, bool low_latency_mode);
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head,
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode);
void combine(cudaDataType_t type,
void* combined_x, float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const void* x, const float* topk_weights,
const int* combined_rdma_head, const int* combined_nvl_head,
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode);
} // namespace internode
// Internode low-latency kernels
namespace internode_ll {
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
cudaStream_t stream);
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases);
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases);
} // namespace internode_ll
} // namespace deep_ep
#pragma once
#include "configs.cuh"
#include "exception.cuh"
namespace deep_ep {
template <typename dtype_t>
struct Buffer {
private:
uint8_t* ptr;
public:
int total_bytes;
__device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}
__device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) {
total_bytes = num_elems * sizeof(dtype_t);
ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
return *this;
}
__device__ __forceinline__ dtype_t* buffer() {
return reinterpret_cast<dtype_t*>(ptr);
}
__device__ __forceinline__ dtype_t& operator[](int idx) {
return buffer()[idx];
}
};
template <typename dtype_t, int kNumRanks = 1>
struct AsymBuffer {
private:
uint8_t* ptrs[kNumRanks];
int num_bytes;
public:
int total_bytes;
__device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "");
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, "");
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
for (int i = 0; i < kNumRanks; ++ i) {
ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
}
}
__device__ __forceinline__ void advance(int shift) {
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i)
ptrs[i] = ptrs[i] + shift * sizeof(dtype_t);
}
__device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
return *this;
}
template<int kNumAlsoRanks>
__device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) {
for (int i = 0; i < kNumAlsoRanks; ++ i)
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
return *this;
}
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t*>(ptrs[rank_idx] + num_bytes * idx);
}
};
template <typename dtype_t, bool kDecoupled = true>
struct SymBuffer {
private:
// NOTES: for non-decoupled case, `recv_ptr` is not used
uint8_t* send_ptr;
uint8_t* recv_ptr;
int num_bytes;
public:
int total_bytes;
__device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1) {
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t*>(recv_ptr + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case");
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
}
};
} // namespace deep_ep
#pragma once
#define NUM_MAX_NVL_PEERS 8
#define NUM_MAX_RDMA_PEERS 20
#define NUM_MAX_FIFO_SLOTS 32768
#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)
#define NUM_MAX_LOCAL_EXPERTS 1024
#define NUM_BUFFER_ALIGNMENT_BYTES 128
#define FINISHED_SUM_TAG 1024
#define NUM_CPU_TIMEOUT_SECS 100
#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s
#define NUM_WAIT_NANOSECONDS 500
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
// Make CLion CUDA indexing work
#ifdef __CLION_IDE__
#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier)
#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier)
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
#define printf host_device_printf
#endif
// Remove Torch restrictions
#ifdef __CUDA_NO_HALF_CONVERSIONS__
#undef __CUDA_NO_HALF_CONVERSIONS__
#endif
#ifdef __CUDA_NO_HALF_OPERATORS__
#undef __CUDA_NO_HALF_OPERATORS__
#endif
#ifdef __CUDA_NO_HALF2_OPERATORS__
#undef __CUDA_NO_HALF2_OPERATORS__
#endif
#ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__
#undef __CUDA_NO_BFLOAT16_CONVERSIONS__
#endif
#ifdef __CUDA_NO_BFLOAT162_OPERATORS__
#undef __CUDA_NO_BFLOAT162_OPERATORS__
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <nvshmem.h>
#include <nvshmemx.h>
#include <infiniband/mlx5dv.h>
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
#include <device_host_transport/nvshmem_common_ibgda.h>
#pragma once
#include <string>
#include <exception>
#include "configs.cuh"
#ifndef EP_STATIC_ASSERT
#define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
#endif
class EPException: public std::exception {
private:
std::string message = {};
public:
explicit EPException(const char *name, const char* file, const int line, const std::string& error) {
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
}
const char *what() const noexcept override { return message.c_str(); }
};
#ifndef CUDA_CHECK
#define CUDA_CHECK(cmd) \
do { \
cudaError_t e = (cmd); \
if (e != cudaSuccess) { \
throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \
} \
} while (0)
#endif
#ifndef EP_HOST_ASSERT
#define EP_HOST_ASSERT(cond) \
do { \
if (not (cond)) { \
throw EPException("Assertion", __FILE__, __LINE__, #cond); \
} \
} while (0)
#endif
#ifndef EP_DEVICE_ASSERT
#define EP_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while (0)
#endif
// Portions derived from NVSHMEM (https://developer.nvidia.com/nvshmem)
// Copyright (c) NVIDIA Corporation.
// Licensed under the NVSHMEM Software License Agreement (version: September 3, 2019).
// See full license at: https://docs.nvidia.com/nvshmem/api/sla.html
//
// Modified from original source:
// - nvshmem/src/include/non_abi/device/pt-to-pt/ibgda_device.cuh
#pragma once
#include "configs.cuh"
#include "exception.cuh"
#include "utils.cuh"
namespace deep_ep {
EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth");
__device__ static __forceinline__
uint64_t HtoBE64(uint64_t x) {
uint64_t ret;
asm("{\n\t"
".reg .b32 ign;\n\t"
".reg .b32 lo;\n\t"
".reg .b32 hi;\n\t"
".reg .b32 new_lo;\n\t"
".reg .b32 new_hi;\n\t"
"mov.b64 {lo,hi}, %1;\n\t"
"prmt.b32 new_hi, lo, ign, 0x0123;\n\t"
"prmt.b32 new_lo, hi, ign, 0x0123;\n\t"
"mov.b64 %0, {new_lo,new_hi};\n\t"
"}" : "=l"(ret) : "l"(x));
return ret;
}
__device__ static __forceinline__
uint32_t HtoBE32(uint32_t x) {
uint32_t ret;
asm("{\n\t"
".reg .b32 ign;\n\t"
"prmt.b32 %0, %1, ign, 0x0123;\n\t"
"}" : "=r"(ret) : "r"(x));
return ret;
}
__device__ static __forceinline__
uint16_t HtoBE16(uint16_t x) {
// TODO: simplify PTX using 16-bit instructions
auto a = static_cast<uint32_t>(x);
uint32_t d;
asm volatile(
"{\n\t"
".reg .b32 mask;\n\t"
".reg .b32 ign;\n\t"
"mov.b32 mask, 0x4401;\n\t"
"mov.b32 ign, 0x0;\n\t"
"prmt.b32 %0, %1, ign, mask;\n\t"
"}"
: "=r"(d)
: "r"(a));
return static_cast<uint16_t>(d);
}
typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t;
__device__ static __forceinline__
nvshmemi_ibgda_device_state_t* ibgda_get_state() {
return &nvshmemi_ibgda_device_state_d;
}
__device__ static __forceinline__
nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) {
auto state = ibgda_get_state();
const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe;
return &state->globalmem.rcs[pe * num_rc_per_pe + id % num_rc_per_pe];
}
__device__ static __forceinline__
void ibgda_lock_acquire(int *lock) {
while (atomicCAS(lock, 0, 1) == 1);
// Prevent reordering before the lock is acquired
memory_fence_cta();
}
__device__ static __forceinline__
void ibgda_lock_release(int *lock) {
memory_fence_cta();
// Prevent reordering before lock is released
st_na_relaxed(lock, 0);
}
__device__ static __forceinline__
void ibgda_update_dbr(nvshmemi_ibgda_device_qp_t *qp, uint32_t dbrec_head) {
// `DBREC` contains the index of the next empty `WQEBB`
__be32 dbrec_val;
__be32 *dbrec_ptr = qp->tx_wq.dbrec;
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(dbrec_head & 0xffff))`
asm("{\n\t"
".reg .b32 dbrec_head_16b;\n\t"
".reg .b32 ign;\n\t"
"and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
"}"
: "=r"(dbrec_val)
: "r"(dbrec_head));
st_na_release(dbrec_ptr, dbrec_val);
}
__device__ static __forceinline__
void ibgda_ring_db(nvshmemi_ibgda_device_qp_t *qp, uint16_t prod_idx) {
auto bf_ptr = reinterpret_cast<uint64_t*>(qp->tx_wq.bf);
ibgda_ctrl_seg_t ctrl_seg = {
.opmod_idx_opcode = HtoBE32(prod_idx << 8),
.qpn_ds = HtoBE32(qp->qpn << 8)
};
EP_STATIC_ASSERT(sizeof(decltype(&ctrl_seg)) == sizeof(uint64_t), "");
st_na_release(bf_ptr, *(reinterpret_cast<uint64_t*>(&ctrl_seg)));
}
__device__ static __forceinline__
void ibgda_post_send(nvshmemi_ibgda_device_qp_t *qp, uint64_t new_prod_idx) {
nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
uint64_t old_prod_idx;
// Update `prod_idx` before ringing the doorbell, so that we know which index is needed in quiet/fence
ibgda_lock_acquire(&mvars->post_send_lock);
old_prod_idx = atomicMax(reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.prod_idx), new_prod_idx);
if (new_prod_idx > old_prod_idx) {
ibgda_update_dbr(qp, new_prod_idx);
ibgda_ring_db(qp, new_prod_idx);
}
ibgda_lock_release(&mvars->post_send_lock);
}
template <bool kAlwaysDoPostSend>
__device__ static __forceinline__
void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx,
uint32_t num_wqes, int message_idx = 0) {
nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
uint64_t new_wqe_idx = base_wqe_idx + num_wqes;
// WQE writes must be finished first
__threadfence();
// Wait for prior WQE slots to be filled first
auto *ready_idx = reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.ready_head);
while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx);
// Always post, not in batch
constexpr int kNumRequestInBatch = 4;
if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0)
ibgda_post_send(qp, new_wqe_idx);
}
__device__ static __forceinline__ void
ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *val, uint64_t raddr,
__be32 rkey, uint16_t wqe_idx, void **out_wqes, uint32_t imm) {
ibgda_ctrl_seg_t ctrl_seg;
struct mlx5_wqe_raddr_seg raddr_seg;
struct mlx5_wqe_inl_data_seg inl_seg;
auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
auto *raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
auto *inl_seg_ptr = reinterpret_cast<mlx5_wqe_inl_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
auto *wqe_data_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(inl_seg_ptr) + sizeof(*inl_seg_ptr));
raddr_seg.raddr = HtoBE64(raddr);
raddr_seg.rkey = rkey;
raddr_seg.reserved = 0;
inl_seg.byte_count = HtoBE32(4 | MLX5_INLINE_SEG);
// `imm == std::numeric_limits<uint32_t>::max()` means no imm writes
ctrl_seg = {0};
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | (imm != std::numeric_limits<uint32_t>::max() ? MLX5_OPCODE_RDMA_WRITE_IMM : MLX5_OPCODE_RDMA_WRITE));
if (imm != std::numeric_limits<uint32_t>::max())
ctrl_seg.imm = HtoBE32(imm);
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*inl_seg_ptr) == 4, "sizeof(*inl_seg_ptr) == 4");
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
st_na_relaxed(reinterpret_cast<uint32_t*>(inl_seg_ptr), *reinterpret_cast<const uint32_t*>(&inl_seg));
st_na_relaxed(reinterpret_cast<uint32_t*>(wqe_data_ptr), *reinterpret_cast<const uint32_t*>(val));
}
__device__ static __forceinline__
uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
auto state = ibgda_get_state();
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
auto log2_cumem_granularity = state->log2_cumem_granularity;
// Local key
uint64_t idx = (laddr - heap_start) >> log2_cumem_granularity;
auto device_key = state->constmem.lkeys[idx];
auto lchunk_size = device_key.next_addr - laddr;
*lkey = device_key.key;
// Remote key
uint64_t roffset = raddr - heap_start;
idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) {
device_key = state->constmem.rkeys[idx];
} else {
device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
}
*out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
*out_rkey = device_key.key;
// Return the minimum of local and remote chunk sizes
auto rchunk_size = device_key.next_addr - roffset;
return min(lchunk_size, rchunk_size);
}
__device__ static __forceinline__ void
ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
auto state = ibgda_get_state();
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
uint64_t roffset = addr - heap_start;
uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
nvshmemi_ibgda_device_key_t device_key;
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS)
device_key = state->constmem.rkeys[idx];
else
device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
*out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
*out_rkey = device_key.key;
}
__device__ static __forceinline__ uint64_t
ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) {
auto mvars = &qp->mvars;
return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
}
__device__ static __forceinline__ void*
ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
uint16_t cnt = qp->tx_wq.nwqes;
uint16_t idx = wqe_idx & (cnt - 1);
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));
}
// Wait until wqe `idx - 1` is completed.
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. It can only be used for polling recv.
// Because we post recv and poll recv in the same thread, so we don't need to maintain queue status.
__device__ static __forceinline__ void
nvshmemi_ibgda_poll_recv(int dst_pe, int qp_id) {
auto qp = ibgda_get_rc(dst_pe, qp_id);
auto cq = qp->rx_wq.cq;
const uint32_t ncqes = cq->ncqes;
auto *cqe64 = reinterpret_cast<struct mlx5_cqe64*>(cq->cqe);
auto old_cons_idx = *cq->cons_idx;
*cq->cons_idx = old_cons_idx + 1;
// Wait until `wqe_counter >= old_cons_idx`
while ((static_cast<uint16_t>(old_cons_idx - HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)) - 1) < ncqes));
}
__device__ static __forceinline__ void
nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits<uint32_t>::max()) {
// Get rkey
// NOTES: the `p` operation will not cross multiple remote chunks
__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey);
// Write WQEs
auto qp = ibgda_get_rc(dst_pe, qp_id);
uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs;
wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx);
ibgda_write_rdma_write_inl_wqe(qp, reinterpret_cast<const uint32_t*>(&value), raddr, rkey, base_wqe_idx, &wqe_ptrs, imm);
// Submit requests
ibgda_submit_requests<true>(qp, base_wqe_idx, 1);
}
__device__ static __forceinline__ void
ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be32 lkey,
uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx,
void **out_wqes) {
ibgda_ctrl_seg_t ctrl_seg;
struct mlx5_wqe_raddr_seg raddr_seg;
struct mlx5_wqe_data_seg data_seg;
auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
void *av_seg_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
struct mlx5_wqe_raddr_seg *raddr_seg_ptr;
struct mlx5_wqe_data_seg *data_seg_ptr;
raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(av_seg_ptr));
data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
raddr_seg.raddr = HtoBE64(raddr);
raddr_seg.rkey = rkey;
raddr_seg.reserved = 0;
data_seg.byte_count = HtoBE32(bytes);
data_seg.lkey = lkey;
data_seg.addr = HtoBE64(laddr);
ctrl_seg = {0};
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE);
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == 16, "sizeof(*data_seg_ptr) == 16");
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
}
__device__ static __forceinline__ void
ibgda_write_empty_recv_wqe(void *out_wqe) {
auto *data_seg_ptr = reinterpret_cast<struct mlx5_wqe_data_seg*>(out_wqe);
struct mlx5_wqe_data_seg data_seg;
// Make the first segment in the WQE invalid, then the entire list will be invalid
data_seg.byte_count = 0;
data_seg.lkey = HtoBE64(MLX5_INVALID_LKEY);
data_seg.addr = 0;
EP_STATIC_ASSERT(sizeof(mlx5_wqe_data_seg) == sizeof(int4), "Invalid data type length");
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
}
__device__ static __forceinline__ uint64_t
nvshmemi_ibgda_allocate_recvs(nvshmemi_ibgda_device_qp* qp) {
auto mvars = &qp->mvars;
// Allocate if not enough
constexpr int kMinIBGDARecvs = 32;
auto resv_head = mvars->rx_wq.resv_head;
auto num_valid_slots = resv_head - mvars->rx_wq.cons_idx;
if (num_valid_slots < kMinIBGDARecvs) {
resv_head = mvars->rx_wq.cons_idx + qp->rx_wq.nwqes;
mvars->rx_wq.resv_head = resv_head;
// Ensure WQE is written before `dbrec`
__be32 dbrec_val;
__be32 *dbrec_ptr = qp->rx_wq.dbrec;
// Compared to sending, for each QP, we only post recv in a single thread,
// so we don't need to do synchronization here
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(wqe_idx & 0xffff))`
asm("{\n\t"
".reg .b32 dbrec_head_16b;\n\t"
".reg .b32 ign;\n\t"
"and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
"}" : "=r"(dbrec_val)
: "r"(static_cast<uint32_t>(resv_head)));
st_na_release(dbrec_ptr, dbrec_val);
}
// Return old number of slots
return num_valid_slots;
}
__device__ static __forceinline__ void
nvshmemi_ibgda_prepare_recvs(int dst_rank, int qp_id) {
// NOTES: only one thread can run this function
// TODO: consider this assertion for normal AR
EP_DEVICE_ASSERT(nvshmemi_ibgda_allocate_recvs(ibgda_get_rc(dst_rank, qp_id)) > 16);
}
__device__ static __forceinline__ void
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
// Get lkey and rkey, store them into lanes
uint32_t num_wqes = 0;
__be32 my_lkey = 0;
uint64_t my_laddr = 0;
__be32 my_rkey = 0;
uint64_t my_raddr = 0;
uint64_t my_chunk_size = 0;
// Decide how many messages (theoretically 3 for maximum)
auto remaining_bytes = bytes;
while (remaining_bytes > 0) {
if (lane_id == num_wqes)
my_chunk_size = min(remaining_bytes, ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey));
// Move one more message
auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes));
remaining_bytes -= chunk_size;
req_lptr += chunk_size;
req_rptr += chunk_size;
++ num_wqes;
}
EP_DEVICE_ASSERT(num_wqes <= 32);
// Process WQE
auto qp = ibgda_get_rc(dst_pe, qp_id);
uint64_t base_wqe_idx = 0;
if (lane_id == 0)
base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes);
base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0);
if (lane_id < num_wqes) {
auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id);
ibgda_write_rdma_write_wqe(qp, my_laddr, my_lkey, my_raddr, my_rkey, my_chunk_size,
base_wqe_idx, &wqe_ptr);
}
__syncwarp();
// Submit
if (lane_id == 0)
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
__syncwarp();
}
} // namespace deep_ep
#include "configs.cuh"
#include "buffer.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
namespace deep_ep {
namespace internode {
extern nvshmem_team_t cpu_rdma_team;
template<int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
__global__ void __launch_bounds__(kNumThreads, 1)
get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x);
// Count expert statistics
__shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM];
int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts);
if (expert_begin_idx < expert_end_idx) {
// Per-thread count
#pragma unroll
for (int i = 0; i < kNumExpertsPerSM; ++ i)
num_tokens_per_expert_per_thread[thread_id][i] = 0;
#pragma unroll
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
auto shifted_topk_idx = topk_idx + i * num_topk;
#pragma unroll
for (int j = 0, expert_idx; j < num_topk; ++ j) {
expert_idx = static_cast<int>(shifted_topk_idx[j]);
if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx)
++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx];
}
}
__syncthreads();
// Sum up
EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM");
if (expert_begin_idx + thread_id < expert_end_idx) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumThreads; ++ i)
sum += num_tokens_per_expert_per_thread[i][thread_id];
num_tokens_per_expert[expert_begin_idx + thread_id] = sum;
}
return;
}
if (num_tokens_per_rdma_rank != nullptr)
EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS);
// Count rank statistics
constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS;
__shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM];
__shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM];
auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM;
int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks);
int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS;
if (rank_begin_idx < rank_end_idx) {
const auto num_expert_per_rank = num_experts / num_ranks;
auto expert_begin = rank_begin_idx * num_expert_per_rank;
auto expert_end = rank_end_idx * num_expert_per_rank;
// Per-thread count
#pragma unroll
for (int i = 0; i < kNumRanksPerSM; ++ i)
num_tokens_per_rank_per_thread[thread_id][i] = 0;
#pragma unroll
for (int i = 0; i < kNumRDMARanksPerSM; ++ i)
num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0;
#pragma unroll
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
auto shifted_topk_idx = topk_idx + i * num_topk;
int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0};
#pragma unroll
for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) {
expert_idx = static_cast<int>(shifted_topk_idx[j]);
if (expert_begin <= expert_idx and expert_idx < expert_end) {
// Count single rank
rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx;
is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++;
}
}
auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks;
#pragma unroll
for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) {
shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0);
num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0);
}
#pragma unroll
for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j)
num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0);
}
__syncthreads();
// Sum up
EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM");
if (rank_begin_idx + thread_id < rank_end_idx) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumThreads; ++ i)
sum += num_tokens_per_rank_per_thread[i][thread_id];
num_tokens_per_rank[rank_begin_idx + thread_id] = sum;
}
if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumThreads; ++ i)
sum += num_tokens_per_rdma_rank_per_thread[i][thread_id];
num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum;
}
}
}
void get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream) {
constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8;
int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM");
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank,
num_tokens, num_topk, num_ranks, num_experts);
}
struct SourceMeta {
int src_rdma_rank, is_token_in_nvl_rank_bits;
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers");
__forceinline__ SourceMeta() = default;
// TODO: faster encoding
__device__ __forceinline__ SourceMeta(int rdma_rank, const bool* is_token_in_nvl_ranks) {
src_rdma_rank = rdma_rank;
is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0];
#pragma unroll
for (int i = 1; i < NUM_MAX_NVL_PEERS; ++ i)
is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i;
}
__device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const {
return (is_token_in_nvl_rank_bits >> nvl_rank) & 1;
}
};
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`");
int get_source_meta_bytes() {
return sizeof(SourceMeta);
}
__host__ __device__ __forceinline__
int get_num_bytes_per_rdma_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) {
return static_cast<int>(align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4)));
}
__host__ __device__ __forceinline__
std::pair<int, int> get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
// Return `int32_t` offset and count to clean
return {
(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms
};
}
__host__ __device__ __forceinline__
std::pair<int, int> get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens, int num_sms) {
// Return `int32_t` offset and to clean
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`");
return {
(num_nvl_recv_buffer_tokens * (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float) + sizeof(SourceMeta)) * num_nvl_ranks * num_sms) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
};
}
template <bool kLowLatencyMode>
__forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank, const int nvl_rank) {
return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank;
}
template <bool kLowLatencyMode>
__forceinline__ __device__ void nvshmem_barrier_with_same_gpu_idx(const nvshmem_team_t& rdma_team) {
kLowLatencyMode ? void(nvshmem_barrier(rdma_team)) : nvshmem_barrier_all();
}
template <bool kLowLatencyMode, int kNumRDMARanks>
__global__ void
notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
const bool* is_token_in_rank, int num_tokens, int num_channels, int expert_alignment,
const int rdma_clean_offset, const int rdma_num_int_clean,
const int nvl_clean_offset, const int nvl_num_int_clean,
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
void* rdma_buffer_ptr,
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank,
const nvshmem_team_t rdma_team) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();
auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto num_rdma_experts = num_experts / kNumRDMARanks, num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS;
if (sm_id == 0) {
// Communication with others
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
EP_DEVICE_ASSERT(num_warps > 1);
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
if (thread_id == 32)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
__syncthreads();
// Send numbers of tokens per rank/expert to RDMA ranks
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
auto rdma_recv_num_tokens_mixed = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);
// Clean up for later data dispatch
EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int));
#pragma unroll
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
// Copy to send buffer
#pragma unroll
for (int i = thread_id; i < num_ranks; i += num_threads)
rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] = num_tokens_per_rank[i];
#pragma unroll
for (int i = thread_id; i < num_experts; i += num_threads)
rdma_recv_num_tokens_mixed.send_buffer(i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] = num_tokens_per_expert[i];
if (thread_id < kNumRDMARanks)
rdma_recv_num_tokens_mixed.send_buffer(thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] = num_tokens_per_rdma_rank[thread_id];
__syncthreads();
// Issue send
// TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning
if (thread_id < kNumRDMARanks) {
nvshmem_int_put_nbi(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), rdma_recv_num_tokens_mixed.send_buffer(thread_id),
NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank));
}
__syncthreads();
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads();
// NVL buffers
auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr;
auto nvl_recv_buffer = buffer_ptrs[nvl_rank];
auto nvl_reduced_num_tokens_per_expert = Buffer<int>(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer);
auto nvl_send_num_tokens_per_rank = AsymBuffer<int>(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
auto nvl_send_num_tokens_per_expert = AsymBuffer<int>(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
auto nvl_recv_num_tokens_per_rank = AsymBuffer<int>(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
auto nvl_recv_num_tokens_per_expert = AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
// Clean up for later data dispatch
auto nvl_buffer_ptr_int = reinterpret_cast<int*>(buffer_ptrs[nvl_rank]);
EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes +
nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int));
#pragma unroll
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
// Reduce number of tokens per expert into the NVL send buffer
// TODO: may use NVSHMEM reduction
EP_DEVICE_ASSERT(num_rdma_experts <= num_threads);
if (thread_id < num_rdma_experts) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumRDMARanks; ++ i)
sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id];
nvl_reduced_num_tokens_per_expert[thread_id] = sum;
}
__syncthreads();
// Reduce RDMA received tokens
if (thread_id == 0) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumRDMARanks; ++ i) {
sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
recv_rdma_rank_prefix_sum[i] = sum;
}
while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1);
*moe_recv_rdma_counter_mapped = sum;
}
// Send numbers of tokens per rank/expert to NVL ranks
EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads);
if (thread_id < NUM_MAX_NVL_PEERS) {
#pragma unroll
for (int i = 0; i < kNumRDMARanks; ++ i)
nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] = rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id];
#pragma unroll
for (int i = 0; i < num_nvl_experts; ++ i)
nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
}
memory_fence();
__syncthreads();
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
__syncthreads();
// Reduce number of tokens per rank/expert
EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
if (thread_id == 0) {
int sum = 0;
#pragma unroll
for (int i = 0; i < num_ranks; ++ i) {
int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS;
sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
recv_gbl_rank_prefix_sum[i] = sum;
}
while (ld_volatile_global(moe_recv_counter_mapped) != -1);
*moe_recv_counter_mapped = sum;
}
if (thread_id < num_nvl_experts) {
int sum = 0;
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1);
moe_recv_expert_counter_mapped[thread_id] = sum;
}
// Finally barrier
__syncthreads();
if (thread_id == 32)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
} else {
// Calculate meta data
int dst_rdma_rank = sm_id - 1;
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// Iterate over tokens
int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0};
for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) {
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
auto is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS);
auto is_token_in_rank_values = reinterpret_cast<const bool*>(&is_token_in_rank_uint64);
#pragma unroll
for (int j = 0; j < NUM_MAX_NVL_PEERS; ++ j)
per_nvl_rank_count[j] += is_token_in_rank_values[j];
total_count += (is_token_in_rank_uint64 != 0);
}
// Warp reduce
total_count = warp_reduce_sum(total_count);
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]);
// Write into channel matrix
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * num_channels + channel_id] = per_nvl_rank_count[i];
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count;
}
}
// Calculate prefix sum
__syncthreads();
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
if (thread_id < kNumRDMARanks) {
auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels;
#pragma unroll
for (int i = 1; i < num_channels; ++ i)
prefix_row[i] += prefix_row[i - 1];
}
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers");
if (thread_id < NUM_MAX_NVL_PEERS) {
auto prefix_row = gbl_channel_prefix_matrix + (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels;
#pragma unroll
for (int i = 1; i < num_channels; ++ i)
prefix_row[i] += prefix_row[i - 1];
}
}
}
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
const bool* is_token_in_rank, int num_tokens, int num_channels,
int hidden_int4, int num_scales, int num_topk, int expert_alignment,
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank,
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool low_latency_mode) {
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
auto notify_dispatch_func = low_latency_mode ? \
notify_dispatch<true, num_rdma_ranks> : notify_dispatch<false, num_rdma_ranks>; \
LAUNCH_KERNEL(&cfg, notify_dispatch_func, \
num_tokens_per_rank, moe_recv_counter_mapped, num_ranks, \
num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, \
num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \
is_token_in_rank, num_tokens, num_channels, expert_alignment, \
rdma_clean_meta.first, rdma_clean_meta.second, \
nvl_clean_meta.first, nvl_clean_meta.second, \
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
rdma_buffer_ptr, \
buffer_ptrs, task_fifo_ptrs, head, rank, \
cpu_rdma_team); } break
constexpr int kNumThreads = 256;
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Get clean meta
auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels);
auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes);
EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
// Launch kernel
SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream);
SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
#undef NOTIFY_DISPATCH_LAUNCH_CASE
}
// At most 8 RDMA ranks to be sent
constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
return num_rdma_ranks < 8 ? num_rdma_ranks : 8;
}
template <bool kLowLatencyMode, int kNumRDMARanks, bool kCachedMode,
int kNumDispatchRDMASenderWarps, int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
__global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1)
dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta,
const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
const bool* is_token_in_rank,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks) {
enum class WarpRole {
kRDMASender,
kRDMASenderCoordinator,
kRDMAAndNVLForwarder,
kForwarderCoordinator,
kNVLReceivers
};
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2;
const bool is_forwarder = sm_id % 2 == 0;
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
const auto role_meta = [=]() -> std::pair<WarpRole, int> {
if (is_forwarder) {
if (warp_id < NUM_MAX_NVL_PEERS) {
return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
} else {
return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};
}
} else if (warp_id < kNumDispatchRDMASenderWarps) {
return {WarpRole::kRDMASender, -1};
} else if (warp_id == kNumDispatchRDMASenderWarps) {
return {WarpRole::kRDMASenderCoordinator, -1};
} else {
return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS};
}
}();
auto warp_role = role_meta.first;
auto target_rank = role_meta.second; // Not applicable for RDMA senders
EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS);
// Data checks
EP_DEVICE_ASSERT(num_topk <= 32);
// RDMA symmetric layout
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// NVL buffer layouts
// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;
int rs_wr_rank = 0, ws_rr_rank = 0;
if (warp_role == WarpRole::kRDMAAndNVLForwarder)
rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
if (warp_role == WarpRole::kNVLReceivers)
rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
// Allocate buffers
auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_x_scales = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_topk_idx = AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_topk_weights = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
// RDMA sender warp synchronization
__shared__ volatile int rdma_send_next_token_idx;
__shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
__shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
// Forward warp synchronization
__shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
__shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
auto sync_forwarder_smem = []() { asm volatile("bar.sync 1, %0;" :: "r"((NUM_MAX_NVL_PEERS + 1) * 32)); };
if (warp_role == WarpRole::kRDMASender) {
// Get tasks
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// Clean shared memory
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
(warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0;
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0;
// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers");
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
if (lane_id < NUM_MAX_NVL_PEERS) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
} else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
}
nvshmemx_int_put_nbi_warp(rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
nvshmem_fence();
sync_rdma_sender_smem();
// Iterate over tokens and copy into buffer
int64_t token_idx;
int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
// Read RDMA rank existence
uint64_t is_token_in_rank_uint64 = 0;
if (lane_id < kNumRDMARanks)
is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
// Acquire sequential lock
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
__syncwarp();
// Acquire next tail
int rdma_tail_idx = -1;
if (is_token_in_rank_uint64 != 0) {
rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++;
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens)
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
}
__syncwarp();
// Store RDMA head for combine
if (lane_id < kNumRDMARanks and not kCachedMode)
send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
// Update last token tail
if (last_rdma_tail_idx >= 0)
st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
last_rdma_tail_idx = rdma_tail_idx;
// Release sequential lock
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
// Broadcast tails
SourceMeta src_meta;
int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
void* dst_send_buffers[kNumTopkRDMARanks];
#pragma unroll
for (int i = 0, slot_idx; i < kNumRDMARanks; ++ i) if ((slot_idx = __shfl_sync(0xffffffff, rdma_tail_idx, i)) >= 0) {
slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;
topk_ranks[num_topk_ranks] = i;
auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);
auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);
if (lane_id == num_topk_ranks)
src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
dst_send_buffers[num_topk_ranks ++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;
}
EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);
// Copy `x` into symmetric send buffer
auto st_broadcast = [=](const int key, const int4& value) {
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);
};
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
// Copy source metadata into symmetric send buffer
if (lane_id < num_topk_ranks)
st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
// Copy `x_scales` into symmetric send buffer
#pragma unroll
for (int i = lane_id; i < num_scales; i += 32) {
auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
}
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
// Copy `topk_idx` and `topk_weights` into symmetric send buffer
#pragma unroll
for (int i = lane_id; i < num_topk * num_topk_ranks; i += 32) {
auto rank_idx = i / num_topk, copy_idx = i % num_topk;
auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);
st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
}
}
// Epilogue
// Acquire sequential lock
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
__syncwarp();
// Update last token tail
if (last_rdma_tail_idx >= 0)
st_release_cta(const_cast<const int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
// Release sequential lock
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
} else if (warp_role == WarpRole::kRDMASenderCoordinator) {
// NOTES: in case of splitting the issued put at the end of the buffer
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
// Synchronize shared memory
sync_rdma_sender_smem();
// Get number of tokens to send for each RDMA rank
int num_tokens_to_send = 0;
if (lane_id < kNumRDMARanks) {
num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];
if (channel_id > 0)
num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];
}
// Iterate all RDMA ranks
int last_issued_tail = 0;
while (__any_sync(0xffffffff, num_tokens_to_send > 0)) {
for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) {
int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;
synced_num_tokens_to_send = __shfl_sync(0xffffffff, num_tokens_to_send, dst_rdma_rank);
if (synced_num_tokens_to_send == 0)
continue;
// Read progress
auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);
auto processed_tail = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank));
auto num_tokens_processed = processed_tail - synced_last_issued_tail;
if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens)
continue;
// Issue RDMA send
auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);
EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 and num_tokens_to_issue <= synced_num_tokens_to_send);
if (dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
num_bytes_per_rdma_token * num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
nvshmem_fence();
} else {
// Lighter fence for local RDMA rank
memory_fence();
}
// Update tails
__syncwarp();
if (lane_id == dst_rdma_rank) {
last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue;
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
}
}
} else if (warp_role == WarpRole::kRDMAAndNVLForwarder) {
// RDMA consumers and NVL producers
const auto dst_nvl_rank = target_rank;
const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank;
const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks);
const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks);
// Wait counters to arrive
int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;
EP_DEVICE_ASSERT(kNumRDMARanks <= 32);
auto start_time = clock64();
if (lane_id < kNumRDMARanks) {
while (true) {
auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank);
auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank);
auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2);
auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1);
if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) {
// Notify NVL ranks
int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;
EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum);
st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);
st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);
// Save RDMA channel received token count
src_rdma_channel_prefix = -meta_2 - 1;
auto src_rdma_channel_prefix_1 = -meta_3 - 1;
num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix;
if (not kCachedMode)
recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;
src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1];
EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
break;
}
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3);
trap();
}
}
}
__syncwarp();
// Shift cached head
send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank;
// Wait shared memory to be cleaned
sync_forwarder_smem();
// Forward tokens from RDMA buffer
// NOTES: always start from the local rank
int src_rdma_rank = sm_id % kNumRDMARanks;
int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;
int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0;
while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) {
// Check destination queue emptiness, or wait a buffer to be released
start_time = clock64();
while (lane_id == 0) {
int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)
break;
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
trap();
}
}
__syncwarp();
// Find next source RDMA rank (round-robin)
start_time = clock64();
while (true) {
src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;
if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {
if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail)
cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank))
break;
}
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n",
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma);
trap();
}
}
auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank);
auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank);
// Iterate over every token from the RDMA buffer
for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) {
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;
bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
if (lane_id == src_rdma_rank) {
auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;
rdma_nvl_token_idx += is_in_dst_nvl_rank;
if (not kCachedMode)
send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;
}
if (not is_in_dst_nvl_rank)
continue;
// Get an empty slot
int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens;
// Copy data
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
reinterpret_cast<int4*>(shifted),
ld_nc_global, st_na_global);
shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
// Copy source meta
if (lane_id == 0)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
// Copy `x_scales`
UNROLLED_WARP_COPY(1, lane_id, num_scales,
nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
reinterpret_cast<float*>(shifted),
ld_nc_global, st_na_global);
shifted = reinterpret_cast<float*>(shifted) + num_scales;
// Copy `topk_idx` and `topk_weights`
// NOTES: do not use `shifted` after this `if`, because only several lanes are shifted
if (lane_id < num_topk) {
// Read
auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
shifted = reinterpret_cast<int*>(shifted) + num_topk;
auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);
// Transform and write
idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value);
weight_value = idx_value >= 0 ? weight_value : 0.0f;
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value);
}
// In case of insufficient NVL buffers, early stopping
if ((++ num_tokens_sent) == num_max_nvl_chunked_send_tokens)
src_rdma_tail = i + 1;
}
// Sync head index
if (lane_id == src_rdma_rank)
forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);
// Move tail index
__syncwarp();
if (lane_id == 0)
st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);
}
// Retired
__syncwarp();
if (lane_id == 0)
forward_channel_retired[dst_nvl_rank] = true;
} else if (warp_role == WarpRole::kForwarderCoordinator) {
// Extra warps for forwarder coordinator should exit directly
if (target_rank > 0)
return;
// Forward warp coordinator
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
// Clean shared memory
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers");
#pragma unroll
for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32)
forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0;
if (lane_id < NUM_MAX_NVL_PEERS)
forward_channel_retired[lane_id] = false;
sync_forwarder_smem();
int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0;
while (true) {
// Find minimum head
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) if (not forward_channel_retired[i])
min_head = min(min_head, forward_channel_head[i][target_rdma]);
if (__all_sync(0xffffffff, min_head == std::numeric_limits<int>::max()))
break;
// Update remote head
if (min_head != std::numeric_limits<int>::max() and min_head > last_head and lane_id < kNumRDMARanks)
nvshmem_uint64_p(rdma_channel_head.buffer(rdma_rank), last_head = min_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
// Nanosleep and let other warps work
__nanosleep(NUM_WAIT_NANOSECONDS);
}
} else {
// NVL consumers
// Retrieve rank offset from barrier results (each lane's register stores an RDMA rank)
int src_nvl_rank = target_rank, total_offset = 0;
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)
total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];
// Receive channel offsets
int start_offset = 0, end_offset = 0, num_tokens_to_recv;
auto start_time = clock64();
while (lane_id < kNumRDMARanks) {
start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);
end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);
if (start_offset < 0 and end_offset < 0) {
start_offset = -start_offset - 1, end_offset = -end_offset - 1;
total_offset += start_offset;
break;
}
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset);
trap();
}
}
num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);
// Save for combine usage
if (lane_id < kNumRDMARanks and not kCachedMode)
recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset;
__syncwarp();
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
while (num_tokens_to_recv > 0) {
// Check channel status by lane 0
start_time = clock64();
while (lane_id == 0) {
// Ready to copy
if (cached_channel_head_idx != cached_channel_tail_idx)
break;
cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
trap();
}
}
// Sync queue tail
cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0);
// Copy data
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) {
int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens;
auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank);
(lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;
// Copy data
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
recv_x + recv_token_idx * hidden_int4,
nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,
ld_nc_global, st_na_global);
// Copy source meta
if (lane_id == 0 and not kCachedMode)
st_na_global(recv_src_meta + recv_token_idx, meta);
// Copy scales
UNROLLED_WARP_COPY(1, lane_id, num_scales,
recv_x_scales + recv_token_idx * num_scales,
nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
ld_nc_global, st_na_global);
// Copy `topk_idx` and `topk_weights`
if (lane_id < num_topk) {
auto recv_idx = recv_token_idx * num_topk + lane_id;
auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;
st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
}
}
// Move queue
__syncwarp();
if (lane_id == 0)
st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
}
}
}
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
const bool* is_token_in_rank,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, bool is_cached_dispatch,
cudaStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumDispatchRDMASenderWarps = 7;
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
auto dispatch_func = low_latency_mode ? \
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) : \
(is_cached_dispatch ? dispatch<false, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<false, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>); \
LAUNCH_KERNEL(&cfg, dispatch_func, \
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_topk_idx, recv_topk_weights, reinterpret_cast<SourceMeta*>(recv_src_meta), \
reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
send_rdma_head, send_nvl_head, \
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
is_token_in_rank, \
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
rank, num_ranks); } break
EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));
SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream);
SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}
template <bool kLowLatencyMode>
__global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean,
const int nvl_clean_offset, const int nvl_num_int_clean,
int* combined_rdma_head, int num_combined_tokens, int num_channels,
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
void* rdma_buffer_ptr,
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks,
bool is_cached_dispatch, const nvshmem_team_t rdma_team) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x);
auto num_threads = static_cast<int>(blockDim.x);
auto num_warps = num_threads / 32;
auto warp_id = thread_id / 32;
auto lane_id = get_lane_id();
auto nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Using two SMs, which clean the RDMA/NVL buffer respectively
if (sm_id == 0) {
// Barrier for RDMA
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads();
// Clean
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
#pragma unroll
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
nvshmem_fence();
__syncthreads();
// Barrier again
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
} else if (sm_id == 1) {
// Barrier for NVL
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
__syncthreads();
// Clean
auto nvl_buffer_ptr_int = reinterpret_cast<int*>(buffer_ptrs[nvl_rank]);
#pragma unroll
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
memory_fence();
__syncthreads();
// Barrier again
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
} else if (sm_id == 2) {
if (is_cached_dispatch)
return;
EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(num_rdma_ranks <= 32);
// Iterate in reverse order
if (lane_id < num_rdma_ranks and warp_id < num_channels) {
int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx);
// NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25;
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) {
auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
if (current_head < 0) {
combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
} else {
last_head = current_head;
}
}
}
} else {
if (is_cached_dispatch)
return;
EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr);
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers");
if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 3) {
// Iterate in reverse order
int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
token_start_idx += shift, token_end_idx += shift;
// NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25;
#pragma unroll
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) {
auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
if (current_head < 0) {
combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
} else {
last_head = current_head;
}
}
}
}
}
}
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head,
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode) {
const int num_threads = std::max(128, 32 * num_channels);
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Get clean meta
auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels);
auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes);
EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_channels * 2 > 3);
// Launch kernel
auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
LAUNCH_KERNEL(&cfg, cached_notify_func,
rdma_clean_meta.first, rdma_clean_meta.second,
nvl_clean_meta.first, nvl_clean_meta.second,
combined_rdma_head, num_combined_tokens, num_channels,
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head,
rdma_buffer_ptr,
buffer_ptrs, task_fifo_ptrs, head, rank, num_ranks,
is_cached_dispatch, cpu_rdma_team);
}
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights,
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
// Broadcast current heads
// Lane `i` holds the head of rank `i` and `is_token_in_rank`
EP_STATIC_ASSERT(kMaxNumRanks <= 32, "Too many ranks");
int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i) if (__shfl_sync(0xffffffff, is_token_in_rank, i)) {
slot_indices[num_topk_ranks] = __shfl_sync(0xffffffff, head_idx, i) % num_max_recv_tokens;
topk_ranks[num_topk_ranks ++] = i;
}
EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);
// Reduce data
#pragma unroll
for (int i = lane_id; i < hidden_int4; i += 32) {
// Read buffers
// TODO: maybe too many registers here
int4 recv_value_int4[kMaxNumRanks];
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
// Reduce all-to-all results
float values[kDtypePerInt4] = {0};
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) {
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
#pragma unroll
for (int k = 0; k < kDtypePerInt4; ++ k)
values[k] += static_cast<float>(recv_value_dtypes[k]);
}
// Cast back to `dtype_t` and write
int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]);
st_na_global(combined_row + i, out_int4);
}
// Reduce `topk_weights`
if (lane_id < num_topk) {
float value = 0;
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);
st_na_global(combined_topk_weights + lane_id, value);
}
// Return the minimum top-k rank
return topk_ranks[0];
}
template<bool kLowLatencyMode,
int kNumRDMARanks, typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks),
int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder,
int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS>
__global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, 1)
combine(int4* combined_x, float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const int4* x, const float* topk_weights,
const int* combined_rdma_head, const int* combined_nvl_head,
const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks) {
enum class WarpRole {
kNVLSender,
kNVLAndRDMAForwarder,
kRDMAReceiver,
kCoordinator
};
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2;
const bool is_rdma_receiver_sm = sm_id % 2 == 1;
EP_DEVICE_ASSERT(num_topk <= 32);
EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0);
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
// NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto role_meta = [=]() -> std::pair<WarpRole, int> {
auto warp_id = thread_id / 32;
if (not is_rdma_receiver_sm) {
if (warp_id < NUM_MAX_NVL_PEERS) {
auto shuffled_warp_id = warp_id;
shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS;
return {WarpRole::kNVLSender, shuffled_warp_id};
} else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) {
auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS;
shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders;
return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id};
} else {
return {WarpRole::kCoordinator, 0};
}
} else {
if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) {
return {WarpRole::kRDMAReceiver, warp_id};
} else {
return {WarpRole::kCoordinator, 0};
}
}
}();
auto warp_role = role_meta.first;
auto warp_id = role_meta.second;
EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1);
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
if (warp_role == WarpRole::kNVLSender) {
// NVL producers
const auto dst_nvl_rank = warp_id;
// NVL layouts
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_topk_weights = AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
// Get tasks for each RDMA lane
int token_start_idx = 0, token_end_idx = 0;
if (lane_id < kNumRDMARanks) {
int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
}
__syncwarp();
// NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
// Iterate over all tokens and send by chunks
while (true) {
// Exit if possible
if (__all_sync(0xffffffff, token_start_idx >= token_end_idx))
break;
// Decide next RDMA buffer to send
bool is_lane_ready = false;
auto start_time = clock64();
while (true) {
int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;
if (__any_sync(0xffffffff, is_lane_ready))
break;
// Retry
if (lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, ld_volatile_global(nvl_channel_head.buffer() + lane_id), cached_channel_tail_idx,
token_start_idx, token_end_idx);
trap();
}
}
// Sync token start index and count
for (int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++ current_rdma_idx) {
if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
continue;
// Sync token start index
auto token_idx = static_cast<int64_t>(__shfl_sync(0xffffffff, token_start_idx, current_rdma_idx));
int num_tokens_in_chunk = __shfl_sync(0xffffffff, min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);
// Send by chunk
for (int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++ chunk_idx, ++ token_idx) {
// Get an empty slot
int dst_slot_idx = 0;
if (lane_id == current_rdma_idx) {
dst_slot_idx = (cached_channel_tail_idx ++) % num_max_nvl_chunked_recv_tokens_per_rdma;
dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
}
dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx);
// Copy data
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
// Copy source meta
if (lane_id == 0)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
// Copy `topk_weights`
if (lane_id < num_topk)
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
}
lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
}
// Move queue tail
__syncwarp();
if (lane_id < kNumRDMARanks and is_lane_ready)
st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
}
} else {
// Combiners and coordinators
// RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// NVL layouts
void* local_nvl_buffer = buffer_ptrs[nvl_rank];
void* nvl_buffers[NUM_MAX_NVL_PEERS];
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
nvl_buffers[i] = buffer_ptrs[i];
auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
// Combiner warp synchronization
__shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];
__shared__ volatile bool forwarder_retired[kNumForwarders];
__shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
__shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
auto sync_forwarder_smem = [=]() { asm volatile("bar.sync 0, %0;" :: "r"((kNumForwarders + 1) * 32)); };
auto sync_rdma_receiver_smem = [=]() { asm volatile("bar.sync 1, %0;" :: "r"((kNumRDMAReceivers + 1) * 32)); };
if (warp_role == WarpRole::kNVLAndRDMAForwarder) {
// Receive from NVL ranks and forward to RDMA ranks
// NOTES: this part is using "large warps" for each RDMA ranks
const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder;
const auto sub_warp_id = warp_id % kNumWarpsPerForwarder;
auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
auto sync_large_warp = [=]() {
if (kNumWarpsPerForwarder == 1) {
__syncwarp();
} else {
asm volatile("bar.sync %0, %1;" :: "r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * 32));
}
};
EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough");
// Advance to the corresponding NVL buffer
nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
nvl_channel_head.advance(dst_rdma_rank);
nvl_channel_tail.advance(dst_rdma_rank);
// Clean shared memory and sync
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers");
lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0;
lane_id == 0 ? (forwarder_retired[warp_id] = false) : false;
sync_forwarder_smem();
// Get count and cached head
int cached_nvl_channel_tail_idx = 0;
int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
num_tokens_to_combine -= num_tokens_prefix;
num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;
// Iterate over all tokens and combine by chunks
for (int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
// Check destination queue emptiness, or wait a buffer to be released
auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
auto num_chunked_tokens = token_end_idx - token_start_idx;
auto start_time = clock64();
while (sub_warp_id == 0 and lane_id == 0) {
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
// Here, `token_start_idx` is the actual tail
int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));
if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
break;
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
trap();
}
}
sync_large_warp();
// Combine and write to the RDMA buffer
for (int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
// Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
int expected_head = -1;
if (lane_id < NUM_MAX_NVL_PEERS)
expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
// Wait lanes to be ready
start_time = clock64();
while (cached_nvl_channel_tail_idx <= expected_head) {
cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
trap();
}
}
// Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
reinterpret_cast<int4*>(shifted),
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
// Update head
if (lane_id < NUM_MAX_NVL_PEERS)
expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) : (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1);
}
sync_large_warp();
// Issue RDMA send
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
nvshmem_fence();
} else {
memory_fence();
}
// Write new RDMA tail
__syncwarp();
if (lane_id == 0)
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
}
// Retired
__syncwarp();
if (lane_id == 0)
forwarder_retired[warp_id] = true;
} else if (warp_role == WarpRole::kRDMAReceiver) {
// Receive from RDMA ranks and write to the output tensor
// Clean shared memory and sync
EP_DEVICE_ASSERT(kNumRDMARanks <= 32);
lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0;
lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0;
sync_rdma_receiver_smem();
// The same tokens as the dispatch process
int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// Iterate over all tokens and combine
int cached_channel_tail_idx = 0;
for (int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
// Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
int expected_head = -1;
if (lane_id < kNumRDMARanks) {
expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
(expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1) : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head);
}
// Wait lanes to be ready
auto start_time = clock64();
while (cached_channel_tail_idx <= expected_head) {
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
trap();
}
}
__syncwarp();
// Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
combined_x + token_idx * hidden_int4,
combined_topk_weights + token_idx * num_topk,
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
}
// Retired
__syncwarp();
if (lane_id == 0)
rdma_receiver_retired[warp_id] = true;
} else {
// Coordinator
// Sync shared memory status
is_rdma_receiver_sm ? sync_rdma_receiver_smem() : sync_forwarder_smem();
const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_rdma_head = 0;
int last_nvl_head[kNumRDMARanks] = {0};
int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0;
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, "Invalid number of forwarder warps");
while (true) {
// Retired
if (is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
break;
if (not is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
break;
// Find minimum head for RDMA ranks
if (is_rdma_receiver_sm) {
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_rdma_head and lane_id < kNumRDMARanks)
nvshmem_uint64_p(rdma_channel_head.buffer(rdma_rank), last_rdma_head = min_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
} else {
// Find minimum head for NVL ranks
#pragma unroll
for (int i = 0; i < kNumRDMARanks; ++ i) {
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for (int j = 0; j < num_warps_per_rdma_rank; ++ j) if (not forwarder_retired[i * num_warps_per_rdma_rank + j])
min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS)
st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
}
}
// Nanosleep and let other warps work
__nanosleep(NUM_WAIT_NANOSECONDS);
}
}
}
}
void combine(cudaDataType_t type,
void* combined_x, float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const void* x, const float* topk_weights,
const int* combined_rdma_head, const int* combined_nvl_head,
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumCombineForwarderWarps = 16;
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) { \
auto combine_func = low_latency_mode ? \
combine<true, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps> : combine<false, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps>; \
LAUNCH_KERNEL(&cfg, combine_func, \
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
reinterpret_cast<const int4*>(x), topk_weights, \
combined_rdma_head, combined_nvl_head, \
reinterpret_cast<const SourceMeta*>(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
num_tokens, num_combined_tokens, hidden, num_topk, \
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
rank, num_ranks); } break
int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder;
EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
EP_HOST_ASSERT(type == CUDA_R_16BF);
SETUP_LAUNCH_CONFIG(num_channels * 2, (NUM_MAX_NVL_PEERS + num_forwarder_warps + 1) * 32, stream);
SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}
} // namespace internode
} // namespace deep_ep
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "ibgda_device.cuh"
namespace deep_ep {
namespace internode_ll {
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
__global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1) {
// Barrier before cleaning (in case of unfinished chunked EP)
nvshmemx_barrier_all_block();
// Clean
auto thread_id = static_cast<int>(threadIdx.x);
#pragma unroll
for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
clean_0[i] = 0;
#pragma unroll
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
clean_1[i] = 0;
// Barrier after cleaning (make sure low-latency mode work fine)
nvshmemx_barrier_all_block();
}
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
cudaStream_t stream) {
constexpr int kNumThreads = 256;
SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>,
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
}
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int* atomic_counter_per_local_expert,
int* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int phases) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
const auto num_sms = static_cast<int>(gridDim.x);
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// FP8 staffs
constexpr int kNumPerChannels = 128;
constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f;
const int num_scales = kHidden / kNumPerChannels;
const size_t hidden_int4 = kHidden / sizeof(int4);
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
const size_t num_bytes_per_msg = kHidden + num_scales * sizeof(float) + sizeof(int4);
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
// Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV;
// Expert counts
__shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups];
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps - 1) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * 32;
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
const auto rdma_x_int2 = reinterpret_cast<int2*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_int2) + kHidden);
const auto rdma_x_src_idx = reinterpret_cast<int*>(rdma_x_scales + num_scales);
// Overlap top-k index read and source token index write
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
// FP8 cast
#pragma unroll
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read and calculate local amax
auto int4_value = __ldg(x_int4 + i);
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead];
float amax = kFP8Margin, scale, scale_inv;
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
fp32_values[j] = static_cast<float>(bf16_values[j]);
amax = fmaxf(amax, fabsf(fp32_values[j]));
}
// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv;
if (lane_id == 0 or lane_id == 16)
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
// Cast into send buffer
int2 int2_value;
auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; j += 2) {
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3);
}
rdma_x_int2[i] = int2_value;
}
asm volatile("bar.sync 1, %0;" :: "r"(num_threads));
// Issue IBGDA sends
if (dst_expert_idx >= 0) {
int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
slot_idx = __shfl_sync(0xffffffff, slot_idx, 0);
const auto dst_rank = dst_expert_idx / num_local_experts;
const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_int2);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
} else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
// Increase counter after finishing
__syncwarp();
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
}
}
} else if (warp_id == num_warps - 1) {
EP_DEVICE_ASSERT(num_sms > 1);
if (sm_id == 0) {
// The first SM is also responsible for checking QPs
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_local_experts);
// The first SM is also responsible for cleaning the next buffer
#pragma unroll
for (int i = lane_id; i < num_next_clean_int; i += 32)
next_clean[i] = 0;
// Notify before executing `int_p`
__syncwarp();
#pragma unroll
for (int i = lane_id; i < num_experts; i += 32)
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
int expert_count[kNumWarpGroups] = {0};
const auto expert_begin_idx = sm_id * kNumWarpGroups;
const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts);
// Per lane count
#pragma unroll 8
for (int i = lane_id; i < num_tokens * num_topk; i += 32) {
auto idx = static_cast<int>(__ldg(topk_idx + i));
if (idx >= expert_begin_idx and idx < expert_end_idx)
expert_count[idx - expert_begin_idx] ++;
}
// Warp reduce
#pragma unroll
for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
if (lane_id == 0) {
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
}
}
}
__syncthreads();
// Issue count sends
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
const auto dst_rank = responsible_expert_idx / num_local_experts;
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups];
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
nvshmemi_ibgda_rma_p(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1,
dst_rank, dst_expert_local_idx, 0);
nvshmemi_ibgda_prepare_recvs(dst_rank, dst_expert_local_idx);
} else {
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
}
// Clean workspace for next use
atomic_counter_per_expert[responsible_expert_idx] = 0;
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
}
__syncwarp();
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// Receiving and packing
if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts;
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
// Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
int num_recv_tokens, recv_token_begin_idx;
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
if (sub_warp_id == 1 and lane_id == 0) {
if (src_rank != rank) {
nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx);
num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank);
EP_DEVICE_ASSERT(num_recv_tokens != 0);
} else {
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
}
num_recv_tokens = -num_recv_tokens - 1;
recv_token_begin_idx = atomicAdd(atomic_counter_per_local_expert + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
}
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
// Copy tokens
EP_DEVICE_ASSERT(num_scales <= 64);
for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) {
// Copy data
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
const auto src = reinterpret_cast<int4*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
const auto dst = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst, src, ld_nc_global, st_na_global);
// Copy scales
const auto src_scales = reinterpret_cast<float*>(rdma_recv_x_uint8 + i * num_bytes_per_msg + kHidden);
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
// Copy source info
const auto src_src_idx = reinterpret_cast<int*>(src_scales + num_scales);
if (lane_id == 0)
recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
__syncwarp();
}
}
}
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases) {
constexpr int kNumMaxTopK = 9;
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
// Workspace checks
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
// Use the last part `rdma_recv_count` as `atomic_counter_per_local_expert`
// NOTES: this part will be cleaned in `combine`
auto atomic_counter_per_local_expert = rdma_recv_count + num_ranks * (num_experts / num_ranks);
#define DISPATCH_LAUNCH_CASE(hidden) \
LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, atomic_counter_per_local_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, phases); break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int* atomic_clean_flag,
int num_combined_tokens, int hidden, int num_topk,
int num_max_dispatch_tokens_per_rank,
int num_experts, int rank, int num_ranks,
int phases) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto num_threads = static_cast<int>(blockDim.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// Data type staffs
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// Message package
// BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(nv_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_COMBINE_RECV;
// Clean up next buffer
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
#pragma unroll
for (int i = lane_id; i < num_next_clean_int; i += 32)
next_clean[i] = 0;
// Notify before executing `int_p`
__syncwarp();
if (lane_id == 0)
atomic_add_release_global(atomic_clean_flag, num_experts);
}
// FP8 cast and issue IBGDA sends
if (responsible_expert_idx < num_experts) {
const auto dst_rank = responsible_expert_idx / num_local_experts;
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
const auto local_x = reinterpret_cast<const int4*>(x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// Unpack layout
int offset, num_tokens_to_send;
unpack2(layout, num_tokens_to_send, offset);
// Issue IBGDA send
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
// Copy directly to local rank, or copy to buffer and issue RDMA
auto src_idx = __ldg(local_src_info + token_idx);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
if (dst_rank == rank) {
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
}
}
// Put finishing flag
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) {
nvshmemi_ibgda_rma_p(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx, 0);
} else {
st_na_release(rdma_recv_flag + global_expert_idx, 1);
}
atomic_add_release_global(atomic_clean_flag, -1);
}
__syncwarp();
}
// Receiving phase
LOW_LATENCY_COMBINE_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) {
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
if (sub_warp_id == 0 and lane_id == 0) {
// TODO: refactor QP indices
auto src_rank = responsible_expert_idx / num_local_experts;
auto src_expert_idx = responsible_expert_idx % num_local_experts;
if (src_rank != rank) {
nvshmemi_ibgda_poll_recv(src_rank, src_expert_idx);
} else {
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0);
}
}
}
cg::this_grid().sync();
// Reduce tokens with FP8 cast
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
if (thread_id < hidden_bf16_int4) {
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// Read top-k indices and weights
int reg_topk_idx[kNumMaxTopk];
float reg_topk_weights[kNumMaxTopk];
#pragma unroll
for (int i = 0; i < num_topk; ++ i) {
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
}
float combined_values[kNumElemsPerInt4] = {0.0f};
#pragma unroll
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
// Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
}
// Write results
int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
auto combined_bf16 = reinterpret_cast<nv_bfloat16*>(&combined_values);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
}
}
}
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases) {
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
constexpr int kNumMaxTopk = 9;
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
// Check workspace
auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
phases); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}
} // namespace internode_ll
} // namespace deep_ep
#include "configs.cuh"
#include "buffer.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
namespace deep_ep {
namespace intranode {
template<int kNumRanks>
__global__ void
notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32;
if (sm_id == 0) {
// Barrier first
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
int *per_rank_buffer, *per_expert_buffer;
if (thread_id < kNumRanks) {
per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[thread_id]);
per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;
}
// After this loop:
// - `per_rank_buffer[rank][i, j]` means the number of tokens from rank i to rank j
// - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j
int num_experts_per_rank = num_experts / kNumRanks;
if (thread_id < kNumRanks) {
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i)
per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i];
#pragma unroll
for (int i = 0; i < num_experts_per_rank; ++ i)
per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i];
}
__syncthreads();
// Wait for all ranks to be finished
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
// Sum per-rank counts and return to CPU
// Also pre-compute the prefix sum for data sending
auto local_per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[rank]);
if (thread_id < kNumRanks) {
#pragma unroll
for (int i = 1; i < kNumRanks; ++ i)
local_per_rank_buffer[i * kNumRanks + thread_id] += local_per_rank_buffer[(i - 1) * kNumRanks + thread_id];
if (thread_id == rank)
*moe_recv_counter_mapped = local_per_rank_buffer[(kNumRanks - 1) * kNumRanks + rank];
}
// Sum per-experts counts and return to CPU
auto local_per_expert_buffer = local_per_rank_buffer + kNumRanks * kNumRanks;
if (thread_id < num_experts_per_rank) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i)
sum += local_per_expert_buffer[i * num_experts_per_rank + thread_id];
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
moe_recv_expert_counter_mapped[thread_id] = sum;
}
__syncthreads();
// Copy rank size prefix matrix to another tensor
#pragma unroll
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
rank_prefix_matrix_copy[i] = local_per_rank_buffer[i];
// Extra memset for later communication queue
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
local_per_expert_buffer[i] = 0;
// Barrier
memory_fence();
__syncthreads();
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
} else {
int dst_rank = sm_id - 1;
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// Iterate over tokens
int count = 0;
for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32)
count += is_token_in_rank[i * kNumRanks + dst_rank];
count = warp_reduce_sum(count);
if (lane_id == 0)
channel_prefix_matrix[dst_rank * num_channels + channel_id] = count;
}
__syncthreads();
// Pre-compute prefix sum for all channels
if (thread_id == 0) {
#pragma unroll
for (int i = 1; i < num_channels; ++ i)
channel_prefix_matrix[dst_rank * num_channels + i] += channel_prefix_matrix[dst_rank * num_channels + i - 1];
}
}
}
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
cudaStream_t stream, int num_channels) {
#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, notify_dispatch<ranks>, \
num_tokens_per_rank, moe_recv_counter_mapped, \
num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \
num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \
rank_prefix_matrix_copy, num_memset_int, expert_alignment, \
buffer_ptrs, task_fifo_ptrs, head, rank); \
break
constexpr int kNumThreads = 128;
EP_HOST_ASSERT(num_experts % num_ranks == 0);
EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads);
SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream);
SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
#undef NOTIFY_DISPATCH_LAUNCH_CASE
}
template<int kNumRanks>
__global__ void
cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) {
// A simplified version for cached handles
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
// Copy and clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
#pragma unroll
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
ptr[i] = rank_prefix_matrix[i];
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[kNumRanks * kNumRanks + i] = 0;
memory_fence();
__syncthreads();
// Barrier after cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
}
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** task_fifo_ptrs,
int head, int rank, int num_ranks, cudaStream_t stream) {
#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, cached_notify_dispatch<ranks>, \
rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \
break
SETUP_LAUNCH_CONFIG(1, 128, stream);
SWITCH_RANKS(CACHED_NOTIFY_DISPATCH_LAUNCH_CASE);
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
}
template<int kNumRanks>
__global__ void __launch_bounds__(kNumRanks * 32, 1)
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void **buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const bool is_sender = sm_id % 2 == 0;
EP_DEVICE_ASSERT(num_sms % 2 == 0);
// Each warp is responsible for a single rank
const auto num_channels = num_sms / 2;
const auto responsible_rank = (static_cast<int>(thread_id)) / 32;
// Even-numbered blocks for sending, odd-numbered blocks for receiving
const auto responsible_channel = sm_id / 2;
int num_experts_per_rank = num_experts / kNumRanks;
EP_DEVICE_ASSERT(num_experts_per_rank > 0 or num_topk == 0);
EP_DEVICE_ASSERT(num_topk <= 32);
EP_DEVICE_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
EP_DEVICE_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));
// Calculate pointers by the specific layout
// `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
int target_rank = is_sender ? rank : responsible_rank;
auto num_channels_total = num_channels * kNumRanks;
auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;
// Channel buffer metadata
// Senders are responsible for tails, and receivers are responsible for heads
// Stored on the receiver side
// The retired signals are actually boolean flags, but to align with 16 bytes, we make it `int64_t`
// `start_offset`: kNumChannels * kNumRanks * sizeof(int)
// `end_offset`: kNumChannels * kNumRanks * sizeof(int)
// `head_idx`: kNumChannels * kNumRanks * sizeof(int)
// `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
auto channel_start_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_end_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
// Channel data buffers, stored on the receiver side
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
// `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(int64_t)
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
// `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float)
auto channel_x_buffers = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);
auto channel_topk_idx_buffers = Buffer<int64_t>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
if (is_sender) {
// Workers for sending
constexpr int num_send_warps = kNumRanks;
const auto send_thread_id = thread_id;
const auto send_warp_id = send_thread_id / 32;
const auto send_lane_id = send_thread_id % 32;
EP_DEVICE_ASSERT(kNumRanks <= 32);
EP_DEVICE_ASSERT(num_send_warps == kNumRanks and send_warp_id == responsible_rank);
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
// NOTES: this is for distinguishing zero tokens
if (send_lane_id == 0) {
int value = responsible_channel > 0 ? channel_prefix_matrix[send_warp_id * num_channels + responsible_channel - 1] : 0;
st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
value = channel_prefix_matrix[send_warp_id * num_channels + responsible_channel];
st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1);
}
__syncwarp();
// Get tasks
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx);
// Iterate over all tokens and send by chunks
int cached_channel_tail_idx = 0;
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) {
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
auto start_time = clock64();
while (send_lane_id == 0) {
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)
break;
// Rare cases to loop again
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for dispatch senders, rank %d, responsible_channel = %d\n", rank, responsible_channel);
trap();
}
}
__syncwarp();
int chunk_token_idx = 0;
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
if (send_lane_id == 0)
send_head[token_idx * kNumRanks + send_warp_id] = is_token_in_rank[token_idx * kNumRanks + send_warp_id] ? cached_channel_tail_idx : -1;
// Skip if not selected
if (not is_token_in_rank[token_idx * kNumRanks + send_warp_id]) {
token_idx ++;
continue;
}
// Get an empty slot
int dst_slot_idx = (cached_channel_tail_idx ++) % num_recv_buffer_tokens;
// Copy data
auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x,
__ldg, st_na_global);
// Copy source index
if (send_lane_id == 0)
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
// Copy `topk_idx` and `topk_weights` with transformed index
if (send_lane_id < num_topk) {
// Top-k index
int recv_expert_begin = send_warp_id * num_experts_per_rank, recv_expert_end = (send_warp_id + 1) * num_experts_per_rank;
auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id);
idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1;
channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value;
// Top-k weights
auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id);
weight_value = (idx_value >= 0) ? weight_value : 0.0f;
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value;
}
// Copy `x_scales`
#pragma unroll
for (int i = send_lane_id; i < num_scales; i += 32)
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
// Move token index
chunk_token_idx ++, token_idx ++;
}
// Move tail index
__syncwarp();
if (send_lane_id == 0)
st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
}
} else {
// Workers for receiving and copying into buffer
constexpr int num_recv_warps = kNumRanks;
const auto recv_thread_id = thread_id;
const auto recv_warp_id = recv_thread_id / 32;
const auto recv_lane_id = recv_thread_id % 32;
EP_DEVICE_ASSERT(kNumRanks <= 32 and recv_warp_id == responsible_rank);
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps == kNumRanks);
// Calculate offset first
auto rank_prefix_matrix = reinterpret_cast<int*>(buffer_ptrs[rank]);
int rank_offset = recv_warp_id > 0 ? rank_prefix_matrix[(recv_warp_id - 1) * kNumRanks + rank] : 0;
// Receive channel offset
int total_offset, num_tokens_to_recv;
while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
if (recv_lane_id == 0) {
total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
recv_channel_offset[recv_warp_id * num_channels + responsible_channel] = total_offset;
num_tokens_to_recv -= total_offset;
}
total_offset = __shfl_sync(0xffffffff, total_offset, 0);
total_offset += rank_offset;
num_tokens_to_recv = __shfl_sync(0xffffffff, num_tokens_to_recv, 0);
auto start_time = clock64();
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
while (num_tokens_to_recv > 0) {
// Check channel status by lane 0
while (recv_lane_id == 0) {
cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());;
// Ready to copy
if (cached_channel_head_idx != cached_channel_tail_idx)
break;
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for dispatch receivers, rank %d, responsible_channel = %d, tokens remained: %d\n", rank, responsible_channel, num_tokens_to_recv);
trap();
}
}
// Sync queue tail
cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0);
// Copy data
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx) {
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
ld_nc_global, st_na_global);
}
// Copy `src_idx`
#pragma unroll 4
for (int chunk_idx = cached_channel_head_idx + recv_lane_id; chunk_idx < cached_channel_tail_idx; chunk_idx += 32)
recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = ld_nc_global(channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens);
// Copy `topk_idx` and `topk_weights`
#pragma unroll 4
for (int idx = recv_lane_id; idx < num_recv_tokens * num_topk; idx += 32) {
int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk;
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
auto recv_idx = static_cast<int64_t>(total_offset + chunk_idx) * num_topk + token_topk_idx;
auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx;
recv_topk_idx[recv_idx] = ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx);
recv_topk_weights[recv_idx] = ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx);
}
// Copy `x_scales`
#pragma unroll 4
for (int i = recv_lane_id; i < num_recv_tokens * num_scales; i += 32) {
int chunk_idx = i / num_scales, scales_idx = i % num_scales;
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
recv_x_scales[static_cast<int64_t>(total_offset + chunk_idx) * num_scales + scales_idx] =
ld_nc_global(channel_x_scales_buffers.buffer() + token_idx_in_buffer * num_scales + scales_idx);
}
// Move queue
cached_channel_head_idx += num_recv_tokens;
total_offset += num_recv_tokens;
__syncwarp();
if (recv_lane_id == 0)
st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);
// Exit
num_tokens_to_recv -= num_recv_tokens;
}
}
}
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
#define DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, dispatch<ranks>, \
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
break
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
EP_HOST_ASSERT(num_sms % 2 == 0);
SETUP_LAUNCH_CONFIG(num_sms, num_ranks * 32, stream);
SWITCH_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}
template<int kNumRanks>
__global__ void
cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank) {
const auto sm_id = static_cast<int>(blockIdx.x);
if (sm_id == 0) {
// Barrier before cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
// Clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[i] = 0;
memory_fence();
__syncthreads();
// Barrier after cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
} else {
const auto channel_id = sm_id - 1;
const auto thread_id = static_cast<int>(threadIdx.x);
const auto rank_id = thread_id / 32;
const auto lane_id = thread_id % 32;
int token_start_idx, token_end_idx;
get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25;
#pragma unroll
for (int token_idx_tail = token_end_idx - 1; token_idx_tail >= token_start_idx; token_idx_tail -= 32) {
int token_idx = token_idx_tail - lane_id, expected_head = 0;
auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1;
for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) {
head = __shfl_sync(0xffffffff, current_head, i);
if (head < 0) {
if (lane_id == i)
expected_head = -last_head - 1;
} else {
last_head = head;
}
}
if (current_head < 0 and token_idx >= token_start_idx)
send_head[token_idx * kNumRanks + rank_id] = expected_head;
}
}
}
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank, int num_ranks,
cudaStream_t stream) {
#define CACHED_NOTIFY_COMBINE(ranks) \
LAUNCH_KERNEL(&cfg, cached_notify_combine<ranks>, \
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \
break
const int num_threads = std::max(128, 32 * num_ranks);
EP_HOST_ASSERT(num_ranks <= num_threads);
EP_HOST_ASSERT(num_threads <= 1024);
EP_HOST_ASSERT(1 + num_channels <= num_channels * 2);
SETUP_LAUNCH_CONFIG(1 + num_channels, num_threads, stream);
SWITCH_RANKS(CACHED_NOTIFY_COMBINE);
#undef CACHED_NOTIFY_COMBINE
}
template<typename dtype_t, int kNumRanks, int kNumThreads>
__global__ void __launch_bounds__(kNumThreads, 1)
combine(dtype_t* recv_x, float* recv_topk_weights,
const dtype_t* x, const float* topk_weights,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void **buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_channels = num_sms / 2;
const bool is_sender = sm_id % 2 == 0;
const int responsible_channel = sm_id / 2;
EP_DEVICE_ASSERT(num_topk <= 32);
constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);
auto x_int4 = reinterpret_cast<const int4*>(x);
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
if (is_sender) {
// Workers for sending
// Several warps are responsible for a single rank
constexpr int num_send_warps = kNumThreads / 32;
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
const auto num_threads_per_rank = num_send_warps_per_rank * 32;
const auto send_thread_id = thread_id;
const auto send_lane_id = send_thread_id % 32;
const auto send_rank_id = thread_id / num_threads_per_rank;
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
// Calculate pointers by the specific layout
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
auto num_channels_total = num_channels * kNumRanks;
auto channel_rank_offset = responsible_channel * kNumRanks + rank;
// Channel meta data
// `head_idx`: kNumChannels * kNumRanks * sizeof(int)
// `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_x_buffers = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
// Get tasks
// NOTES: `channel_offset` is already shifted
int rank_offset = send_rank_id > 0 ? rank_prefix_matrix[(send_rank_id - 1) * kNumRanks + rank] : 0;
int num_rank_tokens = rank_prefix_matrix[send_rank_id * kNumRanks + rank] - rank_offset;
int channel_offset = channel_prefix_matrix[send_rank_id * num_channels + responsible_channel];
int num_channel_tokens = (responsible_channel == num_channels - 1 ? num_rank_tokens : channel_prefix_matrix[send_rank_id * num_channels + responsible_channel + 1]) - channel_offset;
int token_start_idx = rank_offset + channel_offset, token_end_idx = rank_offset + channel_offset + num_channel_tokens;
// Iterate over all tokens and send by chunks
int current_channel_tail_idx = 0;
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx; ) {
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
auto start_time = clock64();
int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));
while (send_lane_id == 0) {
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)
break;
// Rare cases to loop again
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for combine senders, rank %d, responsible_channel = %d\n", rank, responsible_channel);
trap();
}
}
__syncwarp();
// Send by chunk
#pragma unroll
for (int i = send_warp_id_in_rank; i < num_round_tokens; i += num_send_warps_per_rank) {
// Get an empty slot
int dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens;
// Copy data
auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x_int4 + (token_idx + i) * hidden_int4;
UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
// Send source index
if (send_lane_id == 0)
channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);
// Send `topk_weights`
if (num_topk > 0 and send_lane_id < num_topk)
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id);
}
token_idx += num_round_tokens;
current_channel_tail_idx += num_round_tokens;
// Move tail index
asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank));
if (send_lane_id == 0 and send_warp_id_in_rank == 0)
st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);
}
} else {
// Workers for receiving
// One warp for moving the queue head, others for reduction
constexpr int num_recv_warps = kNumThreads / 32;
const auto recv_warp_id = thread_id / 32;
const auto recv_lane_id = thread_id % 32;
EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32);
EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0);
// Shared head, tail and retired flags for receiver warps
__shared__ volatile int warp_channel_head_idx[num_recv_warps][kNumRanks];
__shared__ volatile int channel_tail_idx[kNumRanks];
__shared__ volatile bool warp_retired[num_recv_warps];
if (thread_id < num_recv_warps)
warp_retired[thread_id] = false;
if (recv_lane_id < kNumRanks)
warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0;
if (thread_id < kNumRanks)
channel_tail_idx[thread_id] = 0;
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
if (thread_id < 32) {
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id;
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
// Queue head updater
int last_head = 0;
while (recv_lane_id < kNumRanks) {
// Check retired
bool retired = true;
#pragma unroll
for (int i = 1; i < num_recv_warps; ++ i)
retired = retired and warp_retired[i];
if (retired)
break;
// Update queue tail
channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);
// Update minimum head
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i])
min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_head)
st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);
}
} else {
// Receivers
// Channel metadata
// All lanes will use data buffer, but only rank lane will use `head/tail/src_idx`
Buffer<int4> channel_x_buffers[kNumRanks];
Buffer<float> channel_topk_weights_buffers[kNumRanks];
// Calculate pointers by the specific layout
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i) {
auto channel_rank_offset = responsible_channel * kNumRanks + i;
auto num_channels_total = num_channels * kNumRanks;
// `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
channel_x_buffers[i] = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
channel_topk_weights_buffers[i] = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
}
// The same tokens as the dispatch process
int token_start_idx, token_end_idx;
get_channel_task_range(num_recv_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx);
// Iterate over all tokens and combine
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {
// Read expected head
int expected_head = -1;
if (recv_lane_id < kNumRanks) {
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id);
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
}
auto start_time = clock64();
while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) {
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head);
trap();
}
}
__syncwarp();
// Broadcast current heads
int num_topk_ranks = 0, topk_ranks[kNumRanks], slot_indices[kNumRanks];
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i) {
auto expected_head_i = __shfl_sync(0xffffffff, expected_head, i);
if (expected_head_i >= 0) {
slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens;
topk_ranks[num_topk_ranks ++] = i;
}
}
// Reduce data
#pragma unroll
for (int i = recv_lane_id; i < hidden_int4; i += 32) {
// Read buffers
int4 recv_value_int4[kNumRanks];
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i);
// Reduce all-to-all results
float values[kDtypePerInt4] = {0};
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) {
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
#pragma unroll
for (int k = 0; k < kDtypePerInt4; ++ k)
values[k] += static_cast<float>(recv_value_dtypes[k]);
}
// Cast back to `dtype_t` and write
int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]);
recv_int4[token_idx * hidden_int4 + i] = out_int4;
}
// Reduce `topk_weights`
if (recv_lane_id < num_topk) {
float value = 0;
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id);
recv_topk_weights[token_idx * num_topk + recv_lane_id] = value;
}
}
// Retired
__syncwarp();
if (recv_lane_id == 0)
warp_retired[recv_warp_id] = true;
}
}
}
void combine(cudaDataType_t type,
void* recv_x, float* recv_topk_weights,
const void* x, const float* topk_weights,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
#define COMBINE_LAUNCH_CASE(dtype, ranks) \
LAUNCH_KERNEL(&cfg, (combine<dtype, ranks, kNumThreads>), \
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
reinterpret_cast<const dtype*>(x), topk_weights, \
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
break
#define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break
// Even-numbered blocks for sending, odd-numbered blocks for receiving
EP_HOST_ASSERT(num_sms % 2 == 0);
EP_HOST_ASSERT(kNumThreads >= num_ranks * 32);
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
SWITCH_TYPES(COMBINE_DTYPE_LAUNCH_CASE);
#undef COMBINE_DTYPE_LAUNCH_CASE
#undef COMBINE_LAUNCH_CASE
}
} // namespace intranode
} // namespace deep_ep
#pragma once
#include "configs.cuh"
#ifndef SETUP_LAUNCH_CONFIG
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \
cudaLaunchAttribute attr[1]; \
attr[0].id = cudaLaunchAttributeCooperative; \
attr[0].val.cooperative = 1; \
cfg.attrs = attr; \
cfg.numAttrs = 1
#endif
#ifndef LAUNCH_KERNEL
#define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__))
#endif
#define SWITCH_RANKS(case_macro) \
switch (num_ranks) { \
case 2: case_macro(2); \
case 4: case_macro(4); \
case 8: case_macro(8); \
default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
} while (false)
#define SWITCH_RDMA_RANKS(case_macro) \
switch (num_ranks / NUM_MAX_NVL_PEERS) { \
case 2: case_macro(2); \
case 3: case_macro(3); \
case 4: case_macro(4); \
case 8: case_macro(8); \
case 16: case_macro(16); \
case 18: case_macro(18); \
case 20: case_macro(20); \
default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
} while (false)
#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \
switch (num_ranks) { \
case 2: case_macro(dtype, 2); \
case 4: case_macro(dtype, 4); \
case 8: case_macro(dtype, 8); \
default: EP_HOST_ASSERT(false && "Unsupported ranks"); \
} while (false)
#define SWITCH_TYPES(case_macro) \
switch (type) { \
case CUDA_R_16BF: case_macro(nv_bfloat16); \
case CUDA_R_32F: case_macro(float); \
default: EP_HOST_ASSERT(false && "Unsupported type"); \
} while (false)
#define SWITCH_HIDDEN(case_macro) \
switch (hidden) { \
case 2560: case_macro(2560); \
case 5120: case_macro(5120); \
case 7168: case_macro(7168); \
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \
} while (false)
#include <vector>
#include <cstring>
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
#include "ibgda_device.cuh"
namespace deep_ep {
namespace intranode {
template<int kNumRanks>
__global__ void barrier(int** task_fifo_ptrs, int head, int rank) {
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
}
void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) {
#define BARRIER_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, barrier<ranks>, task_fifo_ptrs, head, rank); \
break
SETUP_LAUNCH_CONFIG(1, 32, stream);
SWITCH_RANKS(BARRIER_LAUNCH_CASE);
#undef BARRIER_LAUNCH_CASE
}
} // namespace intranode
namespace internode {
nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;
nvshmem_team_config_t cpu_rdma_team_config;
std::vector<uint8_t> get_unique_id() {
nvshmemx_uniqueid_t unique_id;
nvshmemx_get_uniqueid(&unique_id);
std::vector<uint8_t> result(sizeof(nvshmemx_uniqueid_t));
std::memcpy(result.data(), &unique_id, sizeof(nvshmemx_uniqueid_t));
return result;
}
__global__ void ibgda_initialize_recv_queue(int rank) {
auto thread_idx = static_cast<int>(threadIdx.x);
auto num_threads = static_cast<int>(blockDim.x);
auto dst_rank = static_cast<int>(blockIdx.x);
if (dst_rank != rank) {
for (int qp_id = thread_idx; qp_id < ibgda_get_state()->num_rc_per_pe; qp_id += num_threads) {
auto qp = ibgda_get_rc(dst_rank, qp_id);
// Clean some necessary variables
for (int i = 0; i < qp->rx_wq.nwqes; ++ i)
ibgda_write_empty_recv_wqe(ibgda_get_wqe_ptr(qp, i));
qp->mvars.rx_wq.resv_head = 0;
qp->mvars.rx_wq.cons_idx = 0;
// Allocate receive slots
nvshmemi_ibgda_allocate_recvs(qp);
}
}
}
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
nvshmemx_uniqueid_t root_unique_id;
nvshmemx_init_attr_t attr;
std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(nvshmemx_uniqueid_t));
nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
// Create sub-RDMA teams
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
EP_HOST_ASSERT(cpu_rdma_team == NVSHMEM_TEAM_INVALID);
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(nvshmem_team_split_strided(NVSHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS,
num_ranks / NUM_MAX_NVL_PEERS, &cpu_rdma_team_config, 0, &cpu_rdma_team) == 0);
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
}
// Normal operations use IBRC, while low-latency operations use IBGDA
if (low_latency_mode) {
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
bool ibgda_is_initialized = false;
cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice);
// Initialize recv queues for low-latency mode AR
ibgda_initialize_recv_queue<<<num_ranks, 128>>>(rank);
}
nvshmem_barrier_all();
return nvshmem_my_pe();
}
void* alloc(size_t size, size_t alignment) {
return nvshmem_align(alignment, size);
}
void free(void* ptr) {
nvshmem_free(ptr);
}
void barrier() {
nvshmem_barrier_all();
}
void finalize() {
if (cpu_rdma_team != NVSHMEM_TEAM_INVALID) {
nvshmem_team_destroy(cpu_rdma_team);
cpu_rdma_team = NVSHMEM_TEAM_INVALID;
}
nvshmem_finalize();
}
} // namespace internode
} // namespace deep_ep
#pragma once
#include "exception.cuh"
#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \
constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \
typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \
auto __src = (SRC); \
auto __dst = (DST); \
for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \
_Pragma("unroll") \
for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \
unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32); \
_Pragma("unroll") \
for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \
ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]); \
} \
for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += 32) \
ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
}
namespace deep_ep {
template <int kBytes>
struct VecInt {};
template<> struct VecInt<1> { using vec_t = int8_t; };
template<> struct VecInt<2> { using vec_t = int16_t; };
template<> struct VecInt<4> { using vec_t = int; };
template<> struct VecInt<8> { using vec_t = int64_t; };
template<> struct VecInt<16> { using vec_t = int4; };
__device__ __forceinline__ void trap() {
asm("trap;");
}
__device__ __forceinline__ void memory_fence() {
asm volatile("fence.acq_rel.sys;":: : "memory");
}
__device__ __forceinline__ void memory_fence_gpu() {
asm volatile("fence.acq_rel.gpu;":: : "memory");
}
__device__ __forceinline__ void memory_fence_cta() {
asm volatile("fence.acq_rel.cta;":: : "memory");
}
__device__ __forceinline__ void st_relaxed_sys_global(const int *ptr, int val) {
asm volatile("st.relaxed.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
}
__device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
asm volatile("st.release.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
}
__device__ __forceinline__ void st_release_cta(const int *ptr, int val) {
asm volatile("st.release.cta.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
}
__device__ __forceinline__ int ld_acquire_sys_global(const int *ptr) {
int ret;
asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) {
uint64_t ret;
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int ld_acquire_global(const int *ptr) {
int ret;
asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int atomic_add_release_sys_global(const int* ptr, int value) {
int ret;
asm volatile("atom.add.release.sys.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
return ret;
}
__device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) {
int ret;
asm volatile("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
return ret;
}
__device__ __forceinline__ int ld_acquire_cta(const int *ptr) {
int ret;
asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ uint8_t ld_na_relaxed(const uint8_t *ptr) {
uint16_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr));
return static_cast<uint8_t>(ret);
}
__device__ __forceinline__ uint16_t ld_na_relaxed(const uint16_t *ptr) {
uint16_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ uint32_t ld_na_relaxed(const uint32_t *ptr) {
uint32_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ uint64_t ld_na_relaxed(const uint64_t *ptr) {
uint64_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int ld_volatile_global(const int *ptr) {
int ret;
asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ float ld_volatile_global(const float *ptr) {
float ret;
asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int64_t ld_volatile_global(const int64_t *ptr) {
int64_t ret;
asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) {
int64_t ret;
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
#define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B"
#else
#define LD_NC_FUNC "ld.volatile.global"
#endif
// `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS,
// which does not have cache allocation, and `CONSTANT` memory does not have coherence control,
// so we have to control them by queue semantics
template <typename dtype_t>
__device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) {
auto ret = ld_nc_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr));
return *reinterpret_cast<dtype_t*>(&ret);
}
template <>
__device__ __forceinline__ uint8_t ld_nc_global(const uint8_t *ptr) {
uint16_t ret;
// NOTES: we must use `uint16_t` as inline ASM does not support 8-bit constraint letter (`h` below means unsigned 16-bit)
asm volatile(LD_NC_FUNC ".u8 %0, [%1];" : "=h"(ret) : "l"(ptr));
return static_cast<uint8_t>(ret);
}
template <>
__device__ __forceinline__ int ld_nc_global(const int *ptr) {
int ret;
asm volatile(LD_NC_FUNC ".s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
template <>
__device__ __forceinline__ int64_t ld_nc_global(const int64_t *ptr) {
int64_t ret;
asm volatile(LD_NC_FUNC ".s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
template <>
__device__ __forceinline__ float ld_nc_global(const float *ptr) {
float ret;
asm volatile(LD_NC_FUNC ".f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
return ret;
}
template <>
__device__ __forceinline__ int2 ld_nc_global(const int2 *ptr) {
int2 ret;
asm volatile(LD_NC_FUNC ".v2.s32 {%0, %1}, [%2];" : "=r"(ret.x), "=r"(ret.y) : "l"(ptr));
return ret;
}
template <>
__device__ __forceinline__ int4 ld_nc_global(const int4 *ptr) {
int4 ret;
asm volatile(LD_NC_FUNC ".v4.s32 {%0, %1, %2, %3}, [%4];"
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) {
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;" : : "l"(ptr), "h"(static_cast<uint16_t>(val)));
}
__device__ __forceinline__ void st_na_relaxed(const uint16_t *ptr, uint16_t val) {
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b16 [%0], %1;" : : "l"(ptr), "h"(val));
}
__device__ __forceinline__ void st_na_relaxed(const uint32_t *ptr, uint32_t val) {
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) {
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) {
asm volatile("st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};"
: : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
}
__device__ __forceinline__ void st_na_release(const int *ptr, int val) {
asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) {
asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) {
asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val));
}
// `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS,
// which does not have cache allocation (obviously in L1, I guess not in L2 too)
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
#define ST_NA_FUNC "st.global.L1::no_allocate"
#else
#define ST_NA_FUNC "st.global"
#endif
template <typename dtype_t>
__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t& value) {
st_na_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr),
*reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(&value));
}
template <>
__device__ __forceinline__ void st_na_global(const int *ptr, const int& value) {
asm volatile(ST_NA_FUNC ".s32 [%0], %1;" ::"l"(ptr), "r"(value));
}
template <>
__device__ __forceinline__ void st_na_global(const int64_t *ptr, const int64_t& value) {
asm volatile(ST_NA_FUNC ".s64 [%0], %1;" ::"l"(ptr), "l"(value));
}
template <>
__device__ __forceinline__ void st_na_global(const float *ptr, const float& value) {
asm volatile(ST_NA_FUNC ".f32 [%0], %1;" ::"l"(ptr), "f"(value));
}
template <>
__device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value) {
asm volatile(ST_NA_FUNC ".v4.s32 [%0], {%1, %2, %3, %4};"
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
}
template <typename dtype_t>
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}
template <typename dtype_t>
__host__ __device__ dtype_t align(dtype_t a, dtype_t b) {
return cell_div<dtype_t>(a, b) * b;
}
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
int& token_start_idx, int& token_end_idx) {
int num_tokens_per_sm = cell_div(num_tokens, num_sms);
token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
}
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
dtype_b_t packed;
auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
unpacked_ptr[0] = x, unpacked_ptr[1] = y;
return packed;
}
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
x = unpacked_ptr[0], y = unpacked_ptr[1];
}
template <typename dtype_t>
__device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) {
EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, "");
auto send_int_values = reinterpret_cast<int*>(&ptr);
int recv_int_values[sizeof(dtype_t) / sizeof(int)];
#pragma unroll
for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++ i)
recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], src_lane_idx);
return *reinterpret_cast<dtype_t*>(recv_int_values);
}
__forceinline__ __device__ int warp_reduce_sum(int value) {
value += __shfl_xor_sync(0xffffffff, value, 16);
value += __shfl_xor_sync(0xffffffff, value, 8);
value += __shfl_xor_sync(0xffffffff, value, 4);
value += __shfl_xor_sync(0xffffffff, value, 2);
value += __shfl_xor_sync(0xffffffff, value, 1);
return value;
}
__forceinline__ __device__ float half_warp_reduce_max(float value) {
auto mask = __activemask();
// The mask be in `{0xffffffff, 0xffff}`
value = max(value, __shfl_xor_sync(mask, value, 8));
value = max(value, __shfl_xor_sync(mask, value, 4));
value = max(value, __shfl_xor_sync(mask, value, 2));
value = max(value, __shfl_xor_sync(mask, value, 1));
return value;
}
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
template <int kNumRanks>
__forceinline__ __device__ void move_fifo_slots(int &head) {
head = (head + kNumRanks) % NUM_MAX_FIFO_SLOTS;
}
template <int kNumRanks>
__device__ __forceinline__ bool not_finished(int *task, int expected) {
auto result = false;
auto lane_id = threadIdx.x % 32;
if (lane_id < kNumRanks)
result = ld_volatile_global(task + lane_id) != expected;
return __any_sync(0xffffffff, result);
}
template <int kNumRanks>
__forceinline__ __device__ void
timeout_check(int **task_fifo_ptrs, int head, int rank, int expected, int tag = 0) {
auto start_time = clock64();
while (not_finished<kNumRanks>(task_fifo_ptrs[rank] + head, expected)) {
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and threadIdx.x == 0) {
printf("DeepEP timeout check failed: %d (rank = %d)\n", tag, rank);
trap();
}
}
}
template <int kNumRanks>
__forceinline__ __device__ void
barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) {
auto thread_id = static_cast<int>(threadIdx.x);
EP_DEVICE_ASSERT(kNumRanks <= 32);
if (thread_id < kNumRanks) {
atomicAdd_system(task_fifo_ptrs[rank] + head + thread_id, FINISHED_SUM_TAG);
memory_fence();
atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG);
}
timeout_check<kNumRanks>(task_fifo_ptrs, head, rank, 0, tag);
}
} // namespace deep_ep
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment