// !!! This is a file automatically generated by hipify!!! #include #pragma once #include "kernels/api.cuh" #include "kernels/configs.cuh" #include "kernels/exception.cuh" namespace deep_ep { struct Config { int num_sms; int num_max_nvl_chunked_send_tokens; int num_max_nvl_chunked_recv_tokens; int num_max_rdma_chunked_send_tokens; int num_max_rdma_chunked_recv_tokens; Config(int num_sms, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) : num_sms(num_sms), num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens), num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens), num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens), num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) { EP_HOST_ASSERT(num_sms >= 0); EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0); EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens); EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0); // Ceil up RDMA buffer size this->num_max_rdma_chunked_recv_tokens = ALIGN(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens); EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens); // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always // have space to push EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); } size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { // Below are some assumptions // TODO: add assertions constexpr int kNumMaxTopK = 128; constexpr int kNumMaxScales = 128; EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % (2 * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL) == 0); const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); const int num_channels = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL; // 计算每个nvl通信数据包的数据量 size_t num_single_nvl_bag_bytes = hidden_bytes + // 数据缓冲区(Token Data)。存储从 RDMA 转发过来的 token 数据(x 张量) #ifndef DISABLE_ROCSHMEM internode::get_source_meta_bytes() + // 源元数据缓冲区(Source Metadata)。存储每个 token 的源信息(哪个 RDMA rank 发送的) #endif kNumMaxTopK * sizeof(int) + // TopK 索引缓冲区。存储每个 token 的 top-k 专家索引 kNumMaxTopK * sizeof(float) + // TopK 权重缓冲区。存储每个 token 的 top-k 专家权重 kNumMaxScales * sizeof(float); // Scale 缓冲区。存储每个 token 的量化缩放因子 // 计算每个 NVL channel 的控制信息所需的字节数,存储每个 NVL channel 的前缀索引信息,用于快速定位数据(nvl_channel_prefix_start、nvl_channel_prefix_end 等) size_t num_single_nvl_control_bytes = (2 * num_rdma_ranks + 3) * sizeof(int); // NVL 数据总的字节数 size_t num_bytes = (num_single_nvl_bag_bytes * num_max_nvl_chunked_recv_tokens + num_single_nvl_control_bytes) * num_channels * num_nvl_ranks; // 128 字节对齐,匹配 GPU 缓存行大小,优化内存访问。 num_bytes = ((num_bytes + 127) / 128) * 128; return num_bytes; } size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { #ifndef DISABLE_ROCSHMEM // Legacy mode if (num_ranks <= NUM_MAX_NVL_PEERS) return 0; // Below are some assumptions // TODO: add assertions constexpr int kNumMaxTopK = 128; constexpr int kNumMaxScales = 128; EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); EP_HOST_ASSERT(num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0); const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const int num_channels = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL; // 计算每个rdma通信数据包的数据量 size_t num_single_rdma_bag_bytes = hidden_bytes + // 数据缓冲区。存储实际的 token 数据(x 张量),对应代码中的 rdma_channel_data internode::get_source_meta_bytes() + // 源元数据缓冲区。存储每个 token 的源信息(SourceMeta) kNumMaxTopK * sizeof(int) + // 存储每个 token 的 top-k 专家索引。对应 topk_idx 数据 kNumMaxTopK * sizeof(float) + // 存储每个 token 的 top-k 专家权重。对应 topk_weights 数据 kNumMaxScales * sizeof(float) + // 存储每个 token 的缩放因子(x_scales) sizeof(int4); // 预留空间用于内存对齐和未来扩展 // 计算每个 RDMA channel 的控制信息(起始/结束索引)所需的字节数,对应代码中的 rdma_channel_meta size_t num_single_rdma_control_bytes = (NUM_MAX_NVL_PEERS * 2 + 4) * sizeof(int); // RDMA 数据总的字节数 size_t num_bytes = (num_single_rdma_bag_bytes * num_max_rdma_chunked_recv_tokens + num_single_rdma_control_bytes) * num_channels * num_rdma_ranks * 2; // 128 字节对齐(缓存行对齐),优化内存访问性能 num_bytes = ((num_bytes + 127) / 128) * 128; return num_bytes; #else EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install " "rocSHMEM by following docs/install_dependencies.md"); #endif } }; struct LowLatencyBuffer { int num_clean_int = 0; void* dispatch_rdma_send_buffer = nullptr; void* dispatch_rdma_recv_data_buffer = nullptr; int64_t* dispatch_rdma_recv_count_buffer = nullptr; void* combine_rdma_send_buffer = nullptr; void* combine_rdma_recv_data_buffer = nullptr; int64_t* combine_rdma_recv_flag_buffer = nullptr; void* combine_rdma_send_buffer_data_start = nullptr; size_t num_bytes_per_combine_msg = 0; std::pair clean_meta() { EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); return {dispatch_rdma_recv_count_buffer, num_clean_int}; } }; struct LowLatencyLayout { size_t total_bytes = 0; LowLatencyBuffer buffers[2]; template out_ptr_t advance(const in_ptr_t &ptr, size_t count) { return reinterpret_cast(reinterpret_cast(ptr) + count); } LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts, bool enable_dispatch_ll_layered=false, int quant_group_size=0) { const int num_scales = hidden / QUANTIZATION_GROUPSIZE; const int num_nodes = num_ranks / NUM_MAX_NVL_PEERS; // 计算结点数 // Dispatch and combine layout: // - 2 symmetric odd/even send buffer // - 2 symmetric odd/even receive buffers // - 2 symmetric odd/even signaling buffers // Message sizes // NOTES: you should add a control `int4` for combine messages if you want to do data // transformation EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast(hidden)); size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(hip_bfloat16), hidden + (quant_group_size == 0 ? 4 : num_scales) * sizeof(float)); // 应该是1,但是代码中为了满足int4对齐 // 与internode_ll::combine 中的 num_bytes_per_slot 相等 size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) + (enable_dispatch_ll_layered ? 0 : // 即enable_combine_overlap==true,执行函数combine_sbo num_scales * sizeof(__hip_bfloat162)); // Send buffer size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); total_bytes += send_buffer_bytes * 2; // Symmetric receive buffers // TODO: optimize memory usages size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); total_bytes += recv_buffer_bytes * 2; // Symmetric signaling buffers size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t); if (enable_dispatch_ll_layered) { dispatch_recv_count_buffer_bytes += NUM_MAX_NVL_PEERS * num_nodes * num_max_dispatch_tokens_per_rank * sizeof(int) + NUM_MAX_NVL_PEERS * sizeof(int); } size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); size_t signaling_buffer_bytes_aligned = ALIGN(signaling_buffer_bytes, 128); total_bytes += signaling_buffer_bytes_aligned * 2; // Assign pointers // NOTES: we still leave some space for distinguishing dispatch/combine buffer, // so you may see some parameters are duplicated for (int i = 0; i < 2; ++i) { buffers[i] = { static_cast(signaling_buffer_bytes / sizeof(int64_t)), // dispatch:send_buffer + recv_buffer + recv_count advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * i), // combine:send_buffer + recv_buffer + recv_count advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * i), // combine_rdma_send_buffer_data_start advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), // num_bytes_per_combine_msg }; } } }; inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts, bool enable_dispatch_ll_layered=false, int quant_group_size=0) { auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, enable_dispatch_ll_layered, quant_group_size) .total_bytes; return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; } } // namespace deep_ep