Unverified Commit 1cf85fb2 authored by Chenggang Zhao's avatar Chenggang Zhao Committed by GitHub
Browse files

Support 10-bit LogFMT (simulated version) (#284)



* Add LogFMT interface

* Update comments

* Add simulated code

* Fix comments

* Change to 128 channels

* Add notes

* Optimize performance

* optimize simulate logfmt 10bit

* Minor fix

* Stronger low latency tests

* Minor fix

* Stronger low latency tests for logfmt

* Optimize logfmt simulate: lg2/ex2 ptx, step_inv

* Minor fix

* Minor fix

* Add non-logfmt bench

* Fix value=0 for logfmt

* Optimize performance

* Refactor tests

---------
Co-authored-by: default avatarZhean Xu <xza@deepseek.com>
parent c50f3d6f
...@@ -1186,7 +1186,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio ...@@ -1186,7 +1186,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) { const std::optional<torch::Tensor>& out) {
#ifndef DISABLE_NVSHMEM #ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
...@@ -1247,6 +1247,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1247,6 +1247,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, num_topk, num_experts, rank, num_ranks,
use_logfmt,
workspace, num_device_sms, workspace, num_device_sms,
launch_stream, phases, zero_copy); launch_stream, phases, zero_copy);
}; };
......
...@@ -147,7 +147,7 @@ public: ...@@ -147,7 +147,7 @@ public:
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt); const std::optional<torch::Tensor>& out = std::nullopt);
torch::Tensor torch::Tensor
......
...@@ -159,6 +159,7 @@ void combine(void* combined_x, ...@@ -159,6 +159,7 @@ void combine(void* combined_x,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms, void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy); cudaStream_t stream, int phases, bool zero_copy);
......
...@@ -91,7 +91,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -91,7 +91,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// 2. The last warp for reading `topk_idx` and count for per-expert information // 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps - 1) { if (warp_id < num_warps - 1) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, "Invalid hidden");
EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * 32; const auto num_threads = (num_warps - 1) * 32;
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
...@@ -125,7 +125,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -125,7 +125,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Reduce amax and scale // Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
amax = half_warp_reduce_max(amax); amax = warp_reduce_max<16>(amax);
calculate_fp8_scales(amax, scale, scale_inv, round_scale); calculate_fp8_scales(amax, scale, scale_inv, round_scale);
if (lane_id == 0 or lane_id == 16) if (lane_id == 0 or lane_id == 16)
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
...@@ -382,7 +382,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \ ...@@ -382,7 +382,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
#undef DISPATCH_LAUNCH_CASE #undef DISPATCH_LAUNCH_CASE
} }
template <int kHidden, int kNumMaxTopk> template <bool kUseLogFMT, int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(1024, 1) void __global__ __launch_bounds__(1024, 1) void
combine(void* combined_x, combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
...@@ -452,19 +452,90 @@ combine(void* combined_x, ...@@ -452,19 +452,90 @@ combine(void* combined_x,
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row); const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
// Copy directly to local rank, or copy to buffer and issue RDMA // Copy directly to local rank, or copy to buffer and issue RDMA
auto src_idx = __ldg(local_src_info + token_idx); const auto src_idx = __shfl_sync(0xffffffff, __ldg(local_src_info + token_idx), 0);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row); const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot; const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
if (dst_p2p_ptr == 0) {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr); if (not zero_copy or dst_p2p_ptr != 0) {
if (not zero_copy) constexpr int kNumUnrolls = 4;
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); constexpr int hidden_bf16_int4_pad = align(static_cast<int>(hidden_bf16_int4), 32 * kNumUnrolls);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
} else { // Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_p2p_ptr); const auto cpy_src_int4_ptr = zero_copy ? reinterpret_cast<int4*>(buf_ptr) : x_int4;
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr) : reinterpret_cast<int4*>(dst_p2p_ptr);
#pragma unroll
for (int i = lane_id * kNumUnrolls; i < hidden_bf16_int4_pad; i += 32 * kNumUnrolls) {
// Read
int4 int4_values[kNumUnrolls];
if (i < hidden_bf16_int4) {
#pragma unroll
for (int k = 0; k < kNumUnrolls; ++ k)
int4_values[k] = ld_nc_global(cpy_src_int4_ptr + i + k);
}
auto bf16_values = reinterpret_cast<nv_bfloat16*>(int4_values);
auto uint32_values = reinterpret_cast<uint32_t*>(int4_values);
// Simulated cast
if constexpr (kUseLogFMT) {
constexpr float kThreshold = 1;
constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
constexpr int kNumBits = 10;
constexpr int kNumValues = 1 << (kNumBits - 1);
EP_STATIC_ASSERT(kHidden % (kNumElemsPerInt4 * 32) == 0 and kNumElemsPerInt4 == 8, "Invalid hidden");
// Local log amax
float log_abs_values[kNumElemsPerInt4 * kNumUnrolls], log_amax, log_amin, amax;
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4 * kNumUnrolls; ++ j) {
auto value = static_cast<float>(bf16_values[j]);
log_abs_values[j] = log2f_approx(fabsf(value));
amax = j == 0 ? value : fmaxf(amax, fabsf(value));
log_amax = j == 0 ? log_abs_values[j] : fmaxf(log_amax, log_abs_values[j]);
log_amin = value != 0 ? (j == 0 ? log_abs_values[j] : fminf(log_amin, log_abs_values[j])) : log_amin;
}
// Reduce per 128 channels
amax = warp_reduce_max<(16 / kNumUnrolls)>(amax);
log_amax = warp_reduce_max<(16 / kNumUnrolls)>(log_amax);
log_amin = fmaxf(warp_reduce_min<(16 / kNumUnrolls)>(log_amin), log_amax - kMinClip);
const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
const auto step_inv = 1.0f / step;
const auto rounding = 2.0f - log2f_approx((1.0f + exp2f_approx(step)) * 0.5f) * step_inv;
// Use LogFMT only with `amax <= kThreshold` (maybe not all quarter-warps)
if (amax <= kThreshold and log_amin < log_amax) {
// Transform
auto transform = [=](const float& log_abs_value) -> nv_bfloat16 {
const auto encoded = floorf((log_abs_value - log_amin) * step_inv + rounding);
const auto decoded = exp2f_approx((encoded - 1) * step + log_amin);
return decoded;
};
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4 * kNumUnrolls; j += 2) {
auto bf162_pack = __nv_bfloat162(transform(log_abs_values[j]), transform(log_abs_values[j + 1]));
auto uint32_pack = *reinterpret_cast<uint32_t*>(&bf162_pack);
uint32_values[j / 2] = (uint32_values[j / 2] & 0x80008000) | uint32_pack;
}
}
__syncwarp();
}
// Store
EP_STATIC_ASSERT(hidden_bf16_int4 % kNumUnrolls == 0, "Invalid hidden");
if (i < hidden_bf16_int4) {
#pragma unroll
for (int k = 0; k < kNumUnrolls; ++ k)
st_na_global(cpy_dst_int4_ptr + i + k, int4_values[k]);
}
}
} }
// Issue RDMA
// NOTES: for zero-copy mode, we assume the data is already in the send buffer
if (dst_p2p_ptr == 0)
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
} }
// Put the finishing flag // Put the finishing flag
...@@ -545,6 +616,7 @@ void combine(void* combined_x, ...@@ -545,6 +616,7 @@ void combine(void* combined_x,
int* next_clean, int num_next_clean_int, int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms, void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy) { cudaStream_t stream, int phases, bool zero_copy) {
constexpr int kNumMaxTopk = 9; constexpr int kNumMaxTopk = 9;
...@@ -560,8 +632,13 @@ void combine(void* combined_x, ...@@ -560,8 +632,13 @@ void combine(void* combined_x,
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk); EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
// Online cast cannot use zero-copy
EP_HOST_ASSERT(not (zero_copy and use_logfmt));
#define COMBINE_LAUNCH_CASE(hidden) { \ #define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<hidden, kNumMaxTopk>; \ auto combine_func = use_logfmt ? \
combine<true, hidden, kNumMaxTopk> : \
combine<false, hidden, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, \ LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \ combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \ rdma_recv_x, rdma_recv_flag, rdma_send_x, \
......
...@@ -266,6 +266,18 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value ...@@ -266,6 +266,18 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w)); ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
} }
__device__ __forceinline__ float log2f_approx(const float &x) {
float ret;
asm volatile("lg2.approx.f32 %0, %1;" : "=f"(ret) : "f"(x));
return ret;
}
__device__ __forceinline__ float exp2f_approx(const float &x) {
float ret;
asm volatile("ex2.approx.f32 %0, %1;" : "=f"(ret) : "f"(x));
return ret;
}
// TMA PTX instructions // TMA PTX instructions
#ifndef DISABLE_SM90_FEATURES #ifndef DISABLE_SM90_FEATURES
...@@ -333,12 +345,12 @@ __device__ __forceinline__ void tma_store_wait() { ...@@ -333,12 +345,12 @@ __device__ __forceinline__ void tma_store_wait() {
#endif #endif
template <typename dtype_t> template <typename dtype_t>
__host__ __device__ dtype_t ceil_div(dtype_t a, dtype_t b) { __host__ __device__ constexpr dtype_t ceil_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b; return (a + b - 1) / b;
} }
template <typename dtype_t> template <typename dtype_t>
__host__ __device__ dtype_t align(dtype_t a, dtype_t b) { __host__ __device__ constexpr dtype_t align(dtype_t a, dtype_t b) {
return ceil_div<dtype_t>(a, b) * b; return ceil_div<dtype_t>(a, b) * b;
} }
...@@ -376,25 +388,6 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) { ...@@ -376,25 +388,6 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t*>(recv_int_values); return *reinterpret_cast<dtype_t*>(recv_int_values);
} }
__forceinline__ __device__ int warp_reduce_sum(int value) {
value += __shfl_xor_sync(0xffffffff, value, 16);
value += __shfl_xor_sync(0xffffffff, value, 8);
value += __shfl_xor_sync(0xffffffff, value, 4);
value += __shfl_xor_sync(0xffffffff, value, 2);
value += __shfl_xor_sync(0xffffffff, value, 1);
return value;
}
__forceinline__ __device__ float half_warp_reduce_max(float value) {
auto mask = __activemask();
// The mask be in `{0xffffffff, 0xffff}`
value = max(value, __shfl_xor_sync(mask, value, 8));
value = max(value, __shfl_xor_sync(mask, value, 4));
value = max(value, __shfl_xor_sync(mask, value, 2));
value = max(value, __shfl_xor_sync(mask, value, 1));
return value;
}
__forceinline__ __device__ int get_lane_id() { __forceinline__ __device__ int get_lane_id() {
int lane_id; int lane_id;
asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
...@@ -493,4 +486,40 @@ __forceinline__ __device__ void release_lock(int* mutex) { ...@@ -493,4 +486,40 @@ __forceinline__ __device__ void release_lock(int* mutex) {
atomic_exch_cta_release(mutex, 0); atomic_exch_cta_release(mutex, 0);
} }
// Operation functors
template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
// Unified reduction function
template <uint32_t kNumLanes, typename T, typename Op>
__forceinline__ __device__ T warp_reduce(T value, Op op) {
EP_STATIC_ASSERT(kNumLanes == 32 or kNumLanes == 16 or kNumLanes == 8 or
kNumLanes == 4 or kNumLanes == 2 or kNumLanes == 1,
"Invalid number of lanes");
if constexpr (kNumLanes >= 32) value = op(value, __shfl_xor_sync(0xffffffff, value, 16));
if constexpr (kNumLanes >= 16) value = op(value, __shfl_xor_sync(0xffffffff, value, 8));
if constexpr (kNumLanes >= 8) value = op(value, __shfl_xor_sync(0xffffffff, value, 4));
if constexpr (kNumLanes >= 4) value = op(value, __shfl_xor_sync(0xffffffff, value, 2));
if constexpr (kNumLanes >= 2) value = op(value, __shfl_xor_sync(0xffffffff, value, 1));
return value;
}
// Convenience aliases
template < uint32_t kNumLanes = 32, typename T>
__forceinline__ __device__ T warp_reduce_sum(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceSum<T>{});
}
template <uint32_t kNumLanes = 32, typename T>
__forceinline__ __device__ T warp_reduce_max(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceMax<T>{});
}
template <uint32_t kNumLanes = 32, typename T>
__forceinline__ __device__ T warp_reduce_min(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceMin<T>{});
}
} // namespace deep_ep } // namespace deep_ep
...@@ -562,7 +562,7 @@ class Buffer: ...@@ -562,7 +562,7 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, zero_copy: bool = False, async_finish: bool = False, handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \ return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]: Tuple[torch.Tensor, EventOverlap, Callable]:
""" """
...@@ -581,6 +581,7 @@ class Buffer: ...@@ -581,6 +581,7 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor. tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function. handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`. with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
...@@ -597,7 +598,8 @@ class Buffer: ...@@ -597,7 +598,8 @@ class Buffer:
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
zero_copy, async_finish, return_recv_hook, out) use_logfmt, zero_copy, async_finish, return_recv_hook,
out)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
......
...@@ -9,7 +9,8 @@ from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_to ...@@ -9,7 +9,8 @@ from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_to
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0): rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer,
use_logfmt: bool = False, seed: int = 0):
torch.manual_seed(seed + rank) torch.manual_seed(seed + rank)
random.seed(seed + rank) random.seed(seed + rank)
...@@ -22,6 +23,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -22,6 +23,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset) x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1) x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
...@@ -33,70 +35,73 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -33,70 +35,73 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Check dispatch correctness # Check dispatch correctness
do_check = True do_check = True
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for return_recv_hook in (False, True): for current_x in (x, x_pure_rand):
for dispatch_use_fp8 in (False, True): for return_recv_hook in (False, True):
for round_scale in (False, True) if dispatch_use_fp8 else (False, ): for dispatch_use_fp8 in (False, True):
for use_ue8m0 in (False, True) if round_scale else (False, ): for round_scale in (False, True) if dispatch_use_fp8 else (False, ):
num_times += 1 for use_ue8m0 in (False, True) if round_scale else (False, ):
for i in range((num_times % 2) + 1): num_times += 1
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda') for i in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \ cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, packed_recv_x, packed_recv_count, handle, event, hook = \
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
hook() if return_recv_hook else event.current_stream_wait() async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x hook() if return_recv_hook else event.current_stream_wait()
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
if dispatch_use_fp8 else packed_recv_x.clone() simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda') if dispatch_use_fp8 else packed_recv_x.clone()
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
for i in range(num_local_experts if do_check else 0): dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
expert_id = rank * num_local_experts + i for i in range(num_local_experts if do_check else 0):
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] expert_id = rank * num_local_experts + i
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1 # Check expert indices
num_valid_tokens = recv_count.item() int_mask = (2 ** 32) - 1
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}' num_valid_tokens = recv_count.item()
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
# Check received data
recv_x = recv_x[:num_valid_tokens] # Check received data
recv_x_amin = recv_x[:, :-128].amin(dim=-1) if current_x is not x_pure_rand:
recv_src_info = recv_src_info[:num_valid_tokens] recv_x = recv_x[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) recv_x_amin = recv_x[:, :-128].amin(dim=-1)
if round_scale: recv_src_info = recv_src_info[:num_valid_tokens]
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
else: if round_scale:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
for j in range(num_ranks): else:
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
if not round_scale: for j in range(num_ranks):
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0 if not round_scale:
if dispatch_use_fp8: assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) if dispatch_use_fp8:
else: hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
# Check combine correctness hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
for zero_copy in (False, True):
if zero_copy: # Check combine correctness
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x for zero_copy in (False, ) if use_logfmt else (False, True):
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') if zero_copy:
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
async_finish=not return_recv_hook, zero_copy=zero_copy, out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
return_recv_hook=return_recv_hook, out=out) combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
hook() if return_recv_hook else event.current_stream_wait() use_logfmt=use_logfmt,
if do_check: async_finish=not return_recv_hook, zero_copy=zero_copy,
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) return_recv_hook=return_recv_hook, out=out)
assert torch.isnan(combined_x).sum().item() == 0 hook() if return_recv_hook else event.current_stream_wait()
assert diff < (7e-4 if round_scale else 1e-5), f'Error: {diff=}, {zero_copy=}' if do_check:
hash_value ^= hash_tensor(combined_x) diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < (7e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {zero_copy=}'
hash_value ^= hash_tensor(combined_x)
# noinspection PyShadowingNames # noinspection PyShadowingNames
def large_gemm_with_hook(hook): def large_gemm_with_hook(hook):
...@@ -106,16 +111,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -106,16 +111,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hook() hook()
# noinspection PyShadowingNames # noinspection PyShadowingNames
def test_func(zero_copy: bool, return_recv_hook: bool): def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \ recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(x_pure_rand, topk_idx, num_tokens, num_experts,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook) use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
zero_copy=zero_copy, return_recv_hook=return_recv_hook) use_logfmt=use_logfmt, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth # Calculate bandwidth
...@@ -127,14 +130,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -127,14 +130,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
num_combine_comm_bytes += num_bf16_bytes * num_selections num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing # Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False)) avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
# Separate profiling # Separate profiling
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
group.barrier() group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook), dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1) suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook: if not return_recv_hook:
...@@ -156,16 +159,20 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -156,16 +159,20 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
if local_rank == 0: if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks) num_qps_per_rank=num_experts // num_ranks,
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1) allow_nvlink_for_low_latency_mode=not args.disable_nvlink)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=1)
do_pressure_test = False do_pressure_test = False
for seed in range(int(1e9) if do_pressure_test else 0): for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0: if local_rank == 0:
print(f'Testing with seed {seed} ...', flush=True) print(f'Testing with seed {seed} ...', flush=True)
ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=seed)
for i in range(20): for i in range(20):
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}' assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the communication group # Destroy the communication group
dist.barrier() dist.barrier()
...@@ -174,6 +181,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -174,6 +181,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
if __name__ == '__main__': if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead # TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser = argparse.ArgumentParser(description='Test low-latency EP kernels') parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser.add_argument('--num-processes', type=int, default=8, parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)') help='Number of processes to spawn (default: 8)')
...@@ -185,6 +193,10 @@ if __name__ == '__main__': ...@@ -185,6 +193,10 @@ if __name__ == '__main__':
help='Number of top-k experts (default: 8)') help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=288, parser.add_argument('--num-experts', type=int, default=288,
help='Number of experts (default: 288)') help='Number of experts (default: 288)')
parser.add_argument('--disable-nvlink', action='store_true',
help='Whether to disable NVLink for testing')
parser.add_argument('--use-logfmt', action='store_true',
help='Whether to test LogFMT combine')
args = parser.parse_args() args = parser.parse_args()
num_processes = args.num_processes num_processes = args.num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment