Commit 830124e1 authored by lishen's avatar lishen
Browse files

量化scales传输size优化

parent d0fcf024
...@@ -135,8 +135,8 @@ struct LowLatencyLayout { ...@@ -135,8 +135,8 @@ struct LowLatencyLayout {
} }
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) { int num_ranks, int num_experts, int quant_group_size=0) {
const int num_scales = hidden / QUANTIZATION_GROUPSIZE; const int num_scales = quant_group_size == 0 ? 4 : hidden / QUANTIZATION_GROUPSIZE; // 应该是1,但是代码中为了满足int4对齐
// Dispatch and combine layout: // Dispatch and combine layout:
// - 2 symmetric odd/even send buffer // - 2 symmetric odd/even send buffer
...@@ -205,9 +205,9 @@ struct LowLatencyLayout { ...@@ -205,9 +205,9 @@ struct LowLatencyLayout {
}; };
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) { int num_ranks, int num_experts, int quant_group_size=0) {
auto num_bytes = auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts) LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size)
.total_bytes; .total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES; NUM_BUFFER_ALIGNMENT_BYTES;
......
...@@ -1271,10 +1271,10 @@ Buffer::internode_combine( ...@@ -1271,10 +1271,10 @@ Buffer::internode_combine(
#endif #endif
} }
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts, int quant_group_size) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size);
auto clean_meta_0 = layout.buffers[0].clean_meta(); auto clean_meta_0 = layout.buffers[0].clean_meta();
auto clean_meta_1 = layout.buffers[1].clean_meta(); auto clean_meta_1 = layout.buffers[1].clean_meta();
...@@ -1311,7 +1311,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1311,7 +1311,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto num_local_experts = num_experts / num_ranks; auto num_local_experts = num_experts / num_ranks;
// Buffer control // Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
......
...@@ -172,7 +172,7 @@ public: ...@@ -172,7 +172,7 @@ public:
std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream); std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream);
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
int num_experts); int num_experts, int quant_group_size=0);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
......
...@@ -210,13 +210,13 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -210,13 +210,13 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
// Message package: hidden data, FP8 scales, index at source // Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use // NOTES: currently we have 3 reserved int fields for future use
using vec_t = typename std::conditional<kUseQuant8Bit, int2, int4>::type; using vec_t = typename std::conditional<kUseQuant8Bit, int2, int4>::type;
constexpr size_t num_bytes_per_msg = sizeof(int4) + (kUseQuant8Bit ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16))); constexpr size_t num_bytes_per_msg = sizeof(int4) +
(kUseQuant8Bit ? (kHidden + (kQuantGroupSize == 0 ? 4 : kNumScales) * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
EP_STATIC_ASSERT(num_bytes_per_msg % sizeof(int4) == 0, "Invalid message size"); EP_STATIC_ASSERT(num_bytes_per_msg % sizeof(int4) == 0, "Invalid message size");
constexpr size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); constexpr size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
// Expert counts // Expert counts
constexpr int kNumMaxWarpGroups = 1024 / kWarpSize; __shared__ int shared_num_tokens_sent_per_expert[kMaxNumWarps];
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
// Sending phase // Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
...@@ -230,7 +230,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -230,7 +230,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
constexpr int kNumThreadPerGroup = QUANTIZATION_GROUPSIZE / kNumElemsPerRead; constexpr int kNumThreadPerGroup = QUANTIZATION_GROUPSIZE / kNumElemsPerRead;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); // EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization"); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * kWarpSize; const auto num_threads = num_warps * kWarpSize;
constexpr int hidden_bf16_int4 = kHidden / kNumElemsPerRead; constexpr int hidden_bf16_int4 = kHidden / kNumElemsPerRead;
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
...@@ -375,7 +375,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -375,7 +375,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
} }
// This SM should be responsible for some destination experts, read `topk_idx` for them // This SM should be responsible for some destination experts, read `topk_idx` for them
int expert_count[kNumMaxWarpGroups] = {0}; int expert_count[kMaxNumWarps] = {0};
const auto expert_begin_idx = sm_id * num_warp_groups; const auto expert_begin_idx = sm_id * num_warp_groups;
const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts); const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);
...@@ -465,7 +465,7 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -465,7 +465,7 @@ LOW_LATENCY_DISPATCH_RECV:
(kQuantGroupSize == 0 ? 1 : num_aligned_scales); (kQuantGroupSize == 0 ? 1 : num_aligned_scales);
// Shared between sub-warps in warp groups // Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; __shared__ int shared_num_recv_tokens[kMaxNumWarps], shared_recv_token_begin_idx[kMaxNumWarps];
// Wait tokens to arrive // Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0 // NOTES: using sub-warp 1 to overlap with sub-warp 0
......
...@@ -212,7 +212,7 @@ class Buffer: ...@@ -212,7 +212,7 @@ class Buffer:
@staticmethod @staticmethod
def get_low_latency_rdma_size_hint( def get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int, quant_group_size: int = 0
) -> int: ) -> int:
""" """
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16. Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
...@@ -222,12 +222,13 @@ class Buffer: ...@@ -222,12 +222,13 @@ class Buffer:
hidden: the hidden dimension of each token. hidden: the hidden dimension of each token.
num_ranks: the number of EP group ranks. num_ranks: the number of EP group ranks.
num_experts: the number of all experts. num_experts: the number of all experts.
quant_group_size: the group size if use quant.
Returns: Returns:
size: the RDMA buffer size recommended. size: the RDMA buffer size recommended.
""" """
return deep_ep_cpp.get_low_latency_rdma_size_hint( return deep_ep_cpp.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size
) )
def get_comm_stream(self) -> torch.Stream: def get_comm_stream(self) -> torch.Stream:
...@@ -823,7 +824,7 @@ class Buffer: ...@@ -823,7 +824,7 @@ class Buffer:
return combined_x, combined_topk_weights, EventOverlap(event) return combined_x, combined_topk_weights, EventOverlap(event)
def clean_low_latency_buffer( def clean_low_latency_buffer(
self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int, quant_group_size: int = 0
) -> None: ) -> None:
""" """
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
...@@ -835,8 +836,9 @@ class Buffer: ...@@ -835,8 +836,9 @@ class Buffer:
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
hidden: the hidden dimension of each token. hidden: the hidden dimension of each token.
num_experts: the number of all experts. num_experts: the number of all experts.
quant_group_size: the group size if use quant.
""" """
self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts) self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts, quant_group_size)
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment