"vscode:/vscode.git/clone" did not exist on "d95e723d5ac54ccc1ec9cc10e37bd5e3476e6150"
Unverified Commit da6ca24e authored by Tailing Yuan's avatar Tailing Yuan Committed by GitHub
Browse files

Make dtype of topk_idx configurable (#422)


Co-authored-by: default avatarYifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
parent c939644c
......@@ -64,7 +64,7 @@ struct Config {
#ifndef DISABLE_NVSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes();
#endif
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float);
num_bytes = ((num_bytes + 127) / 128) * 128;
......@@ -90,7 +90,7 @@ struct Config {
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
......
......@@ -269,7 +269,7 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,
if (is_internode_available())
num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
layout::get_dispatch_layout(topk_idx.data_ptr<int64_t>(),
layout::get_dispatch_layout(topk_idx.data_ptr<topk_idx_t>(),
num_tokens_per_rank.data_ptr<int>(),
num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>() : nullptr,
num_tokens_per_expert.data_ptr<int>(),
......@@ -355,7 +355,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// Top-k checks
int num_topk = 0;
int64_t* topk_idx_ptr = nullptr;
topk_idx_t* topk_idx_ptr = nullptr;
float* topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) {
......@@ -366,7 +366,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_idx_ptr = topk_idx->data_ptr<int64_t>();
topk_idx_ptr = topk_idx->data_ptr<topk_idx_t>();
topk_weights_ptr = topk_weights->data_ptr<float>();
}
......@@ -476,13 +476,13 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
// Assign pointers
int64_t* recv_topk_idx_ptr = nullptr;
topk_idx_t* recv_topk_idx_ptr = nullptr;
float* recv_topk_weights_ptr = nullptr;
float* recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) {
recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>();
recv_topk_idx_ptr = recv_topk_idx->data_ptr<topk_idx_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
if (x_scales.has_value()) {
......@@ -499,7 +499,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(int64_t) + // Top-k index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(topk_idx_t) + // Top-k index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer
<= num_nvl_bytes);
......@@ -720,7 +720,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// Top-k checks
int num_topk = 0;
int64_t* topk_idx_ptr = nullptr;
topk_idx_t* topk_idx_ptr = nullptr;
float* topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) {
......@@ -731,7 +731,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_idx_ptr = topk_idx->data_ptr<int64_t>();
topk_idx_ptr = topk_idx->data_ptr<topk_idx_t>();
topk_weights_ptr = topk_weights->data_ptr<float>();
}
......@@ -856,13 +856,13 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
}
// Assign pointers
int64_t* recv_topk_idx_ptr = nullptr;
topk_idx_t* recv_topk_idx_ptr = nullptr;
float* recv_topk_weights_ptr = nullptr;
float* recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) {
recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>();
recv_topk_idx_ptr = recv_topk_idx->data_ptr<topk_idx_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
if (x_scales.has_value()) {
......@@ -1103,7 +1103,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0);
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType<topk_idx_t>::value);
EP_HOST_ASSERT(num_experts % num_ranks == 0);
// Diagnosis tensors
......@@ -1173,7 +1173,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
x.data_ptr(), topk_idx.data_ptr<topk_idx_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
......@@ -1223,7 +1223,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0);
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1));
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType<topk_idx_t>::value);
EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous());
EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);
......@@ -1274,7 +1274,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
internode_ll::combine(combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
x.data_ptr(), topk_idx.data_ptr<topk_idx_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first, next_clean_meta.second,
......@@ -1379,4 +1379,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
m.attr("topk_idx_t") = py::cast(c10::CppTypeToScalarType<deep_ep::topk_idx_t>::value);
}
......@@ -2,6 +2,8 @@
#include <vector>
#include "configs.cuh"
namespace deep_ep {
// Intranode runtime
......@@ -31,7 +33,7 @@ void finalize();
// Layout kernels
namespace layout {
void get_dispatch_layout(const int64_t* topk_idx,
void get_dispatch_layout(const topk_idx_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
......@@ -53,8 +55,8 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks,
cudaStream_t stream);
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, topk_idx_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
......@@ -95,8 +97,8 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool low_latency_mode);
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
......@@ -145,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
const void* x, const topk_idx_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
......@@ -155,7 +157,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const void* x, const topk_idx_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
......
......@@ -58,6 +58,20 @@ typedef int __nv_fp8x4_e4m3;
typedef uint8_t __nv_fp8_storage_t;
#endif
namespace deep_ep {
#ifndef TOPK_IDX_BITS
#define TOPK_IDX_BITS 64
#endif
#define INT_BITS_T2(bits) int##bits##_t
#define INT_BITS_T(bits) INT_BITS_T2(bits)
typedef INT_BITS_T(TOPK_IDX_BITS) topk_idx_t; // int32_t or int64_t
#undef INT_BITS_T
#undef INT_BITS_T2
} // namespace deep_ep
#ifndef DISABLE_NVSHMEM
#include <nvshmem.h>
#include <nvshmemx.h>
......
......@@ -355,8 +355,8 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
template <bool kLowLatencyMode, int kNumRDMARanks, bool kCachedMode, int kNumTMABytesPerWarp,
int kNumDispatchRDMASenderWarps, int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
__global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1)
dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta,
const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
dispatch(int4* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta,
const int4* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
......@@ -968,7 +968,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Copy `topk_idx` and `topk_weights`
if (lane_id < num_topk) {
// Read
auto idx_value = static_cast<int64_t>(ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id));
auto idx_value = static_cast<topk_idx_t>(ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id));
auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted + sizeof(int) * num_topk) + lane_id);
auto recv_idx = recv_token_idx * num_topk + lane_id;
......@@ -991,8 +991,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
}
}
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
......
......@@ -44,7 +44,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
const void* x, const topk_idx_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
......@@ -340,7 +340,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
const void* x, const topk_idx_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
......@@ -555,7 +555,7 @@ template <bool kUseLogFMT, int kHidden, int kNumMaxTopk, int kNumMaxUnrolls>
__global__ __launch_bounds__(1024, 1) void
combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const void* x, const topk_idx_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
......@@ -914,7 +914,7 @@ combine(void* combined_x,
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const void* x, const topk_idx_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
......
......@@ -163,8 +163,8 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
template <int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>
__global__ void __launch_bounds__(kNumThreads, 1)
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, topk_idx_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const int4* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
......@@ -211,12 +211,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Channel data buffers, stored on the receiver side
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
// `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(int64_t)
// `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(topk_idx_t)
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
// `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float)
auto channel_x_buffers = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);
auto channel_topk_idx_buffers = Buffer<int64_t>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_topk_idx_buffers = Buffer<topk_idx_t>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
......@@ -466,8 +466,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
}
}
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, topk_idx_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
......
......@@ -7,7 +7,7 @@ namespace deep_ep {
namespace layout {
template <int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
__global__ void get_dispatch_layout(const int64_t* topk_idx,
__global__ void get_dispatch_layout(const topk_idx_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts) {
......@@ -115,7 +115,7 @@ __global__ void get_dispatch_layout(const int64_t* topk_idx,
}
}
void get_dispatch_layout(const int64_t* topk_idx,
void get_dispatch_layout(const topk_idx_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
......
......@@ -4,4 +4,4 @@ from .utils import EventOverlap
from .buffer import Buffer
# noinspection PyUnresolvedReferences
from deep_ep_cpp import Config
from deep_ep_cpp import Config, topk_idx_t
......@@ -285,8 +285,8 @@ class Buffer:
Calculate the layout required for later communication.
Arguments:
topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token,
`-1` means no selections.
topk_idx: `[num_tokens, num_topk]`, dtype must be `deep_ep.topk_idx_t` (typically `torch.int64`), the expert
indices selected by each token, `-1` means no selections.
num_experts: the number of experts.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
......@@ -334,8 +334,8 @@ class Buffer:
rank (with the same GPU index), return `None` for intranode settings.
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token,
`-1` means no selections.
topk_idx: `[num_tokens, num_topk]` with `deep_ep.topk_idx_t` (typically `torch.int64`), the expert indices
selected by each token, `-1` means no selections.
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
expert_alignment: align the number of tokens received by each local expert to this variable.
num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it
......@@ -548,8 +548,8 @@ class Buffer:
Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`.
topk_idx: `torch.Tensor` with `torch.int64`, shaped as `[num_tokens, num_topk]`, only several top-k shapes
are supported. `-1` indices (not selecting any expert) are supported.
topk_idx: `torch.Tensor` with `deep_ep.topk_idx_t` (typically `torch.int64`), shaped as `[num_tokens, num_topk]`,
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
......@@ -616,9 +616,9 @@ class Buffer:
Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
the local calculated tokens to be sent to this original rank and reduced.
topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched
tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals
to the number of dispatched tokens.
topk_idx: `[num_combined_tokens, num_topk]` with `deep_ep.topk_idx_t` (typically `torch.int64`), the expert
indices selected by the dispatched tokens. `-1` indices (not selecting any expert) are supported. Note that,
`num_combined_tokens` equals to the number of dispatched tokens.
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
......
......@@ -80,6 +80,12 @@ if __name__ == '__main__':
cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
# Bits of `topk_idx.dtype`, choices are 32 and 64
if "TOPK_IDX_BITS" in os.environ:
topk_idx_bits = int(os.environ['TOPK_IDX_BITS'])
cxx_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}')
nvcc_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}')
# Put them together
extra_compile_args = {
'cxx': cxx_flags,
......
......@@ -35,9 +35,11 @@ def test_main(args: argparse.Namespace, num_sms: int,
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_idx = topk_idx.to(deep_ep.topk_idx_t)
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx = rank_idx.to(torch.int64)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
rdma_rank_idx = rank_idx // num_local_ranks
......
......@@ -29,9 +29,11 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_idx = topk_idx.to(deep_ep.topk_idx_t)
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx = rank_idx.to(torch.int64)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
......
......@@ -37,6 +37,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_idx = topk_idx.to(deep_ep.topk_idx_t)
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment