#pragma once #include #include #include #include #include #include "kernels/configs.cuh" #include "kernels/exception.cuh" #include "config.hpp" #include "event.hpp" namespace deep_ep { struct Buffer { EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); private: // Low-latency mode buffer int low_latency_buffer_idx = 0; bool low_latency_mode = false; // NVLink Buffer int64_t num_nvl_bytes; void *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void **buffer_ptrs_gpu = nullptr; void* nvl_buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void** nvl_buffer_ptrs_gpu = nullptr; // NVSHMEM Buffer int64_t num_rdma_bytes; void *rdma_buffer_ptr = nullptr; // Shrink mode buffer bool enable_shrink = false; int* mask_buffer_ptr = nullptr; int* sync_buffer_ptr = nullptr; // Device info and communication int device_id; int num_device_sms; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; hipIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; // Stream for communication at::hip::HIPStreamMasqueradingAsCUDA comm_stream; // After IPC/NVSHMEM synchronization, this flag will be true bool available = false; // Whether explicit `destroy()` is required. bool explicitly_destroy; // After `destroy()` be called, this flag will be true bool destroyed = false; // Barrier signals int *barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; int **barrier_signal_ptrs_gpu = nullptr; // Workspace void *workspace = nullptr; // Host-side MoE info volatile int *moe_recv_counter = nullptr; int *moe_recv_counter_mapped = nullptr; // Host-side expert-level MoE info volatile int *moe_recv_expert_counter = nullptr; int *moe_recv_expert_counter_mapped = nullptr; // Host-side RDMA-level MoE info volatile int *moe_recv_rdma_counter = nullptr; int *moe_recv_rdma_counter_mapped = nullptr; public: Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy, bool enable_shrink); ~Buffer() noexcept(false); bool is_available() const; bool is_internode_available() const; int get_num_rdma_ranks() const; int get_rdma_rank() const; int get_root_rdma_rank(bool global) const; int get_local_device_id() const; pybind11::bytearray get_local_ipc_handle() const; pybind11::bytearray get_local_nvshmem_unique_id() const; torch::Tensor get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset, bool use_rdma_buffer) const; torch::Stream get_comm_stream() const; void sync(const std::vector &device_ids, const std::vector> &all_gathered_handles, const std::optional &root_unique_id_opt); void destroy(); std::tuple, torch::Tensor, torch::Tensor, std::optional> get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std::optional &previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> intranode_dispatch(const torch::Tensor &x, const std::optional &x_scales, const std::optional &topk_idx, const std::optional &topk_weights, const std::optional &num_tokens_per_rank, const torch::Tensor &is_token_in_rank, const std::optional &num_tokens_per_expert, int cached_num_recv_tokens, const std::optional &cached_rank_prefix_matrix, const std::optional &cached_channel_prefix_matrix, int expert_alignment, int num_worst_tokens, const Config &config, std::optional &previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional> intranode_combine(const torch::Tensor &x, const std::optional &topk_weights, const std::optional &bias_0, const std::optional &bias_1, const torch::Tensor &src_idx, const torch::Tensor &rank_prefix_matrix, const torch::Tensor &channel_prefix_matrix, const torch::Tensor &send_head, const Config &config, std::optional &previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> internode_dispatch(const torch::Tensor &x, const std::optional &x_scales, const std::optional &topk_idx, const std::optional &topk_weights, const std::optional &num_tokens_per_rank, const std::optional &num_tokens_per_rdma_rank, const torch::Tensor &is_token_in_rank, const std::optional &num_tokens_per_expert, int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, const std::optional &cached_rdma_channel_prefix_matrix, const std::optional &cached_recv_rdma_rank_prefix_sum, const std::optional &cached_gbl_channel_prefix_matrix, const std::optional &cached_recv_gbl_rank_prefix_sum, int expert_alignment, const Config &config, std::optional &previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional> internode_combine( const torch::Tensor &x, const std::optional &topk_weights, const std::optional &bias_0, const std::optional &bias_1, const torch::Tensor &src_meta, const torch::Tensor &is_combined_token_in_rank, const torch::Tensor &rdma_channel_prefix_matrix, const torch::Tensor &rdma_rank_prefix_sum, const torch::Tensor &gbl_channel_prefix_matrix, const torch::Tensor &combined_rdma_head, const torch::Tensor &combined_nvl_head, const Config &config, std::optional &previous_event, bool async, bool allocate_on_comm_stream); void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool async, bool return_recv_hook); std::tuple, std::optional>> low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, const torch::Tensor& src_info, const torch::Tensor& layout_range, int num_max_dispatch_tokens_per_rank, int num_experts, bool zero_copy, bool async, bool return_recv_hook, const std::optional& out = std::nullopt); torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); }; } // namespace deep_ep