Unverified Commit da6ca24e authored by Tailing Yuan's avatar Tailing Yuan Committed by GitHub
Browse files

Make dtype of topk_idx configurable (#422)


Co-authored-by: default avatarYifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
parent c939644c
...@@ -64,7 +64,7 @@ struct Config { ...@@ -64,7 +64,7 @@ struct Config {
#ifndef DISABLE_NVSHMEM #ifndef DISABLE_NVSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes();
#endif #endif
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float);
num_bytes = ((num_bytes + 127) / 128) * 128; num_bytes = ((num_bytes + 127) / 128) * 128;
...@@ -90,7 +90,7 @@ struct Config { ...@@ -90,7 +90,7 @@ struct Config {
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
......
...@@ -269,7 +269,7 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, ...@@ -269,7 +269,7 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,
if (is_internode_available()) if (is_internode_available())
num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
layout::get_dispatch_layout(topk_idx.data_ptr<int64_t>(), layout::get_dispatch_layout(topk_idx.data_ptr<topk_idx_t>(),
num_tokens_per_rank.data_ptr<int>(), num_tokens_per_rank.data_ptr<int>(),
num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>() : nullptr, num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>() : nullptr,
num_tokens_per_expert.data_ptr<int>(), num_tokens_per_expert.data_ptr<int>(),
...@@ -355,7 +355,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te ...@@ -355,7 +355,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// Top-k checks // Top-k checks
int num_topk = 0; int num_topk = 0;
int64_t* topk_idx_ptr = nullptr; topk_idx_t* topk_idx_ptr = nullptr;
float* topk_weights_ptr = nullptr; float* topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) { if (topk_idx.has_value()) {
...@@ -366,7 +366,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te ...@@ -366,7 +366,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_idx_ptr = topk_idx->data_ptr<int64_t>(); topk_idx_ptr = topk_idx->data_ptr<topk_idx_t>();
topk_weights_ptr = topk_weights->data_ptr<float>(); topk_weights_ptr = topk_weights->data_ptr<float>();
} }
...@@ -476,13 +476,13 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te ...@@ -476,13 +476,13 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
// Assign pointers // Assign pointers
int64_t* recv_topk_idx_ptr = nullptr; topk_idx_t* recv_topk_idx_ptr = nullptr;
float* recv_topk_weights_ptr = nullptr; float* recv_topk_weights_ptr = nullptr;
float* recv_x_scales_ptr = nullptr; float* recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) { if (topk_idx.has_value()) {
recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>(); recv_topk_idx_ptr = recv_topk_idx->data_ptr<topk_idx_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>(); recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
} }
if (x_scales.has_value()) { if (x_scales.has_value()) {
...@@ -499,7 +499,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te ...@@ -499,7 +499,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(int64_t) + // Top-k index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(topk_idx_t) + // Top-k index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer
<= num_nvl_bytes); <= num_nvl_bytes);
...@@ -720,7 +720,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te ...@@ -720,7 +720,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// Top-k checks // Top-k checks
int num_topk = 0; int num_topk = 0;
int64_t* topk_idx_ptr = nullptr; topk_idx_t* topk_idx_ptr = nullptr;
float* topk_weights_ptr = nullptr; float* topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) { if (topk_idx.has_value()) {
...@@ -731,7 +731,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te ...@@ -731,7 +731,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_idx_ptr = topk_idx->data_ptr<int64_t>(); topk_idx_ptr = topk_idx->data_ptr<topk_idx_t>();
topk_weights_ptr = topk_weights->data_ptr<float>(); topk_weights_ptr = topk_weights->data_ptr<float>();
} }
...@@ -856,13 +856,13 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te ...@@ -856,13 +856,13 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
} }
// Assign pointers // Assign pointers
int64_t* recv_topk_idx_ptr = nullptr; topk_idx_t* recv_topk_idx_ptr = nullptr;
float* recv_topk_weights_ptr = nullptr; float* recv_topk_weights_ptr = nullptr;
float* recv_x_scales_ptr = nullptr; float* recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) { if (topk_idx.has_value()) {
recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>(); recv_topk_idx_ptr = recv_topk_idx->data_ptr<topk_idx_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>(); recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
} }
if (x_scales.has_value()) { if (x_scales.has_value()) {
...@@ -1103,7 +1103,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1103,7 +1103,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0);
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType<topk_idx_t>::value);
EP_HOST_ASSERT(num_experts % num_ranks == 0); EP_HOST_ASSERT(num_experts % num_ranks == 0);
// Diagnosis tensors // Diagnosis tensors
...@@ -1173,7 +1173,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1173,7 +1173,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr, dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer, buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), x.data_ptr(), topk_idx.data_ptr<topk_idx_t>(),
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, num_topk, num_experts, rank, num_ranks,
...@@ -1223,7 +1223,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1223,7 +1223,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0); EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0);
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1)); EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1));
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType<topk_idx_t>::value);
EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous()); EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous());
EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);
...@@ -1274,7 +1274,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1274,7 +1274,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
internode_ll::combine(combined_x.data_ptr(), internode_ll::combine(combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer, buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(), x.data_ptr(), topk_idx.data_ptr<topk_idx_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(), src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr, combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
...@@ -1379,4 +1379,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -1379,4 +1379,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer); .def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
m.def("is_sm90_compiled", deep_ep::is_sm90_compiled); m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
m.attr("topk_idx_t") = py::cast(c10::CppTypeToScalarType<deep_ep::topk_idx_t>::value);
} }
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <vector> #include <vector>
#include "configs.cuh"
namespace deep_ep { namespace deep_ep {
// Intranode runtime // Intranode runtime
...@@ -31,7 +33,7 @@ void finalize(); ...@@ -31,7 +33,7 @@ void finalize();
// Layout kernels // Layout kernels
namespace layout { namespace layout {
void get_dispatch_layout(const int64_t* topk_idx, void get_dispatch_layout(const topk_idx_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank, int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts, int num_tokens, int num_topk, int num_ranks, int num_experts,
...@@ -53,8 +55,8 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, ...@@ -53,8 +55,8 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks,
cudaStream_t stream); cudaStream_t stream);
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, topk_idx_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, int* send_head, const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix, const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride, int scale_token_stride, int scale_hidden_stride,
...@@ -95,8 +97,8 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe ...@@ -95,8 +97,8 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool low_latency_mode); bool low_latency_mode);
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head, int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
...@@ -145,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -145,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* cumulative_local_expert_recv_stats, int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats, int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx, const void* x, const topk_idx_t* topk_idx,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
...@@ -155,7 +157,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -155,7 +157,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
void combine(void* combined_x, void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const topk_idx_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range, const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats, int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
......
...@@ -58,6 +58,20 @@ typedef int __nv_fp8x4_e4m3; ...@@ -58,6 +58,20 @@ typedef int __nv_fp8x4_e4m3;
typedef uint8_t __nv_fp8_storage_t; typedef uint8_t __nv_fp8_storage_t;
#endif #endif
namespace deep_ep {
#ifndef TOPK_IDX_BITS
#define TOPK_IDX_BITS 64
#endif
#define INT_BITS_T2(bits) int##bits##_t
#define INT_BITS_T(bits) INT_BITS_T2(bits)
typedef INT_BITS_T(TOPK_IDX_BITS) topk_idx_t; // int32_t or int64_t
#undef INT_BITS_T
#undef INT_BITS_T2
} // namespace deep_ep
#ifndef DISABLE_NVSHMEM #ifndef DISABLE_NVSHMEM
#include <nvshmem.h> #include <nvshmem.h>
#include <nvshmemx.h> #include <nvshmemx.h>
......
...@@ -355,8 +355,8 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { ...@@ -355,8 +355,8 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
template <bool kLowLatencyMode, int kNumRDMARanks, bool kCachedMode, int kNumTMABytesPerWarp, template <bool kLowLatencyMode, int kNumRDMARanks, bool kCachedMode, int kNumTMABytesPerWarp,
int kNumDispatchRDMASenderWarps, int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)> int kNumDispatchRDMASenderWarps, int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
__global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1) __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1)
dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta, dispatch(int4* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta,
const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, const int4* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head, int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
...@@ -968,7 +968,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -968,7 +968,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Copy `topk_idx` and `topk_weights` // Copy `topk_idx` and `topk_weights`
if (lane_id < num_topk) { if (lane_id < num_topk) {
// Read // Read
auto idx_value = static_cast<int64_t>(ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id)); auto idx_value = static_cast<topk_idx_t>(ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id));
auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted + sizeof(int) * num_topk) + lane_id); auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted + sizeof(int) * num_topk) + lane_id);
auto recv_idx = recv_token_idx * num_topk + lane_id; auto recv_idx = recv_token_idx * num_topk + lane_id;
...@@ -991,8 +991,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -991,8 +991,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
} }
} }
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head, int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
......
...@@ -44,7 +44,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -44,7 +44,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* cumulative_local_expert_recv_stats, int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats, int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx, const void* x, const topk_idx_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank, int num_tokens, int num_max_dispatch_tokens_per_rank,
...@@ -340,7 +340,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -340,7 +340,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* cumulative_local_expert_recv_stats, int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats, int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx, const void* x, const topk_idx_t* topk_idx,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
...@@ -555,7 +555,7 @@ template <bool kUseLogFMT, int kHidden, int kNumMaxTopk, int kNumMaxUnrolls> ...@@ -555,7 +555,7 @@ template <bool kUseLogFMT, int kHidden, int kNumMaxTopk, int kNumMaxUnrolls>
__global__ __launch_bounds__(1024, 1) void __global__ __launch_bounds__(1024, 1) void
combine(void* combined_x, combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const topk_idx_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range, const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats, int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
...@@ -914,7 +914,7 @@ combine(void* combined_x, ...@@ -914,7 +914,7 @@ combine(void* combined_x,
void combine(void* combined_x, void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const topk_idx_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range, const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats, int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
......
...@@ -163,8 +163,8 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, ...@@ -163,8 +163,8 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
template <int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp> template <int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>
__global__ void __launch_bounds__(kNumThreads, 1) __global__ void __launch_bounds__(kNumThreads, 1)
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, topk_idx_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, int* send_head, const int4* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix, const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride, int scale_token_stride, int scale_hidden_stride,
...@@ -211,12 +211,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to ...@@ -211,12 +211,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Channel data buffers, stored on the receiver side // Channel data buffers, stored on the receiver side
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
// `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(int64_t) // `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(topk_idx_t)
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
// `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float) // `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float)
auto channel_x_buffers = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4); auto channel_x_buffers = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens); auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);
auto channel_topk_idx_buffers = Buffer<int64_t>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); auto channel_topk_idx_buffers = Buffer<topk_idx_t>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales); auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
...@@ -466,8 +466,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to ...@@ -466,8 +466,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
} }
} }
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, topk_idx_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, int* send_head, const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix, const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride, int scale_token_stride, int scale_hidden_stride,
......
...@@ -7,7 +7,7 @@ namespace deep_ep { ...@@ -7,7 +7,7 @@ namespace deep_ep {
namespace layout { namespace layout {
template <int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM> template <int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
__global__ void get_dispatch_layout(const int64_t* topk_idx, __global__ void get_dispatch_layout(const topk_idx_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank, int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts) { int num_tokens, int num_topk, int num_ranks, int num_experts) {
...@@ -115,7 +115,7 @@ __global__ void get_dispatch_layout(const int64_t* topk_idx, ...@@ -115,7 +115,7 @@ __global__ void get_dispatch_layout(const int64_t* topk_idx,
} }
} }
void get_dispatch_layout(const int64_t* topk_idx, void get_dispatch_layout(const topk_idx_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank, int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts, int num_tokens, int num_topk, int num_ranks, int num_experts,
......
...@@ -4,4 +4,4 @@ from .utils import EventOverlap ...@@ -4,4 +4,4 @@ from .utils import EventOverlap
from .buffer import Buffer from .buffer import Buffer
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from deep_ep_cpp import Config from deep_ep_cpp import Config, topk_idx_t
...@@ -285,8 +285,8 @@ class Buffer: ...@@ -285,8 +285,8 @@ class Buffer:
Calculate the layout required for later communication. Calculate the layout required for later communication.
Arguments: Arguments:
topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token, topk_idx: `[num_tokens, num_topk]`, dtype must be `deep_ep.topk_idx_t` (typically `torch.int64`), the expert
`-1` means no selections. indices selected by each token, `-1` means no selections.
num_experts: the number of experts. num_experts: the number of experts.
previous_event: the event to wait before actually executing the kernel. previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
...@@ -334,8 +334,8 @@ class Buffer: ...@@ -334,8 +334,8 @@ class Buffer:
rank (with the same GPU index), return `None` for intranode settings. rank (with the same GPU index), return `None` for intranode settings.
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank. is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert. num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token, topk_idx: `[num_tokens, num_topk]` with `deep_ep.topk_idx_t` (typically `torch.int64`), the expert indices
`-1` means no selections. selected by each token, `-1` means no selections.
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch. topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
expert_alignment: align the number of tokens received by each local expert to this variable. expert_alignment: align the number of tokens received by each local expert to this variable.
num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it
...@@ -548,8 +548,8 @@ class Buffer: ...@@ -548,8 +548,8 @@ class Buffer:
Arguments: Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`. supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`.
topk_idx: `torch.Tensor` with `torch.int64`, shaped as `[num_tokens, num_topk]`, only several top-k shapes topk_idx: `torch.Tensor` with `deep_ep.topk_idx_t` (typically `torch.int64`), shaped as `[num_tokens, num_topk]`,
are supported. `-1` indices (not selecting any expert) are supported. only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts. num_experts: the number of all experts.
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
...@@ -616,9 +616,9 @@ class Buffer: ...@@ -616,9 +616,9 @@ class Buffer:
Arguments: Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`, x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
the local calculated tokens to be sent to this original rank and reduced. the local calculated tokens to be sent to this original rank and reduced.
topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched topk_idx: `[num_combined_tokens, num_topk]` with `deep_ep.topk_idx_t` (typically `torch.int64`), the expert
tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals indices selected by the dispatched tokens. `-1` indices (not selecting any expert) are supported. Note that,
to the number of dispatched tokens. `num_combined_tokens` equals to the number of dispatched tokens.
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor. tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function. handle: the communication handle given by the `dispatch` function.
......
...@@ -80,6 +80,12 @@ if __name__ == '__main__': ...@@ -80,6 +80,12 @@ if __name__ == '__main__':
cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS') cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS') nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
# Bits of `topk_idx.dtype`, choices are 32 and 64
if "TOPK_IDX_BITS" in os.environ:
topk_idx_bits = int(os.environ['TOPK_IDX_BITS'])
cxx_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}')
nvcc_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}')
# Put them together # Put them together
extra_compile_args = { extra_compile_args = {
'cxx': cxx_flags, 'cxx': cxx_flags,
......
...@@ -35,9 +35,11 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -35,9 +35,11 @@ def test_main(args: argparse.Namespace, num_sms: int,
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
masked_scores = create_grouped_scores(scores, group_idx, num_nodes) masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1] topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_idx = topk_idx.to(deep_ep.topk_idx_t)
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks) rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx = rank_idx.to(torch.int64)
rank_idx.masked_fill_(topk_idx == -1, -1) rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks) inplace_unique(rank_idx, num_ranks)
rdma_rank_idx = rank_idx // num_local_ranks rdma_rank_idx = rank_idx // num_local_ranks
......
...@@ -29,9 +29,11 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks ...@@ -29,9 +29,11 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1] topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_idx = topk_idx.to(deep_ep.topk_idx_t)
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks) rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx = rank_idx.to(torch.int64)
rank_idx.masked_fill_(topk_idx == -1, -1) rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks) inplace_unique(rank_idx, num_ranks)
......
...@@ -37,6 +37,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -37,6 +37,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_idx = topk_idx.to(deep_ep.topk_idx_t)
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions # Randomly mask some positions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment