Unverified Commit c5facf5c authored by Zhean Xu's avatar Zhean Xu Committed by GitHub
Browse files

Support 10-bit LogFMT Combine (#345)



* independent logfmt_simulate function

* draft: logfmt low latency combine

* Minor bug fixes

* Fix non-logfmt bugs

* Fix logfmt bugs

* Fix logfmt bugs

* Minor fix

* Minor fix

* Clean code

* Clean code

* Use fewer regs

* Use two warp groups

* Correct shared memory size

* Minor fix

* Minor fix

* More rigorous tests

* Clean code

* Use more SMs

* Use different unroll factor for send & recv

* Update csrc/kernels/internode_ll.cu
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>

* Update csrc/kernels/internode_ll.cu
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>

* Some renaming

* Some comments of tests

* Format `logfmt_encode`

* More lints

* Some refactors on sends

* Fix testing

* Fix bugs

* Renaming

* Use the full warp

* Unify combine recv

* Lint

* Lint

* Support 2560

* Fix meta buffer dtype

* Better encode calls

* Better amin/max writes

* Extra sync

* Read `topk_idx` by once

* Better specialization

* Read weights by once

* Rename

* Bug fixed

* Some renaming

* Fix local memory usage for sending

* Fix local memory usage for receiving

* Less writes

* Optimize performance

* Optimize performance

* Better performance

* Optimize performance

* Fix rounding

* Manually unroll

* Fix bench

---------
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent 26cf250a
......@@ -136,9 +136,10 @@ struct LowLatencyLayout {
// Message sizes
// NOTES: you should add a control `int4` for combine messages if you want to do data transformation
// NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float));
size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16);
size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16);
// Send buffer
size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
......
......@@ -74,7 +74,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
using vec_t = std::conditional_t<kUseFP8, int2, int4>;
const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16)));
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
......@@ -108,6 +108,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
// FP8 cast
EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce");
#pragma unroll
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read
......@@ -391,7 +392,166 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
#undef DISPATCH_LAUNCH_CASE
}
template <bool kUseLogFMT, int kHidden, int kNumMaxTopk>
template <int kNumSendUnrolls>
__forceinline__ __device__ int logfmt_encode(void* buffer, nv_bfloat162 *shared_amaxmin, const int& lane_id) {
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
constexpr float kLogThreshold = 0;
constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
constexpr int kNumBits = 10;
constexpr int kNumValues = 1 << (kNumBits - 1);
int4 int4_values[kNumSendUnrolls];
const auto& uint32_values = reinterpret_cast<uint32_t*>(int4_values);
const auto& bf162_values = reinterpret_cast<nv_bfloat162*>(int4_values);
// Calculate lane offset
const auto& ld_buffer = reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4)));
const auto& st_buffer = reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4) * 10 / 16));
// Local log amax
auto bf162_amax = __nv_bfloat162(CUDART_ZERO_BF16, CUDART_ZERO_BF16);
auto bf162_amin = __nv_bfloat162(CUDART_INF_BF16, CUDART_INF_BF16);
uint32_t local_signs = 0;
#pragma unroll
for (int k = 0; k < kNumSendUnrolls * kNumElemsPerInt4 / 2; ++ k) {
// TODO: eliminate bank conflicts
uint32_values[k] = ld_buffer[k];
local_signs |= ((uint32_values[k] >> 15) & 1) << (k * 2);
local_signs |= ((uint32_values[k] >> 31) & 1) << (k * 2 + 1);
uint32_values[k] &= 0x7fff7fff;
bf162_amax = __hmax2(bf162_amax, bf162_values[k]);
bf162_amin = __hmin2(bf162_amin, bf162_values[k]);
}
// Reduce per 128 channels
// TODO: figure out how hardware do 2-byte min/max
auto amax = std::max(static_cast<float>(bf162_amax.x), static_cast<float>(bf162_amax.y));
auto amin = std::min(static_cast<float>(bf162_amin.x), static_cast<float>(bf162_amin.y));
constexpr static int kNumLanesToReduce = 128 * sizeof(nv_bfloat16) / (kNumSendUnrolls * sizeof(int4));
amax = warp_reduce_max<kNumLanesToReduce>(amax);
amin = warp_reduce_min<kNumLanesToReduce>(amin);
// Write min/max into the shared memory
if (shared_amaxmin != nullptr)
*shared_amaxmin = __nv_bfloat162(amax, amin);
__syncwarp();
// Calculate log amin/amax float
const auto& log_amax = log2f_approx(amax);
const auto& log_amin = fmaxf(log2f_approx(amin), log_amax - kMinClip);
const bool& enable_cast = warp_reduce_and<kNumLanesToReduce, true>(log_amax < kLogThreshold and log_amin < log_amax);
// Case into LogFMT-10 if satisfied
if (enable_cast) {
const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
const auto step_inv = 1.0f / step;
const auto rounding = 2.0f - log2f_approx((1.0f + exp2f_approx(step)) * 0.5f) * step_inv;
const auto fused_rounding = rounding - log_amin * step_inv;
// Pack every 256 bits into 160 bits
EP_STATIC_ASSERT(kNumSendUnrolls == 2 or kNumSendUnrolls == 4, "kNumSendUnrolls == 2 or 4 only");
uint32_t encoded[kNumElemsPerInt4 * 2];
#pragma unroll 1
for (int i = 0; i < kNumSendUnrolls / 2; ++ i) {
#pragma unroll
for (int k = 0; k < kNumElemsPerInt4; ++ k) {
const auto& [x, y] = __bfloat1622float2(bf162_values[i * kNumElemsPerInt4 + k]);
encoded[k * 2 + 0] = __float2uint_rd(fmaxf(log2f_approx(x) * step_inv + fused_rounding, 0));
encoded[k * 2 + 1] = __float2uint_rd(fmaxf(log2f_approx(y) * step_inv + fused_rounding, 0));
}
st_buffer[i * 5 + 0] = (encoded[ 0] >> 0) | (encoded[ 1] << 9) | (encoded[ 2] << 18) | (encoded[ 3] << 27);
st_buffer[i * 5 + 1] = (encoded[ 3] >> 5) | (encoded[ 4] << 4) | (encoded[ 5] << 13) | (encoded[ 6] << 22) | (encoded[7] << 31);
st_buffer[i * 5 + 2] = (encoded[ 7] >> 1) | (encoded[ 8] << 8) | (encoded[ 9] << 17) | (encoded[10] << 26);
st_buffer[i * 5 + 3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30);
st_buffer[i * 5 + 4] = (encoded[14] >> 2) | (encoded[15] << 7) | ((i == 0) ? (local_signs << 16) : (local_signs & 0xffff0000u));
}
tma_store_fence();
__syncwarp();
}
// Return TMA copy bytes
return enable_cast ? (32 * (kNumSendUnrolls * sizeof(int4) * 8 * 10 / 16 / 8)):
(32 * (kNumSendUnrolls * sizeof(int4)));
}
template <int kNumLanes, int kNumSendUnrolls, int kNumRecvUnrolls>
__forceinline__ __device__ void logfmt_check_amaxmin(uint8_t* meta_buffer, float2* shared_log_amax,
float2* shared_log_amin, int* shared_cast_info,
const int lane_id) {
constexpr float kLogThreshold = 0;
constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
bool enable_cast = true;
if (lane_id < kNumLanes) {
// Calculate log amin/amax float
auto amaxmin2 = reinterpret_cast<uint64_t*>(meta_buffer)[lane_id];
const auto& bf162_amaxmin = reinterpret_cast<__nv_bfloat162*>(&amaxmin2);
float log_amax[2], log_amin[2];
#pragma unroll
for (int i = 0; i < 2; ++ i) {
auto amax = static_cast<float>(bf162_amaxmin[i].x);
auto amin = static_cast<float>(bf162_amaxmin[i].y);
log_amax[i] = log2f_approx(amax);
log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : fmaxf(log2f_approx(amin), log_amax[i] - kMinClip);
enable_cast = enable_cast and log_amax[i] < kLogThreshold and log_amin[i] < log_amax[i];
}
shared_log_amax[lane_id] = make_float2(log_amax[0], log_amax[1]);
shared_log_amin[lane_id] = make_float2(log_amin[0], log_amin[1]);
}
const auto& casted = warp_reduce_and<kNumSendUnrolls>(enable_cast) ? 1u << (lane_id / kNumRecvUnrolls): 0u;
const auto& num_casted_prefix = std::__popcount(warp_reduce_or<kNumRecvUnrolls, true>(casted) & ((1u << (lane_id / kNumRecvUnrolls)) - 1));
if (lane_id < kNumLanes and lane_id % kNumRecvUnrolls == 0)
shared_cast_info[lane_id / kNumRecvUnrolls] = (num_casted_prefix << 1) | (casted ? 1u : 0u);
__syncwarp();
}
template <int kNumRecvUnrolls>
__forceinline__ __device__ void decode_and_accumulate(uint32_t* ld_buffer, float* accum,
const float& log_amax, const float& log_amin,
const bool& enable_cast, const float& weight) {
if (enable_cast) {
constexpr int kNumBits = 10;
constexpr int kNumValues = 1 << (kNumBits - 1);
const auto& step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
auto decode = [=](const uint32_t &encoded, const uint32_t &sign) {
const auto decoded = encoded == 0 ? .0f : exp2f_approx((encoded - 1) * step + log_amin);
return sign ? -decoded : decoded;
};
EP_STATIC_ASSERT(kNumRecvUnrolls == 2 or kNumRecvUnrolls == 4, "kNumRecvUnrolls == 2 or 4 only");
#pragma unroll
for (int i = 0; i < kNumRecvUnrolls / 2; ++ i) {
uint32_t concat[6];
concat[0] = ld_buffer[i * 5];
#pragma unroll
for (int k = 1; k < 5; ++ k)
concat[k] = (ld_buffer[i * 5 + k - 1] >> (32 - k * 5)) | (ld_buffer[i * 5 + k] << (k * 5));
concat[5] = ld_buffer[i * 5 + 4] >> 7;
const uint32_t& local_signs = ld_buffer[i * 5 + 4] >> 16;
#pragma unroll
for (int k = 0; k < 5; ++ k) {
accum[i * 16 + k * 3 + 0] += decode((concat[k] >> 0) & 0x1ff, (local_signs >> (k * 3 + 0)) & 1) * weight;
accum[i * 16 + k * 3 + 1] += decode((concat[k] >> 9) & 0x1ff, (local_signs >> (k * 3 + 1)) & 1) * weight;
accum[i * 16 + k * 3 + 2] += decode((concat[k] >> 18) & 0x1ff, (local_signs >> (k * 3 + 2)) & 1) * weight;
}
accum[i * 16 + 15] += decode(concat[5] & 0x1ff, (local_signs >> 15) & 1) * weight;
}
} else {
#pragma unroll
for (int k = 0; k < kNumRecvUnrolls * 4; ++ k) {
auto bf16_pack = *reinterpret_cast<__nv_bfloat162*>(ld_buffer + k);
accum[k * 2 + 0] += static_cast<float>(bf16_pack.x) * weight;
accum[k * 2 + 1] += static_cast<float>(bf16_pack.y) * weight;
}
}
}
template <bool kUseLogFMT, int kHidden, int kNumMaxTopk, int kNumMaxUnrolls>
__global__ __launch_bounds__(1024, 1) void
combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
......@@ -405,26 +565,36 @@ combine(void* combined_x,
int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
int phases, bool zero_copy) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto sm_id = __shfl_sync(0xffffffff, static_cast<int>(blockIdx.x), 0);
const auto num_sms = __shfl_sync(0xffffffff, static_cast<int>(gridDim.x), 0);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto num_threads = static_cast<int>(blockDim.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
const auto num_threads = __shfl_sync(0xffffffff, static_cast<int>(blockDim.x), 0);
const auto warp_id = __shfl_sync(0xffffffff, thread_id / 32, 0), lane_id = get_lane_id();
const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / num_warps_per_group;
const auto sub_warp_id = warp_id % num_warps_per_group;
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// Data type staffs
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
constexpr int64_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
constexpr int kNumUnrolls = 4;
constexpr int hidden_bf16_int4_pad = align(static_cast<int>(hidden_bf16_int4), 32 * kNumUnrolls);
EP_STATIC_ASSERT(hidden_bf16_int4 % kNumUnrolls == 0, "Invalid hidden");
EP_STATIC_ASSERT(kNumUnrolls == 1 or kNumUnrolls == 2 or kNumUnrolls == 4, "Invalid unrolling factors");
// Use different unroll factors for send and recv phases
constexpr int kNumSendUnrolls = kHidden % (32 * 4 * sizeof(int4) / sizeof(nv_bfloat16)) == 0 ? 4 : 2;
constexpr int kNumRecvUnrolls = 2;
constexpr int hidden_bf16_int4_pad = align(static_cast<int>(hidden_bf16_int4), 32 * kNumSendUnrolls);
EP_STATIC_ASSERT(kHidden % (32 * 2 * sizeof(int4) / sizeof(nv_bfloat16)) == 0, "Invalid hidden");
EP_STATIC_ASSERT(kNumSendUnrolls <= kNumMaxUnrolls and kNumRecvUnrolls <= kNumMaxUnrolls, "Invalid unrolls");
EP_STATIC_ASSERT(hidden_bf16_int4 % kNumSendUnrolls == 0, "Invalid hidden");
EP_STATIC_ASSERT(kNumSendUnrolls >= kNumRecvUnrolls, "Invalid unroll factors");
// Message package
constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16);
EP_STATIC_ASSERT(kHidden % 128 == 0, "Invalid hidden");
constexpr int kNumDivisions = kHidden / 128;
constexpr int kNumMetaBytes = kNumDivisions * sizeof(nv_bfloat162);
constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16) + kNumMetaBytes;
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// Sending phase
......@@ -460,30 +630,30 @@ combine(void* combined_x,
unpack2(layout, num_tokens_to_send, offset);
// TMA stuffs
constexpr int kNumTMABufferBytes = sizeof(int4) * 32 * kNumUnrolls;
constexpr int kNumTMABufferBytes = sizeof(int4) * 32 * kNumSendUnrolls;
constexpr int kNumStages = 3;
constexpr int kNumPrefetch = 1;
EP_STATIC_ASSERT(kNumStages == 3 and kNumPrefetch == 1, "Invalid stages");
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto smem_ptr = smem_buffer + warp_id * kNumStages * (kNumTMABufferBytes + 16);
uint32_t tma_phase[kNumStages] = {};
auto tma_buffer = PatternVisitor([=](const int& i) { return reinterpret_cast<int4*>(smem_ptr + i * (kNumTMABufferBytes + 16)); });
auto tma_mbarrier = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_ptr + i * (kNumTMABufferBytes + 16) + kNumTMABufferBytes); });
EP_STATIC_ASSERT(kNumUnrolls * kNumStages <= 12, "TMA buffer size exceed limit");
auto smem_ptr = smem_buffer + warp_id * (kNumStages * (kNumTMABufferBytes + 16) + kNumMetaBytes);
uint32_t tma_phase = 0;
auto tma_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int4*>(smem_ptr + i * (kNumTMABufferBytes + 16)); });
auto full_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_ptr + i * (kNumTMABufferBytes + 16) + kNumTMABufferBytes); });
auto meta_buffers = kUseLogFMT ? reinterpret_cast<nv_bfloat162*>(smem_ptr + kNumStages * (kNumTMABufferBytes + 16)) : nullptr;
EP_STATIC_ASSERT(kNumSendUnrolls * kNumStages <= 12, "TMA buffer size exceed limit");
// Initialize m-barriers
if (lane_id < kNumStages) {
mbarrier_init(tma_mbarrier[lane_id], 1);
mbarrier_init(full_barriers[lane_id], 1);
fence_view_async_shared();
fence_barrier_init();
}
__syncwarp();
constexpr int kNumIters = hidden_bf16_int4_pad / (32 * kNumUnrolls);
constexpr int kNumIters = hidden_bf16_int4_pad / (32 * kNumSendUnrolls);
auto tma_load_and_arrive = [&](const int& stage_idx, const int4* gmem_ptr, const int& num_bytes) {
tma_load_1d(tma_buffer[stage_idx], gmem_ptr, tma_mbarrier[stage_idx], num_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier[stage_idx], num_bytes);
tma_load_1d(tma_buffers[stage_idx], gmem_ptr, full_barriers[stage_idx], num_bytes);
mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_bytes);
};
auto get_num_tma_bytes = [&](const int& offset_int4) {
return min(kNumTMABufferBytes, static_cast<int>((hidden_bf16_int4 - offset_int4) * sizeof(int4)));
......@@ -500,6 +670,7 @@ combine(void* combined_x,
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
int num_send_bytes = hidden * sizeof(nv_bfloat16);
if (not zero_copy or dst_p2p_ptr != 0) {
// Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
......@@ -511,95 +682,57 @@ combine(void* combined_x,
tma_load_and_arrive(0, cpy_src_int4_ptr, get_num_tma_bytes(0));
__syncwarp();
int tma_offset_bytes = kNumMetaBytes;
#pragma unroll
for (int i = lane_id * kNumUnrolls, iter_idx = 0; i < hidden_bf16_int4_pad; i += 32 * kNumUnrolls, ++ iter_idx) {
// Read
int4 int4_values[kNumUnrolls] = {0};
auto uint32_values = reinterpret_cast<uint32_t*>(int4_values);
for (int i = lane_id * kNumSendUnrolls, iter_idx = 0; i < hidden_bf16_int4_pad; i += 32 * kNumSendUnrolls, ++ iter_idx) {
// Load the next iteration
// TODO: try `elect_one_sync`
const int& stage_idx = iter_idx % kNumStages;
const int& next_stage_idx = (iter_idx + 1) % kNumStages;
tma_store_wait<kNumStages - kNumPrefetch - 1>();
if (iter_idx + 1 < kNumIters and elect_one_sync(lane_id)) {
const auto& offset_int4 = i + 32 * kNumUnrolls;
tma_store_wait<kNumStages - kNumPrefetch - 1>();
const auto& offset_int4 = i + 32 * kNumSendUnrolls;
tma_load_and_arrive(next_stage_idx, cpy_src_int4_ptr + offset_int4, get_num_tma_bytes(offset_int4));
}
__syncwarp();
// Wait the current TMA arrival
mbarrier_wait(tma_mbarrier[stage_idx], tma_phase[stage_idx]);
const auto& uint32_buffer = reinterpret_cast<uint32_t*>(tma_buffer[stage_idx] + lane_id * kNumUnrolls);
// Simulated cast
EP_STATIC_ASSERT(kNumStages < 32, "Too many stages");
mbarrier_wait<true>(full_barriers[stage_idx], tma_phase, stage_idx);
if constexpr (kUseLogFMT) {
constexpr float kThreshold = 1;
constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
constexpr int kNumBits = 10;
constexpr int kNumValues = 1 << (kNumBits - 1);
EP_STATIC_ASSERT(kHidden % (kNumElemsPerInt4 * 32) == 0 and kNumElemsPerInt4 == 8, "Invalid hidden");
// Local log amax
float log_abs_values[kNumElemsPerInt4 * kNumUnrolls], log_amax, log_amin, amax;
auto log_aminmax = [&](const int &j, const float& value) {
log_abs_values[j] = log2f_approx(fabsf(value));
amax = j == 0 ? value : fmaxf(amax, fabsf(value));
log_amax = j == 0 ? log_abs_values[j] : fmaxf(log_amax, log_abs_values[j]);
log_amin = value != 0 ? (j == 0 ? log_abs_values[j] : fminf(log_amin, log_abs_values[j])) : log_amin;
};
#pragma unroll
for (int k = 0; k < kNumUnrolls * 4; ++ k) {
uint32_values[k] = uint32_buffer[k ^ (lane_id * kNumUnrolls / 8)];
auto bf162_values = *reinterpret_cast<__nv_bfloat162*>(uint32_values + k);
auto float2_values = __bfloat1622float2(bf162_values);
log_aminmax(k * 2, float2_values.x);
log_aminmax(k * 2 + 1, float2_values.y);
}
// Reduce per 128 channels
amax = warp_reduce_max<(16 / kNumUnrolls)>(amax);
log_amax = warp_reduce_max<(16 / kNumUnrolls)>(log_amax);
log_amin = fmaxf(warp_reduce_min<(16 / kNumUnrolls)>(log_amin), log_amax - kMinClip);
const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
const auto step_inv = 1.0f / step;
const auto rounding = 2.0f - log2f_approx((1.0f + exp2f_approx(step)) * 0.5f) * step_inv;
// Use LogFMT only with `amax <= kThreshold` (maybe not all quarter-warps)
if (amax <= kThreshold and log_amin < log_amax) {
// Transform
auto transform = [=](const float& log_abs_value) -> nv_bfloat16 {
const auto encoded = floorf((log_abs_value - log_amin) * step_inv + rounding);
const auto decoded = exp2f_approx((encoded - 1) * step + log_amin);
return decoded;
};
#pragma unroll
for (int k = 0; k < kNumUnrolls * 4; ++ k) {
auto bf162_pack = __nv_bfloat162(transform(log_abs_values[k * 2]), transform(log_abs_values[k * 2 + 1]));
auto uint32_pack = *reinterpret_cast<uint32_t*>(&bf162_pack);
uint32_buffer[k ^ (lane_id * kNumUnrolls / 8)] = (uint32_values[k] & 0x80008000) | uint32_pack;
}
}
tma_store_fence();
// Cast if possible
constexpr int kNumInt4PerDivision = 128 / kNumElemsPerInt4;
int num_tma_bytes = logfmt_encode<kNumSendUnrolls>(
tma_buffers[stage_idx],
// NOTES: only the leader lane will write the result
(i % kNumInt4PerDivision == 0) ? meta_buffers + i / kNumInt4PerDivision : nullptr,
lane_id);
if (elect_one_sync(lane_id))
tma_store_1d(tma_buffers[stage_idx], reinterpret_cast<uint8_t*>(cpy_dst_int4_ptr) + tma_offset_bytes, num_tma_bytes);
tma_offset_bytes += num_tma_bytes;
} else {
// BF16 original values
if (elect_one_sync(lane_id))
tma_store_1d(tma_buffers[stage_idx], cpy_dst_int4_ptr + i, get_num_tma_bytes(i));
}
__syncwarp();
}
// Store
// Store metadata (min/max values) for LogFMT
if constexpr (kUseLogFMT) {
num_send_bytes = tma_offset_bytes;
if (elect_one_sync(lane_id))
tma_store_1d(tma_buffer[stage_idx], cpy_dst_int4_ptr + i, get_num_tma_bytes(i));
__syncwarp();
}
tma_store_1d(meta_buffers, cpy_dst_int4_ptr, kNumMetaBytes);
}
// Flush all stores
tma_store_wait();
__syncwarp();
}
// Issue RDMA
// NOTES: for zero-copy mode, we assume the data is already in the send buffer
if (dst_p2p_ptr == 0)
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, local_expert_idx, lane_id, token_idx - offset);
}
// Put the finishing flag
......@@ -617,6 +750,14 @@ combine(void* combined_x,
atomic_add_release_global(atomic_clean_flag, -1);
}
__syncwarp();
// Destroy m-barriers
if (lane_id < kNumStages) {
mbarrier_inval(full_barriers[lane_id]);
fence_view_async_shared();
fence_barrier_init();
}
__syncwarp();
}
// Receiving phase
......@@ -639,43 +780,140 @@ combine(void* combined_x,
}
cg::this_grid().sync();
// Reduce tokens
EP_DEVICE_ASSERT(num_topk <= 32);
// Reassign warp groups
constexpr int kMaxNumGroups = 2;
const int num_decode_warps = hidden_bf16_int4_pad / (kNumRecvUnrolls * 32);
const int num_groups = min(kMaxNumGroups, (num_threads / 32) / (num_decode_warps + 1));
const int decode_warp_idx = __shfl_sync(0xffffffff, warp_id % (num_decode_warps + 1), 0);
const int group_idx = __shfl_sync(0xffffffff, warp_id / (num_decode_warps + 1), 0);
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
for (int hidden_idx = thread_id; hidden_idx < hidden_bf16_int4; hidden_idx += num_threads) {
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// Read top-k indices and weights
int reg_topk_idx[kNumMaxTopk];
float reg_topk_weights[kNumMaxTopk];
#pragma unroll
EP_DEVICE_ASSERT(num_topk <= 32);
EP_DEVICE_ASSERT(num_groups > 0);
if (group_idx < num_groups) {
constexpr int kNumStages = 3;
constexpr int kNumTMABufferBytes = 16 * 2 + kHidden * 2;
constexpr int kNumBF16PerWarpBytes = 32 * kNumRecvUnrolls * kNumElemsPerInt4 * 2;
constexpr int kNumLogFMTPerWarpBytes = kNumBF16PerWarpBytes / 16 * 10;
constexpr int kNumDivisionBytes = kNumDivisions * sizeof(uint32_t);
constexpr int kNumBytesPerGroup = kNumStages * kNumTMABufferBytes + kHidden * 2 + kNumStages * kNumDivisionBytes * 3;
// Reallocate shared memory
const auto smem_group_buffer = smem_buffer + kNumBytesPerGroup * group_idx;
auto full_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_group_buffer + i * kNumTMABufferBytes); });
auto empty_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_group_buffer + i * kNumTMABufferBytes + 8); });
auto tma_ld_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint8_t* >(smem_group_buffer + i * kNumTMABufferBytes + 16); });
auto tma_st_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint32_t*>(smem_group_buffer + kNumStages * kNumTMABufferBytes + i * kNumBF16PerWarpBytes); });
// Redundant when logfmt is disabled
const auto smem_group_ptr = smem_group_buffer + kNumStages * kNumTMABufferBytes + kHidden * 2;
auto log_amax_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<float*>(smem_group_ptr + i * kNumDivisionBytes); });
auto log_amin_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<float*>(smem_group_ptr + kNumStages * kNumDivisionBytes + i * kNumDivisionBytes); });
auto cast_info_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int*> (smem_group_ptr + kNumStages * kNumDivisionBytes * 2 + i * kNumDivisionBytes); });
uint32_t tma_phase = 0;
EP_STATIC_ASSERT(kNumStages < 32, "Too many stages");
if (decode_warp_idx == num_decode_warps)
tma_phase = (1 << kNumStages) - 1;
// Initialize m-barriers
if (decode_warp_idx == num_decode_warps and lane_id < kNumStages) {
mbarrier_init(full_barriers[lane_id], 1);
mbarrier_init(empty_barriers[lane_id], num_decode_warps);
}
asm volatile("bar.sync %0, %1;" :: "r"(group_idx + 1), "r"((num_decode_warps + 1) * 32));
int stage_idx = 0, topk_idx_by_lane = 0;
EP_STATIC_ASSERT(kNumMaxTopk <= 32, "Invalid number of topks");
if (decode_warp_idx == num_decode_warps) {
// TMA load warp
for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) {
if (lane_id < num_topk)
topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));
for (int i = 0; i < num_topk; ++ i) {
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
int topk_idx_reg = __shfl_sync(0xffffffff, topk_idx_by_lane, i);
if (topk_idx_reg < 0)
continue;
mbarrier_wait<true>(empty_barriers[stage_idx], tma_phase, stage_idx);
auto buffer = static_cast<uint8_t*>(rdma_recv_x) + (topk_idx_reg * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot;
if constexpr (kUseLogFMT) {
logfmt_check_amaxmin<kNumDivisions / 2, kNumSendUnrolls, kNumRecvUnrolls>(
buffer, reinterpret_cast<float2*>(log_amax_buffers[stage_idx]),
reinterpret_cast<float2*>(log_amin_buffers[stage_idx]), cast_info_buffers[stage_idx], lane_id);
}
if (elect_one_sync(lane_id)) {
int num_casted = 0;
if constexpr (kUseLogFMT) {
const auto& info = cast_info_buffers[stage_idx][num_decode_warps - 1];
num_casted = (info >> 1) + (info & 1);
}
int num_tma_bytes = num_casted * kNumLogFMTPerWarpBytes + (num_decode_warps - num_casted) * kNumBF16PerWarpBytes;
tma_load_1d(tma_ld_buffers[stage_idx], buffer + (kUseLogFMT ? kNumMetaBytes : 0), full_barriers[stage_idx], num_tma_bytes);
mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_tma_bytes);
}
__syncwarp();
stage_idx = (stage_idx + 1) % kNumStages;
}
}
} else {
// Reduction warps
float topk_weights_by_lane;
for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) {
if (lane_id < num_topk) {
topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));
topk_weights_by_lane = __ldg(topk_weights + token_idx * num_topk + lane_id);
}
__syncwarp();
float combined_values[kNumElemsPerInt4] = {0.0f};
#pragma unroll
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(static_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
// Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + hidden_idx);
const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
float combined_values[kNumElemsPerInt4 * kNumRecvUnrolls] = {0.0f};
for (int i = 0; i < num_topk; ++ i) {
if (__shfl_sync(0xffffffff, topk_idx_by_lane, i) < 0)
continue;
const auto& topk_weight = __shfl_sync(0xffffffff, topk_weights_by_lane, i);
mbarrier_wait<true>(full_barriers[stage_idx], tma_phase, stage_idx);
if constexpr (kUseLogFMT) {
const auto& info = cast_info_buffers[stage_idx][decode_warp_idx];
bool enable_cast = info & 1;
int num_casted_prefix = info >> 1;
int tma_offset = kNumLogFMTPerWarpBytes * num_casted_prefix + kNumBF16PerWarpBytes * (decode_warp_idx - num_casted_prefix);
int division_idx = decode_warp_idx * (kNumRecvUnrolls * 2) + lane_id * kNumRecvUnrolls / 16;
decode_and_accumulate<kNumRecvUnrolls>(
reinterpret_cast<uint32_t*>(tma_ld_buffers[stage_idx] + tma_offset + (enable_cast ? kNumLogFMTPerWarpBytes : kNumBF16PerWarpBytes) / 32 * lane_id),
combined_values, log_amax_buffers[stage_idx][division_idx], log_amin_buffers[stage_idx][division_idx], enable_cast, topk_weight
);
} else {
int tma_offset = kNumBF16PerWarpBytes * decode_warp_idx;
decode_and_accumulate<kNumRecvUnrolls>(
reinterpret_cast<uint32_t*>(tma_ld_buffers[stage_idx] + tma_offset + kNumBF16PerWarpBytes / 32 * lane_id),
combined_values, 0, 0, false, topk_weight
);
}
// Write results
int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
auto combined_bf16 = reinterpret_cast<nv_bfloat16*>(&combined_values);
if (elect_one_sync(lane_id))
mbarrier_arrive(empty_barriers[stage_idx]);
stage_idx = (stage_idx + 1) % kNumStages;
}
tma_store_wait<0>();
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
(static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[hidden_idx] = combined_int4;
for (int k = 0; k < kNumRecvUnrolls * 4; ++ k) {
auto combined_pack = __nv_bfloat162(combined_values[k * 2], combined_values[k * 2 + 1]);
tma_st_buffers[decode_warp_idx][kNumRecvUnrolls * 4 * lane_id + k] = *reinterpret_cast<uint32_t*>(&combined_pack);
}
tma_store_fence();
if (elect_one_sync(lane_id)) {
tma_store_1d(tma_st_buffers[decode_warp_idx],
static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4 + decode_warp_idx * kNumRecvUnrolls * 32,
kNumBF16PerWarpBytes);
}
__syncwarp();
}
}
// Flush all stores
tma_store_wait<0>();
}
}
......@@ -693,10 +931,11 @@ void combine(void* combined_x,
constexpr int kNumMaxTopk = 9;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 32 / num_warp_groups;
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm > 0);
const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_sms = ceil_div(num_experts, num_warp_groups);
const auto num_sms = max(ceil_div(num_experts, num_warp_groups), ceil_div(num_combined_tokens, num_recv_per_sm));
// Check workspace
auto atomic_clean_flag = static_cast<int*>(workspace);
......@@ -706,13 +945,26 @@ void combine(void* combined_x,
// Online cast cannot use zero-copy
EP_HOST_ASSERT(not (zero_copy and use_logfmt));
constexpr int kNumTMABytesPerWarp = 12 * (512 + 16);
const int smem_size = kNumTMABytesPerWarp * num_warps;
constexpr int kNumStages = 3;
constexpr int kNumMaxUnrolls = 4;
constexpr int kMaxNumGroups = 2;
// Send buffer size
const int num_meta_bytes = hidden / 128 * 4;
const int num_send_tma_bytes = 32 * sizeof(int4) * kNumMaxUnrolls + 16;
const int smem_send_size = num_warps * (kNumStages * num_send_tma_bytes + num_meta_bytes);
// Receive buffer size
const int num_recv_tma_bytes = 16 + hidden * 2;
const int smem_recv_size = kMaxNumGroups * (kNumStages * num_recv_tma_bytes + hidden * 2 + kNumStages * num_meta_bytes * 3);
// Total requirement
const int smem_size = max(smem_send_size, smem_recv_size);
#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = use_logfmt ? \
combine<true, hidden, kNumMaxTopk> : \
combine<false, hidden, kNumMaxTopk>; \
combine<true, hidden, kNumMaxTopk, kNumMaxUnrolls> : \
combine<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \
SET_SHARED_MEMORY_FOR_TMA(combine_func); \
LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
......
......@@ -322,8 +322,15 @@ __device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arriv
asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" :: "r"(arrive_count), "r"(mbar_int_ptr));
}
__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase) {
__device__ __forceinline__ void mbarrier_inval(uint64_t* mbar_ptr) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("mbarrier.inval.shared::cta.b64 [%0];" :: "r"(mbar_int_ptr));
}
template <bool kWithMultiStages = false>
__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase, int stage_idx = 0) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
const auto& wait = kWithMultiStages ? (phase >> stage_idx) & 1 : phase;
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
......@@ -331,8 +338,8 @@ __device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phas
"@P1 bra DONE; \n\t"
"bra LAB_WAIT; \n\t"
"DONE: \n\t"
"}" :: "r"(mbar_int_ptr), "r"(phase), "r"(0x989680));
phase ^= 1;
"}" :: "r"(mbar_int_ptr), "r"(wait), "r"(0x989680));
phase ^= kWithMultiStages ? (1 << stage_idx) : 1;
}
__device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) {
......@@ -340,6 +347,11 @@ __device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr
asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: "r"(num_bytes), "r"(mbar_int_ptr));
}
__device__ __forceinline__ void mbarrier_arrive(uint64_t* mbar_ptr) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0]; \n\t" :: "r"(mbar_int_ptr));
}
__device__ __forceinline__ void tma_store_fence() {
asm volatile ("fence.proxy.async.shared::cta;");
}
......@@ -518,36 +530,56 @@ __forceinline__ __device__ void release_lock(int* mutex) {
template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
template <typename T> struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } };
template <typename T> struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } };
// Unified reduction function
template <uint32_t kNumLanes, typename T, typename Op>
template <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
__forceinline__ __device__ T warp_reduce(T value, Op op) {
EP_STATIC_ASSERT(kNumLanes == 32 or kNumLanes == 16 or kNumLanes == 8 or
kNumLanes == 4 or kNumLanes == 2 or kNumLanes == 1,
EP_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
"Invalid number of lanes");
if constexpr (kNumLanes >= 32) value = op(value, __shfl_xor_sync(0xffffffff, value, 16));
if constexpr (kNumLanes >= 16) value = op(value, __shfl_xor_sync(0xffffffff, value, 8));
if constexpr (kNumLanes >= 8) value = op(value, __shfl_xor_sync(0xffffffff, value, 4));
if constexpr (kNumLanes >= 4) value = op(value, __shfl_xor_sync(0xffffffff, value, 2));
if constexpr (kNumLanes >= 2) value = op(value, __shfl_xor_sync(0xffffffff, value, 1));
constexpr uint32_t mask = 0xffffffff;
if constexpr (kIntergroupReduce) {
if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
} else {
if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
}
return value;
}
// Convenience aliases
template < uint32_t kNumLanes = 32, typename T>
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_sum(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceSum<T>{});
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
}
template <uint32_t kNumLanes = 32, typename T>
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_max(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceMax<T>{});
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceMax<T>{});
}
template <uint32_t kNumLanes = 32, typename T>
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_min(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceMin<T>{});
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceMin<T>{});
}
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_and(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceAnd<T>{});
}
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_or(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceOr<T>{});
}
} // namespace deep_ep
......@@ -27,7 +27,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
x_list = [x]
for i in range(4 if use_logfmt else 0):
# NOTES: make more LogFMT casts and also with some BF16
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random())
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
......@@ -39,7 +46,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for current_x in (x, x_pure_rand):
for current_x in x_list:
for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True):
for round_scale in (False, True) if dispatch_use_fp8 else (False, ):
......@@ -71,7 +78,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
# Check received data
if current_x is not x_pure_rand:
if current_x is x:
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
......@@ -104,7 +111,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < (7e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {zero_copy=}'
assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}'
hash_value ^= hash_tensor(combined_x)
# noinspection PyShadowingNames
......@@ -117,7 +124,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# noinspection PyShadowingNames
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x_pure_rand, topk_idx, num_tokens, num_experts,
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
......@@ -127,11 +134,12 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
num_combine_comm_bytes += (num_logfmt10_bytes if use_logfmt else num_bf16_bytes) * num_selections
# Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
......
......@@ -53,7 +53,7 @@ def per_token_cast_to_fp8(x: torch.Tensor):
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
if x_scales.dtype == torch.int:
x_scales = x_scales.view(dtype=torch.int8).to(torch.int) << 23
x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23
x_scales = x_scales.view(dtype=torch.float)
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
......@@ -171,6 +171,7 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
for _ in range(num_tests):
fn()
torch.cuda.synchronize()
prof.step()
# Parse the profiling table
......@@ -219,4 +220,4 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr
def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item()
return t.view(torch.int).sum().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment