Commit ebfe47e4 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Initial commit

parents
compile_commands.json
.idea
.DS_Store
*.pyc
build/
.cache/
.vscode/
*/cmake-build-*/
MIT License
Copyright (c) 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# DeepEP
DeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also as known as MoE dispatch and combine. The library also supports low-precision operations, including FP8.
To align with the group-limited gating algorithm proposed in the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper, DeepEP offers a set of kernels optimized for asymmetric-domain bandwidth forwarding, such as forwarding data from NVLink domain to RDMA domain. These kernels deliver high throughput, making them suitable for both training and inference prefilling tasks. Additionally, they support SM (Streaming Multiprocessors) number control.
For latency-sensitive inference decoding, DeepEP includes a set of low-latency kernels with pure RDMA to minimize delays. The library also introduces a hook-based communication-computation overlapping method that does not occupy any SM resource.
Notice: the implementation in this library may have some slight differences from the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper.
## Performance
### Normal kernels with NVLink and RDMA forwarding
We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow the DeepSeek-V3/R1 pretraining setting (4096 tokens per batch, 7168 hidden, top-4 groups, top-8 experts, FP8 dispatching and BF16 combining).
| Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth |
|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
| Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) |
| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) |
| Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) |
| Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) |
### Low-latency kernels with pure RDMA
We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, FP8 dispatching and BF16 combining).
| Dispatch #EP | Latency | RDMA bandwidth | Combine #EP | Latency | RDMA bandwidth |
|:------------:|:-------:|:--------------:|:-----------:|:-------:|:--------------:|
| 8 | 163 us | 46 GB/s | 8 | 318 us | 46 GB/s |
| 16 | 173 us | 43 GB/s | 16 | 329 us | 44 GB/s |
| 32 | 182 us | 41 GB/s | 32 | 350 us | 41 GB/s |
| 64 | 186 us | 40 GB/s | 64 | 353 us | 41 GB/s |
| 128 | 192 us | 39 GB/s | 128 | 369 us | 39 GB/s |
| 256 | 194 us | 39 GB/s | 256 | 360 us | 40 GB/s |
## Quick start
### Requirements
- Hopper GPUs (may support more architectures or devices later)
- Python 3.8 and above
- CUDA 12.3 and above
- PyTorch 2.1 and above
- NVLink for intranode communication
- RDMA network for internode communication
### Download and install NVSHMEM dependency
DeepEP also depends on our modified NVSHMEM. Please refer to our [NVSHMEM Installation Guide](third-party/README.md) for instructions.
### Development
```bash
# Build and make symbolic links for SO files
NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py build
# You may modify the specific SO names according to your own platform
ln -s build/lib.linux-x86_64-cpython-38/deep_ep_cpp.cpython-38-x86_64-linux-gnu.so
# Run test cases
# NOTES: you may modify the `init_dist` function in `tests/utils.py`
# according to your own cluster settings, and launch into multiple nodes
python tests/test_intranode.py
python tests/test_internode.py
python tests/test_low_latency.py
```
### Installation
```bash
NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install
```
Then, import `deep_ep` in your Python project, and enjoy!
## Network configurations
DeepEP is fully tested with InfiniBand networks. However, it is theoretically compatible with RDMA over Converged Ethernet (RoCE) as well.
### Traffic isolation
Traffic isolation is supported by InfiniBand through Virtual Lanes (VL).
To prevent interference between different types of traffic, we recommend segregating workloads across different virtual lanes as follows:
- workloads using normal kernels
- workloads using low-latency kernels
- other workloads
For DeepEP, you can control the virtual lane assignment by setting the `NVSHMEM_IB_SL` environment variable.
### Adaptive routing
Adaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Currently, low-latency kernels support adaptive routing, while normal kernels do not (support may be added soon). **Enabling adaptive routing for normal internode kernels may lead to deadlocks or data corruption issues**.
For low-latency kernels, enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance:
- enable adaptive routing in environments with heavy network loads
- use static routing in environments with light network loads
### Congestion control
Congestion control is disabled as we have not observed significant congestion in our production environment.
## Interfaces and examples
### Example use in model training or inference prefilling
The normal kernels can be used in model training or the inference prefilling phase (without the backward part) as the below example code shows.
```python
import torch
import torch.distributed as dist
from typing import List, Tuple, Optional, Union
from deep_ep import Buffer, EventOverlap
# Communication buffer (will allocate at runtime)
_buffer: Optional[Buffer] = None
# Set the number of SMs to use
# NOTES: this is a static variable
Buffer.set_num_sms(24)
# You may call this function at the framework initialization
def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
global _buffer
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
num_nvl_bytes, num_rdma_bytes = 0, 0
for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())):
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
# Allocate a buffer if not existed or not enough buffer size
# NOTES: the adaptive routing configuration of the network **must be off**
if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
return _buffer
def get_hidden_bytes(x: torch.Tensor) -> int:
t = x[0] if isinstance(x, tuple) else x
return t.size(1) * max(t.element_size(), 2)
def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int, previous_event: Optional[EventOverlap] = None) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]:
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
# refer to the docs of `Buffer.dispatch`
global _buffer
# Calculate layout before actual dispatch
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \
_buffer.get_dispatch_layout(topk_idx, num_experts,
previous_event=previous_event, async_finish=True,
allocate_on_comm_stream=previous_event is not None)
# Do MoE dispatch
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
# For more advanced usages, please refer to the docs of the `dispatch` function
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
_buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event, async_finish=True,
allocate_on_comm_stream=True)
# For event management, please refer to the docs of the `EventOverlap` class
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event
def dispatch_backward(grad_recv_x: torch.Tensor, grad_recv_topk_weights: torch.Tensor, handle: Tuple) -> \
Tuple[torch.Tensor, torch.Tensor, EventOverlap]:
global _buffer
# The backward process of MoE dispatch is actually a combine
# For more advanced usages, please refer to the docs of the `combine` function
combined_grad_x, combined_grad_recv_topk_weights, event = \
_buffer.combine(grad_recv_x, handle, topk_weights=grad_recv_topk_weights, async_finish=True)
# For event management, please refer to the docs of the `EventOverlap` class
return combined_grad_x, combined_grad_recv_topk_weights, event
def combine_forward(x: torch.Tensor, handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
Tuple[torch.Tensor, EventOverlap]:
global _buffer
# Do MoE combine
# For more advanced usages, please refer to the docs of the `combine` function
combined_x, _, event = _buffer.combine(x, handle, async_finish=True, previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None)
# For event management, please refer to the docs of the `EventOverlap` class
return combined_x, event
def combine_backward(grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]:
global _buffer
# The backward process of MoE combine is actually a dispatch
# For more advanced usages, please refer to the docs of the `combine` function
grad_x, _, _, _, _, event = _buffer.dispatch(grad_combined_x, handle=handle, async_finish=True,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None)
# For event management, please refer to the docs of the `EventOverlap` class
return grad_x, event
```
Moreover, inside the dispatch function, we may not know how many tokens to receive for the current rank. So an implicit CPU wait for GPU received count signal will be involved, as the following figure shows.
![normal](figures/normal.png)
### Example use in inference decoding
The low latency kernels can be used in the inference decoding phase as the below example code shows.
```python
import torch
import torch.distributed as dist
from typing import Tuple, Optional
from deep_ep import Buffer
# Communication buffer (will allocate at runtime)
# NOTES: there is no SM control API for the low-latency kernels
_buffer: Optional[Buffer] = None
# You may call this function at the framework initialization
def get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> Buffer:
# NOTES: the low-latency mode will consume much more space than the normal mode
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
global _buffer
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts)
# Allocate a buffer if not existed or not enough buffer size
if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes:
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert num_experts % group.size() == 0
_buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size())
return _buffer
def low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int):
global _buffer
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
recv_hidden_states, recv_expert_count, handle, event, hook = \
_buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts,
async_finish=False, return_recv_hook=True)
# NOTES: the actual tensor will not be received only if you call `hook()`,
# it is useful for double-batch overlapping, but **without any SM occupation**
# If you don't want to overlap, please set `return_recv_hook=False`
# Later, you can use our GEMM library to do the computation with this specific format
return recv_hidden_states, recv_expert_count, handle, event, hook
def low_latency_combine(hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple):
global _buffer
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
combined_hidden_states, event_overlap, hook = \
_buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle,
async_finish=False, return_recv_hook=True)
# NOTES: the same behavior as described in the dispatch kernel
return combined_hidden_states, event_overlap, hook
```
For two micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffics are happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e. the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload.
![low-latency](figures/low-latency.png)
## Notices
- For extreme performance, we discover and use an out-of-doc PTX instruction: `ld.global.nc.L1::no_allocate.L2::256B`. This instruction will lead to an undefined behavior: accessing volatile GPU memory with non-coherent read-only PTX modifiers `.nc`. But the correctness is tested to be guaranteed with `.L1::no_allocate` on Hopper architectures, and performance will be much better. If you find kernels not working on some other platforms, you may add `DISABLE_AGGRESSIVE_PTX_INSTRS=1` to `setup.py` and disable this, or file an issue.
- For better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek's internal cluster.
## License
This code repository is released under [the MIT License](LICENSE), except for codes that reference NVSHMEM (including `csrc/kernels/ibgda_device.cuh` and `third-party/nvshmem.patch`), which are subject to [NVSHMEM SLA](https://docs.nvidia.com/nvshmem/api/sla.html).
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
```bibtex
@misc{deepep2025,
title={DeepEP: an efficient expert-parallel communication library},
author={Chenggang Zhao and Shangyan Zhou and Liyue Zhang and Chengqi Deng and Zhean Xu and Yuxuan Liu and Kuai Yu and Jiashi Li and Liang Zhao},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/DeepEP}},
}
```
# NOTES: this CMake is only for debugging; for setup, please use Torch extension
cmake_minimum_required(VERSION 3.10)
project(deep_ep LANGUAGES CUDA CXX)
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC")
set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
set(TORCH_CUDA_ARCH_LIST "9.0")
find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)
find_package(NVSHMEM REQUIRED HINTS ${NVSHMEM_ROOT_DIR}/lib/cmake/nvshmem)
add_library(nvshmem ALIAS nvshmem::nvshmem)
add_library(nvshmem_host ALIAS nvshmem::nvshmem_host)
add_library(nvshmem_device ALIAS nvshmem::nvshmem_device)
# Seems bugs with CMake, NVCC 12 and C++ 17
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 14)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR})
add_subdirectory(kernels)
# Link CPP and CUDA together
pybind11_add_module(deep_ep_cpp deep_ep.cpp)
target_link_libraries(deep_ep_cpp PRIVATE ${EP_CUDA_LIBRARIES} ${TORCH_LIBRARIES} torch_python)
#pragma once
#include "kernels/api.cuh"
#include "kernels/exception.cuh"
namespace deep_ep {
template <typename dtype_t>
dtype_t cell_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}
template <typename dtype_t>
dtype_t align(dtype_t a, dtype_t b) {
return cell_div<dtype_t>(a, b) * b;
}
struct Config {
int num_sms;
int num_max_nvl_chunked_send_tokens;
int num_max_nvl_chunked_recv_tokens;
int num_max_rdma_chunked_send_tokens;
int num_max_rdma_chunked_recv_tokens;
Config(int num_sms,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) :
num_sms(num_sms),
num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
EP_HOST_ASSERT(num_sms >= 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
this->num_max_rdma_chunked_recv_tokens = align<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
}
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes();
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float);
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
// Legacy mode
if (num_ranks <= NUM_MAX_NVL_PEERS)
return 0;
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % 2 == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
};
struct LowLatencyBuffer {
int num_clean_int = 0;
void* dispatch_rdma_send_buffer = nullptr;
void* dispatch_rdma_recv_data_buffer = nullptr;
int* dispatch_rdma_recv_count_buffer = nullptr;
int* dispatch_rdma_atomic_token_counter = nullptr;
void* combine_rdma_send_buffer = nullptr;
void* combine_rdma_recv_data_buffer = nullptr;
int* combine_rdma_recv_flag_buffer = nullptr;
std::pair<int*, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int};
}
};
struct LowLatencyLayout {
size_t total_bytes = 0;
LowLatencyBuffer buffers[2];
template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*>
out_ptr_t advance(const in_ptr_t& ptr, size_t count) {
return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
}
LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
const int num_scales = hidden / 128;
const int num_local_experts = num_experts / num_ranks;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers
// - 2 symmetric odd/even signaling buffers
// Message sizes
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
size_t num_bytes_per_dispatch_msg = hidden + num_scales * sizeof(float) + sizeof(int4);
size_t num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(nv_bfloat16);
// Send buffer
size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
total_bytes += send_buffer_bytes * 2;
// Symmetric receive buffers
// TODO: optimize memory usages
size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
total_bytes += recv_buffer_bytes * 2;
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t dispatch_recv_atomic_token_counter_bytes = num_local_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes + dispatch_recv_atomic_token_counter_bytes,
combine_recv_flag_buffer_bytes);
total_bytes += signaling_buffer_bytes * 2;
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated
for (int i = 0; i < 2; ++ i) {
buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int)),
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i + dispatch_recv_count_buffer_bytes),
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i)
};
}
}
};
size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES;
}
} // namespace deep_ep
This diff is collapsed.
#pragma once
// Forcibly disable NDEBUG
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/types.h>
#include <tuple>
#include <vector>
#include "config.hpp"
#include "event.hpp"
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
#ifndef TORCH_EXTENSION_NAME
#define TORCH_EXTENSION_NAME deep_ep_cpp
#endif
namespace deep_ep {
struct Buffer {
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8");
private:
// Low-latency mode buffer
int low_latency_buffer_idx = 0;
bool low_latency_mode = false;
// NVLink Buffer
int64_t num_nvl_bytes;
void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void** buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer
int64_t num_rdma_bytes;
void* rdma_buffer_ptr = nullptr;
// Device info and communication
int device_id;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
// Stream for communication
at::cuda::CUDAStream comm_stream;
// After IPC/NVSHMEM synchronization, this flag will be true
bool available = false;
// Task fifo
int head = 0;
int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** task_fifo_ptrs_gpu = nullptr;
// Workspace
void* workspace = nullptr;
// Host-side MoE info
volatile int* moe_recv_counter = nullptr;
int* moe_recv_counter_mapped = nullptr;
// Host-side expert-level MoE info
volatile int* moe_recv_expert_counter = nullptr;
int* moe_recv_expert_counter_mapped = nullptr;
// Host-side RDMA-level MoE info
volatile int* moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr;
private:
void move_fifo_slots(int num_slots = 1);
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode);
~Buffer() noexcept(false);
bool is_available() const;
bool is_internode_available() const;
int get_num_rdma_ranks() const;
int get_rdma_rank() const;
int get_root_rdma_rank(bool global) const;
int get_local_device_id() const;
pybind11::bytearray get_local_ipc_handle() const;
pybind11::bytearray get_local_nvshmem_unique_id() const;
torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;
void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event,
bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,
const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook);
};
} // namespace deep_ep
#include <ATen/cuda/CUDAContext.h>
#include <memory>
#include "kernels/exception.cuh"
namespace deep_ep {
struct EventHandle {
std::shared_ptr<torch::Event> event;
EventHandle() {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(at::cuda::getCurrentCUDAStream());
}
explicit EventHandle(const at::cuda::CUDAStream& stream) {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(stream);
}
EventHandle(const EventHandle& other) = default;
void current_stream_wait() const {
at::cuda::getCurrentCUDAStream().unwrap().wait(*event);
}
};
torch::Event create_event(const at::cuda::CUDAStream &s) {
auto event = torch::Event(torch::kCUDA);
event.record(s);
return event;
}
void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {
EP_HOST_ASSERT(s_0.id() != s_1.id());
s_0.unwrap().wait(create_event(s_1));
}
void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) {
s.unwrap().wait(*event.event);
}
} // namespace deep_ep
function(add_deep_ep_library target_name source_file)
add_library(${target_name} STATIC ${source_file})
set_target_properties(${target_name} PROPERTIES
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD_REQUIRED ON
CXX_STANDARD 14
CUDA_STANDARD 14
CUDA_SEPARABLE_COMPILATION ON
)
target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)
endfunction()
add_deep_ep_library(intranode_cuda intranode.cu)
add_deep_ep_library(runtime_cuda runtime.cu)
add_deep_ep_library(internode_cuda internode.cu)
add_deep_ep_library(internode_ll_cuda internode_ll.cu)
# Later, we should link all libraries in `EP_CUDA_LIBRARIES`
set(EP_CUDA_LIBRARIES intranode_cuda runtime_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)
#pragma once
#include <vector>
namespace deep_ep {
// Intranode runtime
namespace intranode {
void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
} // namespace intranode
// Internode runtime
namespace internode {
std::vector<uint8_t> get_unique_id();
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode);
void *alloc(size_t size, size_t alignment);
void free(void *ptr);
void barrier();
void finalize();
} // namespace internode
// Intranode kernels
namespace intranode {
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
cudaStream_t stream, int num_sms);
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, int num_ranks,
cudaStream_t stream);
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
void combine(cudaDataType_t type,
void* recv_x, float* recv_topk_weights,
const void* x, const float* topk_weights,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);
} // namespace intranode
// Internode kernels
namespace internode {
int get_source_meta_bytes();
void get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream);
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
const bool* is_token_in_rank, int num_tokens, int num_channels,
int hidden_int4, int num_scales, int num_topk, int expert_alignment,
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank,
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool low_latency_mode);
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
const bool* is_token_in_rank,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, bool is_cached_dispatch,
cudaStream_t stream, int num_channels, bool low_latency_mode);
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head,
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode);
void combine(cudaDataType_t type,
void* combined_x, float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const void* x, const float* topk_weights,
const int* combined_rdma_head, const int* combined_nvl_head,
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode);
} // namespace internode
// Internode low-latency kernels
namespace internode_ll {
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
cudaStream_t stream);
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases);
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases);
} // namespace internode_ll
} // namespace deep_ep
#pragma once
#include "configs.cuh"
#include "exception.cuh"
namespace deep_ep {
template <typename dtype_t>
struct Buffer {
private:
uint8_t* ptr;
public:
int total_bytes;
__device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}
__device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) {
total_bytes = num_elems * sizeof(dtype_t);
ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
return *this;
}
__device__ __forceinline__ dtype_t* buffer() {
return reinterpret_cast<dtype_t*>(ptr);
}
__device__ __forceinline__ dtype_t& operator[](int idx) {
return buffer()[idx];
}
};
template <typename dtype_t, int kNumRanks = 1>
struct AsymBuffer {
private:
uint8_t* ptrs[kNumRanks];
int num_bytes;
public:
int total_bytes;
__device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "");
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, "");
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
for (int i = 0; i < kNumRanks; ++ i) {
ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
}
}
__device__ __forceinline__ void advance(int shift) {
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i)
ptrs[i] = ptrs[i] + shift * sizeof(dtype_t);
}
__device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
return *this;
}
template<int kNumAlsoRanks>
__device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) {
for (int i = 0; i < kNumAlsoRanks; ++ i)
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
return *this;
}
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t*>(ptrs[rank_idx] + num_bytes * idx);
}
};
template <typename dtype_t, bool kDecoupled = true>
struct SymBuffer {
private:
// NOTES: for non-decoupled case, `recv_ptr` is not used
uint8_t* send_ptr;
uint8_t* recv_ptr;
int num_bytes;
public:
int total_bytes;
__device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1) {
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t*>(recv_ptr + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case");
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
}
};
} // namespace deep_ep
#pragma once
#define NUM_MAX_NVL_PEERS 8
#define NUM_MAX_RDMA_PEERS 20
#define NUM_MAX_FIFO_SLOTS 32768
#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)
#define NUM_MAX_LOCAL_EXPERTS 1024
#define NUM_BUFFER_ALIGNMENT_BYTES 128
#define FINISHED_SUM_TAG 1024
#define NUM_CPU_TIMEOUT_SECS 100
#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s
#define NUM_WAIT_NANOSECONDS 500
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
// Make CLion CUDA indexing work
#ifdef __CLION_IDE__
#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier)
#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier)
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
#define printf host_device_printf
#endif
// Remove Torch restrictions
#ifdef __CUDA_NO_HALF_CONVERSIONS__
#undef __CUDA_NO_HALF_CONVERSIONS__
#endif
#ifdef __CUDA_NO_HALF_OPERATORS__
#undef __CUDA_NO_HALF_OPERATORS__
#endif
#ifdef __CUDA_NO_HALF2_OPERATORS__
#undef __CUDA_NO_HALF2_OPERATORS__
#endif
#ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__
#undef __CUDA_NO_BFLOAT16_CONVERSIONS__
#endif
#ifdef __CUDA_NO_BFLOAT162_OPERATORS__
#undef __CUDA_NO_BFLOAT162_OPERATORS__
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <nvshmem.h>
#include <nvshmemx.h>
#include <infiniband/mlx5dv.h>
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
#include <device_host_transport/nvshmem_common_ibgda.h>
#pragma once
#include <string>
#include <exception>
#include "configs.cuh"
#ifndef EP_STATIC_ASSERT
#define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
#endif
class EPException: public std::exception {
private:
std::string message = {};
public:
explicit EPException(const char *name, const char* file, const int line, const std::string& error) {
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
}
const char *what() const noexcept override { return message.c_str(); }
};
#ifndef CUDA_CHECK
#define CUDA_CHECK(cmd) \
do { \
cudaError_t e = (cmd); \
if (e != cudaSuccess) { \
throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \
} \
} while (0)
#endif
#ifndef EP_HOST_ASSERT
#define EP_HOST_ASSERT(cond) \
do { \
if (not (cond)) { \
throw EPException("Assertion", __FILE__, __LINE__, #cond); \
} \
} while (0)
#endif
#ifndef EP_DEVICE_ASSERT
#define EP_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while (0)
#endif
// Portions derived from NVSHMEM (https://developer.nvidia.com/nvshmem)
// Copyright (c) NVIDIA Corporation.
// Licensed under the NVSHMEM Software License Agreement (version: September 3, 2019).
// See full license at: https://docs.nvidia.com/nvshmem/api/sla.html
//
// Modified from original source:
// - nvshmem/src/include/non_abi/device/pt-to-pt/ibgda_device.cuh
#pragma once
#include "configs.cuh"
#include "exception.cuh"
#include "utils.cuh"
namespace deep_ep {
EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth");
__device__ static __forceinline__
uint64_t HtoBE64(uint64_t x) {
uint64_t ret;
asm("{\n\t"
".reg .b32 ign;\n\t"
".reg .b32 lo;\n\t"
".reg .b32 hi;\n\t"
".reg .b32 new_lo;\n\t"
".reg .b32 new_hi;\n\t"
"mov.b64 {lo,hi}, %1;\n\t"
"prmt.b32 new_hi, lo, ign, 0x0123;\n\t"
"prmt.b32 new_lo, hi, ign, 0x0123;\n\t"
"mov.b64 %0, {new_lo,new_hi};\n\t"
"}" : "=l"(ret) : "l"(x));
return ret;
}
__device__ static __forceinline__
uint32_t HtoBE32(uint32_t x) {
uint32_t ret;
asm("{\n\t"
".reg .b32 ign;\n\t"
"prmt.b32 %0, %1, ign, 0x0123;\n\t"
"}" : "=r"(ret) : "r"(x));
return ret;
}
__device__ static __forceinline__
uint16_t HtoBE16(uint16_t x) {
// TODO: simplify PTX using 16-bit instructions
auto a = static_cast<uint32_t>(x);
uint32_t d;
asm volatile(
"{\n\t"
".reg .b32 mask;\n\t"
".reg .b32 ign;\n\t"
"mov.b32 mask, 0x4401;\n\t"
"mov.b32 ign, 0x0;\n\t"
"prmt.b32 %0, %1, ign, mask;\n\t"
"}"
: "=r"(d)
: "r"(a));
return static_cast<uint16_t>(d);
}
typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t;
__device__ static __forceinline__
nvshmemi_ibgda_device_state_t* ibgda_get_state() {
return &nvshmemi_ibgda_device_state_d;
}
__device__ static __forceinline__
nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) {
auto state = ibgda_get_state();
const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe;
return &state->globalmem.rcs[pe * num_rc_per_pe + id % num_rc_per_pe];
}
__device__ static __forceinline__
void ibgda_lock_acquire(int *lock) {
while (atomicCAS(lock, 0, 1) == 1);
// Prevent reordering before the lock is acquired
memory_fence_cta();
}
__device__ static __forceinline__
void ibgda_lock_release(int *lock) {
memory_fence_cta();
// Prevent reordering before lock is released
st_na_relaxed(lock, 0);
}
__device__ static __forceinline__
void ibgda_update_dbr(nvshmemi_ibgda_device_qp_t *qp, uint32_t dbrec_head) {
// `DBREC` contains the index of the next empty `WQEBB`
__be32 dbrec_val;
__be32 *dbrec_ptr = qp->tx_wq.dbrec;
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(dbrec_head & 0xffff))`
asm("{\n\t"
".reg .b32 dbrec_head_16b;\n\t"
".reg .b32 ign;\n\t"
"and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
"}"
: "=r"(dbrec_val)
: "r"(dbrec_head));
st_na_release(dbrec_ptr, dbrec_val);
}
__device__ static __forceinline__
void ibgda_ring_db(nvshmemi_ibgda_device_qp_t *qp, uint16_t prod_idx) {
auto bf_ptr = reinterpret_cast<uint64_t*>(qp->tx_wq.bf);
ibgda_ctrl_seg_t ctrl_seg = {
.opmod_idx_opcode = HtoBE32(prod_idx << 8),
.qpn_ds = HtoBE32(qp->qpn << 8)
};
EP_STATIC_ASSERT(sizeof(decltype(&ctrl_seg)) == sizeof(uint64_t), "");
st_na_release(bf_ptr, *(reinterpret_cast<uint64_t*>(&ctrl_seg)));
}
__device__ static __forceinline__
void ibgda_post_send(nvshmemi_ibgda_device_qp_t *qp, uint64_t new_prod_idx) {
nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
uint64_t old_prod_idx;
// Update `prod_idx` before ringing the doorbell, so that we know which index is needed in quiet/fence
ibgda_lock_acquire(&mvars->post_send_lock);
old_prod_idx = atomicMax(reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.prod_idx), new_prod_idx);
if (new_prod_idx > old_prod_idx) {
ibgda_update_dbr(qp, new_prod_idx);
ibgda_ring_db(qp, new_prod_idx);
}
ibgda_lock_release(&mvars->post_send_lock);
}
template <bool kAlwaysDoPostSend>
__device__ static __forceinline__
void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx,
uint32_t num_wqes, int message_idx = 0) {
nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
uint64_t new_wqe_idx = base_wqe_idx + num_wqes;
// WQE writes must be finished first
__threadfence();
// Wait for prior WQE slots to be filled first
auto *ready_idx = reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.ready_head);
while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx);
// Always post, not in batch
constexpr int kNumRequestInBatch = 4;
if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0)
ibgda_post_send(qp, new_wqe_idx);
}
__device__ static __forceinline__ void
ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *val, uint64_t raddr,
__be32 rkey, uint16_t wqe_idx, void **out_wqes, uint32_t imm) {
ibgda_ctrl_seg_t ctrl_seg;
struct mlx5_wqe_raddr_seg raddr_seg;
struct mlx5_wqe_inl_data_seg inl_seg;
auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
auto *raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
auto *inl_seg_ptr = reinterpret_cast<mlx5_wqe_inl_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
auto *wqe_data_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(inl_seg_ptr) + sizeof(*inl_seg_ptr));
raddr_seg.raddr = HtoBE64(raddr);
raddr_seg.rkey = rkey;
raddr_seg.reserved = 0;
inl_seg.byte_count = HtoBE32(4 | MLX5_INLINE_SEG);
// `imm == std::numeric_limits<uint32_t>::max()` means no imm writes
ctrl_seg = {0};
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | (imm != std::numeric_limits<uint32_t>::max() ? MLX5_OPCODE_RDMA_WRITE_IMM : MLX5_OPCODE_RDMA_WRITE));
if (imm != std::numeric_limits<uint32_t>::max())
ctrl_seg.imm = HtoBE32(imm);
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*inl_seg_ptr) == 4, "sizeof(*inl_seg_ptr) == 4");
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
st_na_relaxed(reinterpret_cast<uint32_t*>(inl_seg_ptr), *reinterpret_cast<const uint32_t*>(&inl_seg));
st_na_relaxed(reinterpret_cast<uint32_t*>(wqe_data_ptr), *reinterpret_cast<const uint32_t*>(val));
}
__device__ static __forceinline__
uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
auto state = ibgda_get_state();
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
auto log2_cumem_granularity = state->log2_cumem_granularity;
// Local key
uint64_t idx = (laddr - heap_start) >> log2_cumem_granularity;
auto device_key = state->constmem.lkeys[idx];
auto lchunk_size = device_key.next_addr - laddr;
*lkey = device_key.key;
// Remote key
uint64_t roffset = raddr - heap_start;
idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) {
device_key = state->constmem.rkeys[idx];
} else {
device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
}
*out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
*out_rkey = device_key.key;
// Return the minimum of local and remote chunk sizes
auto rchunk_size = device_key.next_addr - roffset;
return min(lchunk_size, rchunk_size);
}
__device__ static __forceinline__ void
ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
auto state = ibgda_get_state();
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
uint64_t roffset = addr - heap_start;
uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
nvshmemi_ibgda_device_key_t device_key;
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS)
device_key = state->constmem.rkeys[idx];
else
device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
*out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
*out_rkey = device_key.key;
}
__device__ static __forceinline__ uint64_t
ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) {
auto mvars = &qp->mvars;
return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
}
__device__ static __forceinline__ void*
ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
uint16_t cnt = qp->tx_wq.nwqes;
uint16_t idx = wqe_idx & (cnt - 1);
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));
}
// Wait until wqe `idx - 1` is completed.
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. It can only be used for polling recv.
// Because we post recv and poll recv in the same thread, so we don't need to maintain queue status.
__device__ static __forceinline__ void
nvshmemi_ibgda_poll_recv(int dst_pe, int qp_id) {
auto qp = ibgda_get_rc(dst_pe, qp_id);
auto cq = qp->rx_wq.cq;
const uint32_t ncqes = cq->ncqes;
auto *cqe64 = reinterpret_cast<struct mlx5_cqe64*>(cq->cqe);
auto old_cons_idx = *cq->cons_idx;
*cq->cons_idx = old_cons_idx + 1;
// Wait until `wqe_counter >= old_cons_idx`
while ((static_cast<uint16_t>(old_cons_idx - HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)) - 1) < ncqes));
}
__device__ static __forceinline__ void
nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits<uint32_t>::max()) {
// Get rkey
// NOTES: the `p` operation will not cross multiple remote chunks
__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey);
// Write WQEs
auto qp = ibgda_get_rc(dst_pe, qp_id);
uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs;
wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx);
ibgda_write_rdma_write_inl_wqe(qp, reinterpret_cast<const uint32_t*>(&value), raddr, rkey, base_wqe_idx, &wqe_ptrs, imm);
// Submit requests
ibgda_submit_requests<true>(qp, base_wqe_idx, 1);
}
__device__ static __forceinline__ void
ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be32 lkey,
uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx,
void **out_wqes) {
ibgda_ctrl_seg_t ctrl_seg;
struct mlx5_wqe_raddr_seg raddr_seg;
struct mlx5_wqe_data_seg data_seg;
auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
void *av_seg_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
struct mlx5_wqe_raddr_seg *raddr_seg_ptr;
struct mlx5_wqe_data_seg *data_seg_ptr;
raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(av_seg_ptr));
data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
raddr_seg.raddr = HtoBE64(raddr);
raddr_seg.rkey = rkey;
raddr_seg.reserved = 0;
data_seg.byte_count = HtoBE32(bytes);
data_seg.lkey = lkey;
data_seg.addr = HtoBE64(laddr);
ctrl_seg = {0};
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE);
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == 16, "sizeof(*data_seg_ptr) == 16");
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
}
__device__ static __forceinline__ void
ibgda_write_empty_recv_wqe(void *out_wqe) {
auto *data_seg_ptr = reinterpret_cast<struct mlx5_wqe_data_seg*>(out_wqe);
struct mlx5_wqe_data_seg data_seg;
// Make the first segment in the WQE invalid, then the entire list will be invalid
data_seg.byte_count = 0;
data_seg.lkey = HtoBE64(MLX5_INVALID_LKEY);
data_seg.addr = 0;
EP_STATIC_ASSERT(sizeof(mlx5_wqe_data_seg) == sizeof(int4), "Invalid data type length");
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
}
__device__ static __forceinline__ uint64_t
nvshmemi_ibgda_allocate_recvs(nvshmemi_ibgda_device_qp* qp) {
auto mvars = &qp->mvars;
// Allocate if not enough
constexpr int kMinIBGDARecvs = 32;
auto resv_head = mvars->rx_wq.resv_head;
auto num_valid_slots = resv_head - mvars->rx_wq.cons_idx;
if (num_valid_slots < kMinIBGDARecvs) {
resv_head = mvars->rx_wq.cons_idx + qp->rx_wq.nwqes;
mvars->rx_wq.resv_head = resv_head;
// Ensure WQE is written before `dbrec`
__be32 dbrec_val;
__be32 *dbrec_ptr = qp->rx_wq.dbrec;
// Compared to sending, for each QP, we only post recv in a single thread,
// so we don't need to do synchronization here
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(wqe_idx & 0xffff))`
asm("{\n\t"
".reg .b32 dbrec_head_16b;\n\t"
".reg .b32 ign;\n\t"
"and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
"}" : "=r"(dbrec_val)
: "r"(static_cast<uint32_t>(resv_head)));
st_na_release(dbrec_ptr, dbrec_val);
}
// Return old number of slots
return num_valid_slots;
}
__device__ static __forceinline__ void
nvshmemi_ibgda_prepare_recvs(int dst_rank, int qp_id) {
// NOTES: only one thread can run this function
// TODO: consider this assertion for normal AR
EP_DEVICE_ASSERT(nvshmemi_ibgda_allocate_recvs(ibgda_get_rc(dst_rank, qp_id)) > 16);
}
__device__ static __forceinline__ void
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
// Get lkey and rkey, store them into lanes
uint32_t num_wqes = 0;
__be32 my_lkey = 0;
uint64_t my_laddr = 0;
__be32 my_rkey = 0;
uint64_t my_raddr = 0;
uint64_t my_chunk_size = 0;
// Decide how many messages (theoretically 3 for maximum)
auto remaining_bytes = bytes;
while (remaining_bytes > 0) {
if (lane_id == num_wqes)
my_chunk_size = min(remaining_bytes, ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey));
// Move one more message
auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes));
remaining_bytes -= chunk_size;
req_lptr += chunk_size;
req_rptr += chunk_size;
++ num_wqes;
}
EP_DEVICE_ASSERT(num_wqes <= 32);
// Process WQE
auto qp = ibgda_get_rc(dst_pe, qp_id);
uint64_t base_wqe_idx = 0;
if (lane_id == 0)
base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes);
base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0);
if (lane_id < num_wqes) {
auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id);
ibgda_write_rdma_write_wqe(qp, my_laddr, my_lkey, my_raddr, my_rkey, my_chunk_size,
base_wqe_idx, &wqe_ptr);
}
__syncwarp();
// Submit
if (lane_id == 0)
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
__syncwarp();
}
} // namespace deep_ep
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#pragma once
#include "configs.cuh"
#ifndef SETUP_LAUNCH_CONFIG
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \
cudaLaunchAttribute attr[1]; \
attr[0].id = cudaLaunchAttributeCooperative; \
attr[0].val.cooperative = 1; \
cfg.attrs = attr; \
cfg.numAttrs = 1
#endif
#ifndef LAUNCH_KERNEL
#define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__))
#endif
#define SWITCH_RANKS(case_macro) \
switch (num_ranks) { \
case 2: case_macro(2); \
case 4: case_macro(4); \
case 8: case_macro(8); \
default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
} while (false)
#define SWITCH_RDMA_RANKS(case_macro) \
switch (num_ranks / NUM_MAX_NVL_PEERS) { \
case 2: case_macro(2); \
case 3: case_macro(3); \
case 4: case_macro(4); \
case 8: case_macro(8); \
case 16: case_macro(16); \
case 18: case_macro(18); \
case 20: case_macro(20); \
default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
} while (false)
#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \
switch (num_ranks) { \
case 2: case_macro(dtype, 2); \
case 4: case_macro(dtype, 4); \
case 8: case_macro(dtype, 8); \
default: EP_HOST_ASSERT(false && "Unsupported ranks"); \
} while (false)
#define SWITCH_TYPES(case_macro) \
switch (type) { \
case CUDA_R_16BF: case_macro(nv_bfloat16); \
case CUDA_R_32F: case_macro(float); \
default: EP_HOST_ASSERT(false && "Unsupported type"); \
} while (false)
#define SWITCH_HIDDEN(case_macro) \
switch (hidden) { \
case 2560: case_macro(2560); \
case 5120: case_macro(5120); \
case 7168: case_macro(7168); \
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \
} while (false)
#include <vector>
#include <cstring>
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
#include "ibgda_device.cuh"
namespace deep_ep {
namespace intranode {
template<int kNumRanks>
__global__ void barrier(int** task_fifo_ptrs, int head, int rank) {
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
}
void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) {
#define BARRIER_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, barrier<ranks>, task_fifo_ptrs, head, rank); \
break
SETUP_LAUNCH_CONFIG(1, 32, stream);
SWITCH_RANKS(BARRIER_LAUNCH_CASE);
#undef BARRIER_LAUNCH_CASE
}
} // namespace intranode
namespace internode {
nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;
nvshmem_team_config_t cpu_rdma_team_config;
std::vector<uint8_t> get_unique_id() {
nvshmemx_uniqueid_t unique_id;
nvshmemx_get_uniqueid(&unique_id);
std::vector<uint8_t> result(sizeof(nvshmemx_uniqueid_t));
std::memcpy(result.data(), &unique_id, sizeof(nvshmemx_uniqueid_t));
return result;
}
__global__ void ibgda_initialize_recv_queue(int rank) {
auto thread_idx = static_cast<int>(threadIdx.x);
auto num_threads = static_cast<int>(blockDim.x);
auto dst_rank = static_cast<int>(blockIdx.x);
if (dst_rank != rank) {
for (int qp_id = thread_idx; qp_id < ibgda_get_state()->num_rc_per_pe; qp_id += num_threads) {
auto qp = ibgda_get_rc(dst_rank, qp_id);
// Clean some necessary variables
for (int i = 0; i < qp->rx_wq.nwqes; ++ i)
ibgda_write_empty_recv_wqe(ibgda_get_wqe_ptr(qp, i));
qp->mvars.rx_wq.resv_head = 0;
qp->mvars.rx_wq.cons_idx = 0;
// Allocate receive slots
nvshmemi_ibgda_allocate_recvs(qp);
}
}
}
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
nvshmemx_uniqueid_t root_unique_id;
nvshmemx_init_attr_t attr;
std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(nvshmemx_uniqueid_t));
nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
// Create sub-RDMA teams
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
EP_HOST_ASSERT(cpu_rdma_team == NVSHMEM_TEAM_INVALID);
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(nvshmem_team_split_strided(NVSHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS,
num_ranks / NUM_MAX_NVL_PEERS, &cpu_rdma_team_config, 0, &cpu_rdma_team) == 0);
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
}
// Normal operations use IBRC, while low-latency operations use IBGDA
if (low_latency_mode) {
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
bool ibgda_is_initialized = false;
cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice);
// Initialize recv queues for low-latency mode AR
ibgda_initialize_recv_queue<<<num_ranks, 128>>>(rank);
}
nvshmem_barrier_all();
return nvshmem_my_pe();
}
void* alloc(size_t size, size_t alignment) {
return nvshmem_align(alignment, size);
}
void free(void* ptr) {
nvshmem_free(ptr);
}
void barrier() {
nvshmem_barrier_all();
}
void finalize() {
if (cpu_rdma_team != NVSHMEM_TEAM_INVALID) {
nvshmem_team_destroy(cpu_rdma_team);
cpu_rdma_team = NVSHMEM_TEAM_INVALID;
}
nvshmem_finalize();
}
} // namespace internode
} // namespace deep_ep
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment