Unverified Commit 1cf85fb2 authored by Chenggang Zhao's avatar Chenggang Zhao Committed by GitHub
Browse files

Support 10-bit LogFMT (simulated version) (#284)



* Add LogFMT interface

* Update comments

* Add simulated code

* Fix comments

* Change to 128 channels

* Add notes

* Optimize performance

* optimize simulate logfmt 10bit

* Minor fix

* Stronger low latency tests

* Minor fix

* Stronger low latency tests for logfmt

* Optimize logfmt simulate: lg2/ex2 ptx, step_inv

* Minor fix

* Minor fix

* Add non-logfmt bench

* Fix value=0 for logfmt

* Optimize performance

* Refactor tests

---------
Co-authored-by: default avatarZhean Xu <xza@deepseek.com>
parent c50f3d6f
......@@ -1186,7 +1186,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);
......@@ -1247,6 +1247,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
use_logfmt,
workspace, num_device_sms,
launch_stream, phases, zero_copy);
};
......
......@@ -147,7 +147,7 @@ public:
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
torch::Tensor
......
......@@ -159,6 +159,7 @@ void combine(void* combined_x,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy);
......
......@@ -91,7 +91,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps - 1) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, "Invalid hidden");
EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * 32;
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
......@@ -125,7 +125,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
amax = half_warp_reduce_max(amax);
amax = warp_reduce_max<16>(amax);
calculate_fp8_scales(amax, scale, scale_inv, round_scale);
if (lane_id == 0 or lane_id == 16)
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
......@@ -382,7 +382,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
#undef DISPATCH_LAUNCH_CASE
}
template <int kHidden, int kNumMaxTopk>
template <bool kUseLogFMT, int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(1024, 1) void
combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
......@@ -452,21 +452,92 @@ combine(void* combined_x,
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
// Copy directly to local rank, or copy to buffer and issue RDMA
auto src_idx = __ldg(local_src_info + token_idx);
const auto src_idx = __shfl_sync(0xffffffff, __ldg(local_src_info + token_idx), 0);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
if (dst_p2p_ptr == 0) {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
} else {
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_p2p_ptr);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
if (not zero_copy or dst_p2p_ptr != 0) {
constexpr int kNumUnrolls = 4;
constexpr int hidden_bf16_int4_pad = align(static_cast<int>(hidden_bf16_int4), 32 * kNumUnrolls);
// Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
const auto cpy_src_int4_ptr = zero_copy ? reinterpret_cast<int4*>(buf_ptr) : x_int4;
const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr) : reinterpret_cast<int4*>(dst_p2p_ptr);
#pragma unroll
for (int i = lane_id * kNumUnrolls; i < hidden_bf16_int4_pad; i += 32 * kNumUnrolls) {
// Read
int4 int4_values[kNumUnrolls];
if (i < hidden_bf16_int4) {
#pragma unroll
for (int k = 0; k < kNumUnrolls; ++ k)
int4_values[k] = ld_nc_global(cpy_src_int4_ptr + i + k);
}
auto bf16_values = reinterpret_cast<nv_bfloat16*>(int4_values);
auto uint32_values = reinterpret_cast<uint32_t*>(int4_values);
// Simulated cast
if constexpr (kUseLogFMT) {
constexpr float kThreshold = 1;
constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
constexpr int kNumBits = 10;
constexpr int kNumValues = 1 << (kNumBits - 1);
EP_STATIC_ASSERT(kHidden % (kNumElemsPerInt4 * 32) == 0 and kNumElemsPerInt4 == 8, "Invalid hidden");
// Local log amax
float log_abs_values[kNumElemsPerInt4 * kNumUnrolls], log_amax, log_amin, amax;
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4 * kNumUnrolls; ++ j) {
auto value = static_cast<float>(bf16_values[j]);
log_abs_values[j] = log2f_approx(fabsf(value));
amax = j == 0 ? value : fmaxf(amax, fabsf(value));
log_amax = j == 0 ? log_abs_values[j] : fmaxf(log_amax, log_abs_values[j]);
log_amin = value != 0 ? (j == 0 ? log_abs_values[j] : fminf(log_amin, log_abs_values[j])) : log_amin;
}
// Reduce per 128 channels
amax = warp_reduce_max<(16 / kNumUnrolls)>(amax);
log_amax = warp_reduce_max<(16 / kNumUnrolls)>(log_amax);
log_amin = fmaxf(warp_reduce_min<(16 / kNumUnrolls)>(log_amin), log_amax - kMinClip);
const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
const auto step_inv = 1.0f / step;
const auto rounding = 2.0f - log2f_approx((1.0f + exp2f_approx(step)) * 0.5f) * step_inv;
// Use LogFMT only with `amax <= kThreshold` (maybe not all quarter-warps)
if (amax <= kThreshold and log_amin < log_amax) {
// Transform
auto transform = [=](const float& log_abs_value) -> nv_bfloat16 {
const auto encoded = floorf((log_abs_value - log_amin) * step_inv + rounding);
const auto decoded = exp2f_approx((encoded - 1) * step + log_amin);
return decoded;
};
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4 * kNumUnrolls; j += 2) {
auto bf162_pack = __nv_bfloat162(transform(log_abs_values[j]), transform(log_abs_values[j + 1]));
auto uint32_pack = *reinterpret_cast<uint32_t*>(&bf162_pack);
uint32_values[j / 2] = (uint32_values[j / 2] & 0x80008000) | uint32_pack;
}
}
__syncwarp();
}
// Store
EP_STATIC_ASSERT(hidden_bf16_int4 % kNumUnrolls == 0, "Invalid hidden");
if (i < hidden_bf16_int4) {
#pragma unroll
for (int k = 0; k < kNumUnrolls; ++ k)
st_na_global(cpy_dst_int4_ptr + i + k, int4_values[k]);
}
}
}
// Issue RDMA
// NOTES: for zero-copy mode, we assume the data is already in the send buffer
if (dst_p2p_ptr == 0)
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
}
// Put the finishing flag
EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16);
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(num_warps_per_group * 32));
......@@ -545,6 +616,7 @@ void combine(void* combined_x,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy) {
constexpr int kNumMaxTopk = 9;
......@@ -560,8 +632,13 @@ void combine(void* combined_x,
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
// Online cast cannot use zero-copy
EP_HOST_ASSERT(not (zero_copy and use_logfmt));
#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<hidden, kNumMaxTopk>; \
auto combine_func = use_logfmt ? \
combine<true, hidden, kNumMaxTopk> : \
combine<false, hidden, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
......
......@@ -266,6 +266,18 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
}
__device__ __forceinline__ float log2f_approx(const float &x) {
float ret;
asm volatile("lg2.approx.f32 %0, %1;" : "=f"(ret) : "f"(x));
return ret;
}
__device__ __forceinline__ float exp2f_approx(const float &x) {
float ret;
asm volatile("ex2.approx.f32 %0, %1;" : "=f"(ret) : "f"(x));
return ret;
}
// TMA PTX instructions
#ifndef DISABLE_SM90_FEATURES
......@@ -333,12 +345,12 @@ __device__ __forceinline__ void tma_store_wait() {
#endif
template <typename dtype_t>
__host__ __device__ dtype_t ceil_div(dtype_t a, dtype_t b) {
__host__ __device__ constexpr dtype_t ceil_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}
template <typename dtype_t>
__host__ __device__ dtype_t align(dtype_t a, dtype_t b) {
__host__ __device__ constexpr dtype_t align(dtype_t a, dtype_t b) {
return ceil_div<dtype_t>(a, b) * b;
}
......@@ -376,25 +388,6 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t*>(recv_int_values);
}
__forceinline__ __device__ int warp_reduce_sum(int value) {
value += __shfl_xor_sync(0xffffffff, value, 16);
value += __shfl_xor_sync(0xffffffff, value, 8);
value += __shfl_xor_sync(0xffffffff, value, 4);
value += __shfl_xor_sync(0xffffffff, value, 2);
value += __shfl_xor_sync(0xffffffff, value, 1);
return value;
}
__forceinline__ __device__ float half_warp_reduce_max(float value) {
auto mask = __activemask();
// The mask be in `{0xffffffff, 0xffff}`
value = max(value, __shfl_xor_sync(mask, value, 8));
value = max(value, __shfl_xor_sync(mask, value, 4));
value = max(value, __shfl_xor_sync(mask, value, 2));
value = max(value, __shfl_xor_sync(mask, value, 1));
return value;
}
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
......@@ -493,4 +486,40 @@ __forceinline__ __device__ void release_lock(int* mutex) {
atomic_exch_cta_release(mutex, 0);
}
// Operation functors
template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
// Unified reduction function
template <uint32_t kNumLanes, typename T, typename Op>
__forceinline__ __device__ T warp_reduce(T value, Op op) {
EP_STATIC_ASSERT(kNumLanes == 32 or kNumLanes == 16 or kNumLanes == 8 or
kNumLanes == 4 or kNumLanes == 2 or kNumLanes == 1,
"Invalid number of lanes");
if constexpr (kNumLanes >= 32) value = op(value, __shfl_xor_sync(0xffffffff, value, 16));
if constexpr (kNumLanes >= 16) value = op(value, __shfl_xor_sync(0xffffffff, value, 8));
if constexpr (kNumLanes >= 8) value = op(value, __shfl_xor_sync(0xffffffff, value, 4));
if constexpr (kNumLanes >= 4) value = op(value, __shfl_xor_sync(0xffffffff, value, 2));
if constexpr (kNumLanes >= 2) value = op(value, __shfl_xor_sync(0xffffffff, value, 1));
return value;
}
// Convenience aliases
template < uint32_t kNumLanes = 32, typename T>
__forceinline__ __device__ T warp_reduce_sum(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceSum<T>{});
}
template <uint32_t kNumLanes = 32, typename T>
__forceinline__ __device__ T warp_reduce_max(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceMax<T>{});
}
template <uint32_t kNumLanes = 32, typename T>
__forceinline__ __device__ T warp_reduce_min(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceMin<T>{});
}
} // namespace deep_ep
......@@ -562,7 +562,7 @@ class Buffer:
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, zero_copy: bool = False, async_finish: bool = False,
handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]:
"""
......@@ -581,6 +581,7 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
......@@ -597,7 +598,8 @@ class Buffer:
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
num_max_dispatch_tokens_per_rank, num_experts,
zero_copy, async_finish, return_recv_hook, out)
use_logfmt, zero_copy, async_finish, return_recv_hook,
out)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
......
......@@ -9,7 +9,8 @@ from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_to
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0):
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer,
use_logfmt: bool = False, seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
......@@ -22,6 +23,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
......@@ -33,6 +35,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for current_x in (x, x_pure_rand):
for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True):
for round_scale in (False, True) if dispatch_use_fp8 else (False, ):
......@@ -41,7 +44,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
for i in range((num_times % 2) + 1):
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
......@@ -64,6 +67,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
# Check received data
if current_x is not x_pure_rand:
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
......@@ -84,18 +88,19 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness
for zero_copy in (False, True):
for zero_copy in (False, ) if use_logfmt else (False, True):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
use_logfmt=use_logfmt,
async_finish=not return_recv_hook, zero_copy=zero_copy,
return_recv_hook=return_recv_hook, out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < (7e-4 if round_scale else 1e-5), f'Error: {diff=}, {zero_copy=}'
assert diff < (7e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {zero_copy=}'
hash_value ^= hash_tensor(combined_x)
# noinspection PyShadowingNames
......@@ -106,16 +111,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hook()
# noinspection PyShadowingNames
def test_func(zero_copy: bool, return_recv_hook: bool):
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
buffer.low_latency_dispatch(x_pure_rand, topk_idx, num_tokens, num_experts,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
zero_copy=zero_copy, return_recv_hook=return_recv_hook)
use_logfmt=use_logfmt, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
......@@ -127,14 +130,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False))
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
# Separate profiling
for return_recv_hook in (False, True):
group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook:
......@@ -156,16 +159,20 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=1)
do_pressure_test = False
for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0:
print(f'Testing with seed {seed} ...', flush=True)
ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed)
ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=seed)
for i in range(20):
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}'
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the communication group
dist.barrier()
......@@ -174,6 +181,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)')
......@@ -185,6 +193,10 @@ if __name__ == '__main__':
help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=288,
help='Number of experts (default: 288)')
parser.add_argument('--disable-nvlink', action='store_true',
help='Whether to disable NVLink for testing')
parser.add_argument('--use-logfmt', action='store_true',
help='Whether to test LogFMT combine')
args = parser.parse_args()
num_processes = args.num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment