#include #include #include #include #include #include #include "./kernels/api.cuh" #include "./kernels/configs.cuh" #include "deep_ep.hpp" namespace deep_ep { Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy, bool use_default_stream_as_comm_stream) : rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), low_latency_mode(low_latency_mode), explicitly_destroy(explicitly_destroy), use_default_stream_as_comm_stream(use_default_stream_as_comm_stream), comm_stream(use_default_stream_as_comm_stream ? at::hip::getCurrentHIPStreamMasqueradingAsCUDA() : at::hip::getStreamFromPoolMasqueradingAsCUDA(true)) { // Metadata memory int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void *); int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int *); // Common checks EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); if (num_rdma_bytes > 0) EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode); // Get ranks CUDA_CHECK(hipGetDevice(&device_id)); rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); #ifdef DISABLE_ROCSHMEM EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "rocSHMEM is disabled during compilation, please install rocSHMEM by " "following docs/install_dependencies.md"); #endif // Get device info hipDeviceProp_t device_prop = {}; CUDA_CHECK(hipGetDeviceProperties(&device_prop, device_id)); num_device_sms = device_prop.multiProcessorCount; if (num_nvl_bytes > 0) { // Local IPC: alloc local memory and set local IPC handles CUDA_CHECK(hipExtMallocWithFlags( &buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes, hipDeviceMallocUncached)); CUDA_CHECK(hipIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); buffer_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); // Set barrier signals barrier_signal_ptrs[nvl_rank] = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); barrier_signal_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes); // No need to synchronize, will do a full device sync during `sync` CUDA_CHECK( hipMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream)); } // Create 32 MiB workspace CUDA_CHECK(hipMalloc(&workspace, NUM_WORKSPACE_BYTES)); CUDA_CHECK(hipMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); // MoE counter CUDA_CHECK(hipHostMalloc(&moe_recv_counter, sizeof(int64_t), hipHostMallocMapped)); CUDA_CHECK( hipHostGetDevicePointer(reinterpret_cast(&moe_recv_counter_mapped), const_cast(moe_recv_counter), 0)); *moe_recv_counter = -1; // MoE expert-level counter CUDA_CHECK(hipHostMalloc(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, hipHostMallocMapped)); CUDA_CHECK( hipHostGetDevicePointer(reinterpret_cast(&moe_recv_expert_counter_mapped), const_cast(moe_recv_expert_counter), 0)); for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++i) moe_recv_expert_counter[i] = -1; // MoE RDMA-level counter if (num_rdma_ranks > 0) { CUDA_CHECK( hipHostMalloc(&moe_recv_rdma_counter, sizeof(int), hipHostMallocMapped)); CUDA_CHECK( hipHostGetDevicePointer(reinterpret_cast(&moe_recv_rdma_counter_mapped), const_cast(moe_recv_rdma_counter), 0)); *moe_recv_rdma_counter = -1; } } Buffer::~Buffer() noexcept(false) { if (not explicitly_destroy) { destroy(); } else if (not destroyed) { printf("WARNING: destroy() was not called before DeepEP buffer destruction, which can leak " "resources.\n"); fflush(stdout); } } bool Buffer::is_available() const { return available; } bool Buffer::is_internode_available() const { return is_available() and num_ranks > NUM_MAX_NVL_PEERS; } int Buffer::get_num_rdma_ranks() const { return num_rdma_ranks; } int Buffer::get_rdma_rank() const { return rdma_rank; } int Buffer::get_root_rdma_rank(bool global) const { return global ? nvl_rank : 0; } int Buffer::get_local_device_id() const { return device_id; } pybind11::bytearray Buffer::get_local_ipc_handle() const { return {ipc_handles[nvl_rank].reserved, HIP_IPC_HANDLE_SIZE}; } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { #ifndef DISABLE_ROCSHMEM EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get ROCSHMEM unique ID"); auto unique_id = internode::get_unique_id(); return {reinterpret_cast(unique_id.data()), unique_id.size()}; #else EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install rocSHMEM by " "following docs/install_dependencies.md"); #endif } torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset, bool use_rdma_buffer) const { torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype); auto element_bytes = static_cast(elementSize(casted_dtype)); auto base_ptr = static_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes; return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); } torch::Stream Buffer::get_comm_stream() const { return comm_stream; } void Buffer::destroy() { EP_HOST_ASSERT(not destroyed); // Synchronize CUDA_CHECK(hipDeviceSynchronize()); if (num_nvl_bytes > 0) { // Barrier intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream); CUDA_CHECK(hipDeviceSynchronize()); // Close remote IPC if (is_available()) { for (int i = 0; i < num_nvl_ranks; ++i) if (i != nvl_rank) CUDA_CHECK(hipIpcCloseMemHandle(buffer_ptrs[i])); } // Free local buffer and error flag CUDA_CHECK(hipFree(buffer_ptrs[nvl_rank])); } // Free ROCSHMEM #ifndef DISABLE_ROCSHMEM if (is_available() and num_rdma_bytes > 0) { CUDA_CHECK(hipDeviceSynchronize()); internode::barrier(); internode::free(rdma_buffer_ptr); internode::finalize(); } #endif // Free workspace and MoE counter CUDA_CHECK(hipFree(workspace)); CUDA_CHECK(hipFreeHost(const_cast(moe_recv_counter))); // Free chunked mode staffs CUDA_CHECK(hipFreeHost(const_cast(moe_recv_expert_counter))); destroyed = true; available = false; } void Buffer::sync(const std::vector &device_ids, const std::vector> &all_gathered_handles, const std::optional &root_unique_id_opt) { EP_HOST_ASSERT(not is_available()); // Sync IPC handles if (num_nvl_bytes > 0) { EP_HOST_ASSERT(num_ranks == device_ids.size()); EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size()); for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++i) { EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); auto handle_str = std::string(all_gathered_handles[offset + i].value()); EP_HOST_ASSERT(handle_str.size() == HIP_IPC_HANDLE_SIZE); if (offset + i != rank) { std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), HIP_IPC_HANDLE_SIZE); CUDA_CHECK(hipIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], hipIpcMemLazyEnablePeerAccess)); barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); } else { EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), HIP_IPC_HANDLE_SIZE) == 0); } } // Copy all buffer and barrier signal pointers to GPU CUDA_CHECK(hipMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void *) * NUM_MAX_NVL_PEERS, hipMemcpyHostToDevice)); CUDA_CHECK(hipMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs, sizeof(int *) * NUM_MAX_NVL_PEERS, hipMemcpyHostToDevice)); CUDA_CHECK(hipDeviceSynchronize()); } #ifndef DISABLE_ROCSHMEM // Sync ROCSHMEM handles and allocate memory if (num_rdma_bytes > 0) { // Initialize NVSHMEM EP_HOST_ASSERT(root_unique_id_opt.has_value()); std::vector root_unique_id(root_unique_id_opt->size()); auto root_unique_id_str = root_unique_id_opt->cast(); std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size()); auto nvshmem_rank = low_latency_mode ? rank : rdma_rank; auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks; EP_HOST_ASSERT(nvshmem_rank == internode::init( root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode)); internode::barrier(); // Allocate rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES); // Clean buffer (mainly for low-latency mode) CUDA_CHECK(hipMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); // Barrier internode::barrier(); CUDA_CHECK(hipDeviceSynchronize()); } #endif // Ready to use available = true; } std::tuple, torch::Tensor, torch::Tensor, std::optional> Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std::optional &previous_event, bool async, bool allocate_on_comm_stream) { EP_HOST_ASSERT(topk_idx.dim() == 2); EP_HOST_ASSERT(topk_idx.is_contiguous()); EP_HOST_ASSERT(num_experts > 0); // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::hip::setCurrentHIPStreamMasqueradingAsCUDA(comm_stream); } if (not use_default_stream_as_comm_stream) { // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } } auto num_tokens = static_cast(topk_idx.size(0)), num_topk = static_cast(topk_idx.size(1)); auto num_tokens_per_rank = torch::empty({num_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); auto num_tokens_per_rdma_rank = std::optional(); auto num_tokens_per_expert = torch::empty( {num_experts}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); auto is_token_in_rank = torch::empty( {num_tokens, num_ranks}, torch::TensorOptions().dtype(torch::kBool).device(torch::kCUDA)); if (is_internode_available()) num_tokens_per_rdma_rank = torch::empty( {num_rdma_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); layout::get_dispatch_layout( topk_idx.data_ptr(), num_tokens_per_rank.data_ptr(), num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, num_tokens_per_expert.data_ptr(), is_token_in_rank.data_ptr(), num_tokens, num_topk, num_ranks, num_experts, comm_stream); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto &t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto &to : {num_tokens_per_rdma_rank}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { if (not use_default_stream_as_comm_stream) { stream_wait(compute_stream, comm_stream); } } // Switch back compute stream if (allocate_on_comm_stream) at::hip::setCurrentHIPStreamMasqueradingAsCUDA(compute_stream); return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event}; } std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> Buffer::intranode_dispatch( const torch::Tensor &x, const std::optional &x_scales, const std::optional &topk_idx, const std::optional &topk_weights, const std::optional &num_tokens_per_rank, const torch::Tensor &is_token_in_rank, const std::optional &num_tokens_per_expert, int cached_num_recv_tokens, const std::optional &cached_rank_prefix_matrix, const std::optional &cached_channel_prefix_matrix, int expert_alignment, int num_worst_tokens, const Config &config, std::optional &previous_event, bool async, bool allocate_on_comm_stream) { bool cached_mode = cached_rank_prefix_matrix.has_value(); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for // receiving. EP_HOST_ASSERT(config.num_sms % 2 == 0); int num_channels = config.num_sms / 2; if (cached_mode) { EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value()); EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value()); } else { EP_HOST_ASSERT(num_tokens_per_rank.has_value()); EP_HOST_ASSERT(num_tokens_per_expert.has_value()); } // Type checks EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool); if (cached_mode) { EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32); } else { EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); } // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous()); EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) and is_token_in_rank.size(1) == num_ranks); if (cached_mode) { EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 and cached_rank_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and cached_rank_prefix_matrix->size(1) == num_ranks); EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and cached_channel_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and cached_channel_prefix_matrix->size(1) == num_channels); } else { EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); } auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; // Top-k checks int num_topk = 0; int64_t *topk_idx_ptr = nullptr; float *topk_weights_ptr = nullptr; EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); if (topk_idx.has_value()) { num_topk = static_cast(topk_idx->size(1)); EP_HOST_ASSERT(num_experts > 0); EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); topk_idx_ptr = topk_idx->data_ptr(); topk_weights_ptr = topk_weights->data_ptr(); } // FP8 scales checks float *x_scales_ptr = nullptr; int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); EP_HOST_ASSERT(x_scales->dim() == 2); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); x_scales_ptr = static_cast(x_scales->data_ptr()); scale_token_stride = static_cast(x_scales->stride(0)); scale_hidden_stride = static_cast(x_scales->stride(1)); } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::hip::setCurrentHIPStreamMasqueradingAsCUDA(comm_stream); } // Wait previous tasks to be finished if (not use_default_stream_as_comm_stream) { if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } } // Create handles (only return for non-cached mode) int num_recv_tokens = -1; auto rank_prefix_matrix = torch::Tensor(); auto channel_prefix_matrix = torch::Tensor(); std::vector num_recv_tokens_per_expert_list; torch::Tensor num_recv_tokens_per_expert = torch::empty( {num_local_experts}, torch::TensorOptions().dtype(torch::kLong).device(torch::kCUDA)); // Barrier or send sizes // To clean: channel start/end offset, head and tail int num_memset_int = num_channels * num_ranks * 4; if (cached_mode) { num_recv_tokens = cached_num_recv_tokens; rank_prefix_matrix = cached_rank_prefix_matrix.value(); channel_prefix_matrix = cached_channel_prefix_matrix.value(); // Copy rank prefix matrix and clean flags intranode::cached_notify_dispatch( rank_prefix_matrix.data_ptr(), num_memset_int, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream); } else { rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); channel_prefix_matrix = torch::empty({num_ranks, num_channels}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); // Send sizes // Meta information: // - Size prefix by ranks, shaped as `[num_ranks, num_ranks]` // - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]` // NOTES: no more token dropping in this version *moe_recv_counter = -1; for (int i = 0; i < num_local_experts; ++i) moe_recv_expert_counter[i] = -1; EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes); intranode::notify_dispatch( num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_recv_tokens_per_expert.data_ptr(), num_experts, num_tokens, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), rank_prefix_matrix.data_ptr(), num_memset_int, expert_alignment, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, comm_stream, num_channels); if (num_worst_tokens > 0) { // No CPU sync, just allocate the worst case num_recv_tokens = num_worst_tokens; // Must be forward with top-k stuffs EP_HOST_ASSERT(topk_idx.has_value()); EP_HOST_ASSERT(topk_weights.has_value()); } else { // Synchronize total received tokens and tokens per expert auto start_time = std::chrono::high_resolution_clock::now(); while (true) { // Read total count num_recv_tokens = static_cast(*moe_recv_counter); // Read per-expert count bool ready = (num_recv_tokens >= 0); for (int i = 0; i < num_local_experts and ready; ++i) ready &= moe_recv_expert_counter[i] >= 0; if (ready) break; // Timeout check if (std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - start_time) .count() > NUM_CPU_TIMEOUT_SECS) throw std::runtime_error("DeepEP error: CPU recv timeout"); } num_recv_tokens_per_expert_list = std::vector( moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } } // Allocate new tensors auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); auto recv_src_idx = torch::empty( {num_recv_tokens}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); auto send_head = torch::empty({num_tokens, num_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); // Assign pointers int64_t *recv_topk_idx_ptr = nullptr; float *recv_topk_weights_ptr = nullptr; float *recv_x_scales_ptr = nullptr; if (topk_idx.has_value()) { recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); recv_topk_idx_ptr = recv_topk_idx->data_ptr(); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } if (x_scales.has_value()) { recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); } // Dispatch EP_HOST_ASSERT(num_ranks * num_ranks * sizeof(int) + // Size prefix matrix num_channels * num_ranks * sizeof(int) + // Channel start offset num_channels * num_ranks * sizeof(int) + // Channel end offset num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(int64_t) + // Top-k index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer <= num_nvl_bytes); intranode::dispatch( recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr(), recv_topk_idx_ptr, recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr(), send_head.data_ptr(), x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), num_tokens, num_worst_tokens, static_cast(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales, scale_token_stride, scale_hidden_stride, buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto &t : {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_src_idx, recv_channel_prefix_matrix, send_head}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto &to : {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { if (not use_default_stream_as_comm_stream) { stream_wait(compute_stream, comm_stream); } } // Switch back compute stream if (allocate_on_comm_stream) at::hip::setCurrentHIPStreamMasqueradingAsCUDA(compute_stream); // Return values return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, num_recv_tokens_per_expert, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event}; } std::tuple, std::optional> Buffer::intranode_combine(const torch::Tensor &x, const std::optional &topk_weights, const std::optional &bias_0, const std::optional &bias_1, const torch::Tensor &src_idx, const torch::Tensor &rank_prefix_matrix, const torch::Tensor &channel_prefix_matrix, const torch::Tensor &send_head, const Config &config, std::optional &previous_event, bool async, bool allocate_on_comm_stream) { EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32); EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32); EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and rank_prefix_matrix.scalar_type() == torch::kInt32); EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and channel_prefix_matrix.scalar_type() == torch::kInt32); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for // receiving. EP_HOST_ASSERT(config.num_sms % 2 == 0); int num_channels = config.num_sms / 2; auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); auto num_recv_tokens = static_cast(send_head.size(0)); EP_HOST_ASSERT(src_idx.size(0) == num_tokens); EP_HOST_ASSERT(send_head.size(1) == num_ranks); EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks and rank_prefix_matrix.size(1) == num_ranks); EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and channel_prefix_matrix.size(1) == num_channels); EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::hip::setCurrentHIPStreamMasqueradingAsCUDA(comm_stream); } // Wait previous tasks to be finished if (not use_default_stream_as_comm_stream) { if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } } int num_topk = 0; auto recv_topk_weights = std::optional(); float *topk_weights_ptr = nullptr; float *recv_topk_weights_ptr = nullptr; if (topk_weights.has_value()) { EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); num_topk = static_cast(topk_weights->size(1)); topk_weights_ptr = topk_weights->data_ptr(); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } // Launch barrier and reset queue head and tail EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes); intranode::cached_notify_combine( buffer_ptrs_gpu, send_head.data_ptr(), num_channels, num_recv_tokens, num_channels * num_ranks * 2, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream); // Assign bias pointers auto bias_opts = std::vector>({bias_0, bias_1}); void *bias_ptrs[2] = {nullptr, nullptr}; for (int i = 0; i < 2; ++i) if (bias_opts[i].has_value()) { auto bias = bias_opts[i].value(); EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden); bias_ptrs[i] = bias.data_ptr(); } // Combine data auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer <= num_nvl_bytes); intranode::combine( at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), recv_x.data_ptr(), recv_topk_weights_ptr, x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1], src_idx.data_ptr(), rank_prefix_matrix.data_ptr(), channel_prefix_matrix.data_ptr(), send_head.data_ptr(), num_tokens, num_recv_tokens, hidden, num_topk, buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto &t : {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto &to : {topk_weights, recv_topk_weights, bias_0, bias_1}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { if (not use_default_stream_as_comm_stream) { stream_wait(compute_stream, comm_stream); } } // Switch back compute stream if (allocate_on_comm_stream) at::hip::setCurrentHIPStreamMasqueradingAsCUDA(compute_stream); return {recv_x, recv_topk_weights, event}; } std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> Buffer::internode_dispatch(const torch::Tensor &x, const std::optional &x_scales, const std::optional &topk_idx, const std::optional &topk_weights, const std::optional &num_tokens_per_rank, const std::optional &num_tokens_per_rdma_rank, const torch::Tensor &is_token_in_rank, const std::optional &num_tokens_per_expert, int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, const std::optional &cached_rdma_channel_prefix_matrix, const std::optional &cached_recv_rdma_rank_prefix_sum, const std::optional &cached_gbl_channel_prefix_matrix, const std::optional &cached_recv_gbl_rank_prefix_sum, int expert_alignment, const Config &config, std::optional &previous_event, bool async, bool allocate_on_comm_stream) { #ifndef DISABLE_ROCSHMEM // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, // which can be quite long. If users of DeepEP need to execute other Python code on other // threads, such as KV transfer, their code will get stuck due to GIL unless we release GIL // here. pybind11::gil_scoped_release release; const int num_channels = config.num_sms / 2; EP_HOST_ASSERT(config.num_sms % 2 == 0); EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); if (cached_mode) { EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value()); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value()); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value()); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value()); } else { EP_HOST_ASSERT(num_tokens_per_rank.has_value()); EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value()); EP_HOST_ASSERT(num_tokens_per_expert.has_value()); } // Type checks if (cached_mode) { EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32); } else { EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32); EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); } // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); if (cached_mode) { EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 and cached_rdma_channel_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and cached_rdma_channel_prefix_matrix->size(1) == num_channels); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous()); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and cached_gbl_channel_prefix_matrix->size(1) == num_channels); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous()); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks); } else { EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and num_tokens_per_rdma_rank->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks); EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); } auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; // Top-k checks int num_topk = 0; int64_t *topk_idx_ptr = nullptr; float *topk_weights_ptr = nullptr; EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); if (topk_idx.has_value()) { num_topk = static_cast(topk_idx->size(1)); EP_HOST_ASSERT(num_experts > 0); EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); topk_idx_ptr = topk_idx->data_ptr(); topk_weights_ptr = topk_weights->data_ptr(); } // FP8 scales checks float *x_scales_ptr = nullptr; int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); EP_HOST_ASSERT(x_scales->dim() == 2); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); x_scales_ptr = static_cast(x_scales->data_ptr()); scale_token_stride = static_cast(x_scales->stride(0)); scale_hidden_stride = static_cast(x_scales->stride(1)); } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } // Create handles (only return for non-cached mode) int num_recv_tokens = -1, num_rdma_recv_tokens = -1; auto rdma_channel_prefix_matrix = torch::Tensor(); auto recv_rdma_rank_prefix_sum = torch::Tensor(); auto gbl_channel_prefix_matrix = torch::Tensor(); auto recv_gbl_rank_prefix_sum = torch::Tensor(); std::vector num_recv_tokens_per_expert_list; // Barrier or send sizes if (cached_mode) { num_recv_tokens = cached_num_recv_tokens; num_rdma_recv_tokens = cached_num_rdma_recv_tokens; rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value(); recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value(); gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value(); recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value(); // Just a barrier and clean flags internode::cached_notify( hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr, nullptr, nullptr, nullptr, rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, true, low_latency_mode); } else { rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); recv_rdma_rank_prefix_sum = torch::empty( {num_rdma_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); recv_gbl_rank_prefix_sum = torch::empty( {num_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); // Send sizes *moe_recv_counter = -1, *moe_recv_rdma_counter = -1; for (int i = 0; i < num_local_experts; ++i) moe_recv_expert_counter[i] = -1; internode::notify_dispatch( num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, num_tokens_per_rdma_rank->data_ptr(), moe_recv_rdma_counter_mapped, num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, is_token_in_rank.data_ptr(), num_tokens, num_channels, hidden_int4, num_scales, num_topk, expert_alignment, rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, low_latency_mode); // Synchronize total received tokens and tokens per expert auto start_time = std::chrono::high_resolution_clock::now(); while (true) { // Read total count num_recv_tokens = static_cast(*moe_recv_counter); num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); // Read per-expert count bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); for (int i = 0; i < num_local_experts and ready; ++i) ready &= moe_recv_expert_counter[i] >= 0; if (ready) break; // Timeout check if (std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - start_time) .count() > NUM_CPU_TIMEOUT_SECS) { printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens); for (int i = 0; i < num_local_experts; ++i) printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]); throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); } } num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } // Allocate new tensors auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); auto recv_src_meta = std::optional(); auto recv_rdma_channel_prefix_matrix = std::optional(); auto recv_gbl_channel_prefix_matrix = std::optional(); auto send_rdma_head = std::optional(); auto send_nvl_head = std::optional(); if (not cached_mode) { recv_src_meta = torch::empty( {num_recv_tokens, internode::get_source_meta_bytes()}, torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA)); recv_rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); recv_gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); send_rdma_head = torch::empty({num_tokens, num_rdma_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); send_nvl_head = torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); } // Assign pointers int64_t *recv_topk_idx_ptr = nullptr; float *recv_topk_weights_ptr = nullptr; float *recv_x_scales_ptr = nullptr; if (topk_idx.has_value()) { recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); recv_topk_idx_ptr = recv_topk_idx->data_ptr(); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } if (x_scales.has_value()) { recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); } // Launch data dispatch // NOTES: the buffer size checks are moved into the `.cu` file internode::dispatch( recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, cached_mode ? nullptr : recv_src_meta->data_ptr(), x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, cached_mode ? nullptr : send_rdma_head->data_ptr(), cached_mode ? nullptr : send_nvl_head->data_ptr(), cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), is_token_in_rank.data_ptr(), num_tokens, hidden_int4, num_scales, num_topk, num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, cached_mode, comm_stream, num_channels, low_latency_mode); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto &t : {x, is_token_in_rank, recv_x, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto &to : {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, cached_rdma_channel_prefix_matrix, cached_recv_rdma_rank_prefix_sum, cached_gbl_channel_prefix_matrix, cached_recv_gbl_rank_prefix_sum, recv_topk_idx, recv_topk_weights, recv_x_scales, recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, recv_src_meta}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); // Return values return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head, send_nvl_head, event}; #else EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install rocSHMEM by " "following docs/install_dependencies.md"); return {}; #endif } std::tuple, std::optional> Buffer::internode_combine( const torch::Tensor &x, const std::optional &topk_weights, const std::optional &bias_0, const std::optional &bias_1, const torch::Tensor &src_meta, const torch::Tensor &is_combined_token_in_rank, const torch::Tensor &rdma_channel_prefix_matrix, const torch::Tensor &rdma_rank_prefix_sum, const torch::Tensor &gbl_channel_prefix_matrix, const torch::Tensor &combined_rdma_head, const torch::Tensor &combined_nvl_head, const Config &config, std::optional &previous_event, bool async, bool allocate_on_comm_stream) { #ifndef DISABLE_ROCSHMEM const int num_channels = config.num_sms / 2; EP_HOST_ASSERT(config.num_sms % 2 == 0); // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte); EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and is_combined_token_in_rank.scalar_type() == torch::kBool); EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and rdma_channel_prefix_matrix.scalar_type() == torch::kInt32); EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and rdma_rank_prefix_sum.scalar_type() == torch::kInt32); EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and gbl_channel_prefix_matrix.scalar_type() == torch::kInt32); EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and combined_rdma_head.scalar_type() == torch::kInt32); EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32); auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); auto num_combined_tokens = static_cast(is_combined_token_in_rank.size(0)); EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes()); EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks); EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and rdma_channel_prefix_matrix.size(1) == num_channels); EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks); EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels); EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and combined_rdma_head.size(1) == num_rdma_ranks); EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS); // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } // Top-k checks int num_topk = 0; auto combined_topk_weights = std::optional(); float *topk_weights_ptr = nullptr; float *combined_topk_weights_ptr = nullptr; if (topk_weights.has_value()) { EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); num_topk = static_cast(topk_weights->size(1)); topk_weights_ptr = topk_weights->data_ptr(); combined_topk_weights = torch::empty({num_combined_tokens, num_topk}, topk_weights->options()); combined_topk_weights_ptr = combined_topk_weights->data_ptr(); } // Extra check for avoid-dead-lock design EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks); // Launch barrier and reset queue head and tail internode::cached_notify( hidden_int4, 0, 0, num_topk, num_ranks, num_channels, num_combined_tokens, combined_rdma_head.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), combined_nvl_head.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, false, low_latency_mode); // Assign bias pointers auto bias_opts = std::vector>({bias_0, bias_1}); void *bias_ptrs[2] = {nullptr, nullptr}; for (int i = 0; i < 2; ++i) if (bias_opts[i].has_value()) { // EP_HOST_ASSERT(false and "bias is not supported in internode combine"); auto bias = bias_opts[i].value(); EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden); bias_ptrs[i] = bias.data_ptr(); } // Launch data combine auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); internode::combine( at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), combined_x.data_ptr(), combined_topk_weights_ptr, is_combined_token_in_rank.data_ptr(), x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1], combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, comm_stream, num_channels, low_latency_mode); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto &t : {x, src_meta, is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, combined_x, combined_rdma_head, combined_nvl_head}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto &to : {topk_weights, combined_topk_weights, bias_0, bias_1}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); // Return values return {combined_x, combined_topk_weights, event}; #else EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install rocSHMEM by " "following docs/install_dependencies.md"); return {}; #endif } void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { EP_HOST_ASSERT(false and "not support low latency"); } std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx, const std::optional &cumulative_local_expert_recv_stats, const std::optional &dispatch_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook) { EP_HOST_ASSERT(false and "not support low latency"); return {}; } std::tuple, std::optional>> Buffer::low_latency_combine(const torch::Tensor &x, const torch::Tensor &topk_idx, const torch::Tensor &topk_weights, const torch::Tensor &src_info, const torch::Tensor &layout_range, const std::optional &combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, const std::optional &out) { EP_HOST_ASSERT(false and "not support low latency"); return {}; } torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const { EP_HOST_ASSERT(false and "not support low latency"); return {}; } } // namespace deep_ep PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "DeepEP: an efficient expert-parallel communication library"; pybind11::class_(m, "Config") .def(pybind11::init(), py::arg("num_sms") = 20, py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256, py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256) .def("get_nvl_buffer_size_hint", &deep_ep::Config::get_nvl_buffer_size_hint) .def("get_rdma_buffer_size_hint", &deep_ep::Config::get_rdma_buffer_size_hint); m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint); pybind11::class_(m, "EventHandle") .def(pybind11::init<>()) .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); pybind11::class_(m, "Buffer") .def(pybind11::init()) .def("is_available", &deep_ep::Buffer::is_available) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) .def("get_root_rdma_rank", &deep_ep::Buffer::get_root_rdma_rank) .def("get_local_device_id", &deep_ep::Buffer::get_local_device_id) .def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle) .def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id) .def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor) .def("get_comm_stream", &deep_ep::Buffer::get_comm_stream) .def("sync", &deep_ep::Buffer::sync) .def("destroy", &deep_ep::Buffer::destroy) .def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout) .def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch) .def("intranode_combine", &deep_ep::Buffer::intranode_combine) .def("internode_dispatch", &deep_ep::Buffer::internode_dispatch) .def("internode_combine", &deep_ep::Buffer::internode_combine) .def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer) .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) .def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer); // m.def("is_sm90_compiled", deep_ep::is_sm90_compiled); // m.attr("topk_idx_t") = py::cast(c10::CppTypeToScalarType::value); }