Unverified Commit c874cb7a authored by Zhean Xu's avatar Zhean Xu Committed by GitHub
Browse files

Optimize low latency combine send with TMA (#299)



* feat: low latency combine inplace TMA

* optimize tma pointer with PatternVisitor

* Minor cleanup

* Add `elect_one_sync`

---------
Co-authored-by: default avatarZhean Xu <xza@deepseek.com>
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent 898269fa
......@@ -407,7 +407,11 @@ combine(void* combined_x,
// Data type staffs
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
constexpr int64_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
constexpr int kNumUnrolls = 4;
constexpr int hidden_bf16_int4_pad = align(static_cast<int>(hidden_bf16_int4), 32 * kNumUnrolls);
EP_STATIC_ASSERT(hidden_bf16_int4 % kNumUnrolls == 0, "Invalid hidden");
EP_STATIC_ASSERT(kNumUnrolls == 1 or kNumUnrolls == 2 or kNumUnrolls == 4, "Invalid unrolling factors");
// Message package
constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16);
......@@ -445,6 +449,36 @@ combine(void* combined_x,
int offset, num_tokens_to_send;
unpack2(layout, num_tokens_to_send, offset);
// TMA stuffs
constexpr int kNumTMABufferBytes = sizeof(int4) * 32 * kNumUnrolls;
constexpr int kNumStages = 3;
constexpr int kNumPrefetch = 1;
EP_STATIC_ASSERT(kNumStages == 3 and kNumPrefetch == 1, "Invalid stages");
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto smem_ptr = smem_buffer + warp_id * kNumStages * (kNumTMABufferBytes + 16);
uint32_t tma_phase[kNumStages] = {};
auto tma_buffer = PatternVisitor([=](const int& i) { return reinterpret_cast<int4*>(smem_ptr + i * (kNumTMABufferBytes + 16)); });
auto tma_mbarrier = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_ptr + i * (kNumTMABufferBytes + 16) + kNumTMABufferBytes); });
EP_STATIC_ASSERT(kNumUnrolls * kNumStages <= 12, "TMA buffer size exceed limit");
// Initialize m-barriers
if (lane_id < kNumStages) {
mbarrier_init(tma_mbarrier[lane_id], 1);
fence_view_async_shared();
fence_barrier_init();
}
__syncwarp();
constexpr int kNumIters = hidden_bf16_int4_pad / (32 * kNumUnrolls);
auto tma_load_and_arrive = [&](const int& stage_idx, const int4* gmem_ptr, const int& num_bytes) {
tma_load_1d(tma_buffer[stage_idx], gmem_ptr, tma_mbarrier[stage_idx], num_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier[stage_idx], num_bytes);
};
auto get_num_tma_bytes = [&](const int& offset_int4) {
return min(kNumTMABufferBytes, static_cast<int>((hidden_bf16_int4 - offset_int4) * sizeof(int4)));
};
// Issue IBGDA send
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
......@@ -458,24 +492,36 @@ combine(void* combined_x,
const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
if (not zero_copy or dst_p2p_ptr != 0) {
constexpr int kNumUnrolls = 4;
constexpr int hidden_bf16_int4_pad = align(static_cast<int>(hidden_bf16_int4), 32 * kNumUnrolls);
// Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
const auto cpy_src_int4_ptr = zero_copy ? reinterpret_cast<int4*>(buf_ptr) : x_int4;
const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr) : reinterpret_cast<int4*>(dst_p2p_ptr);
// Prefetch
if (elect_one_sync(lane_id))
tma_load_and_arrive(0, cpy_src_int4_ptr, get_num_tma_bytes(0));
__syncwarp();
#pragma unroll
for (int i = lane_id * kNumUnrolls; i < hidden_bf16_int4_pad; i += 32 * kNumUnrolls) {
for (int i = lane_id * kNumUnrolls, iter_idx = 0; i < hidden_bf16_int4_pad; i += 32 * kNumUnrolls, ++ iter_idx) {
// Read
int4 int4_values[kNumUnrolls];
if (i < hidden_bf16_int4) {
#pragma unroll
for (int k = 0; k < kNumUnrolls; ++ k)
int4_values[k] = ld_nc_global(cpy_src_int4_ptr + i + k);
}
auto bf16_values = reinterpret_cast<nv_bfloat16*>(int4_values);
int4 int4_values[kNumUnrolls] = {0};
auto uint32_values = reinterpret_cast<uint32_t*>(int4_values);
// Load the next iteration
// TODO: try `elect_one_sync`
const int& stage_idx = iter_idx % kNumStages;
const int& next_stage_idx = (iter_idx + 1) % kNumStages;
tma_store_wait<kNumStages - kNumPrefetch - 1>();
if (iter_idx + 1 < kNumIters and elect_one_sync(lane_id)) {
const auto& offset_int4 = i + 32 * kNumUnrolls;
tma_load_and_arrive(next_stage_idx, cpy_src_int4_ptr + offset_int4, get_num_tma_bytes(offset_int4));
}
__syncwarp();
// Wait the current TMA arrival
mbarrier_wait(tma_mbarrier[stage_idx], tma_phase[stage_idx]);
const auto& uint32_buffer = reinterpret_cast<uint32_t*>(tma_buffer[stage_idx] + lane_id * kNumUnrolls);
// Simulated cast
if constexpr (kUseLogFMT) {
constexpr float kThreshold = 1;
......@@ -486,13 +532,19 @@ combine(void* combined_x,
// Local log amax
float log_abs_values[kNumElemsPerInt4 * kNumUnrolls], log_amax, log_amin, amax;
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4 * kNumUnrolls; ++ j) {
auto value = static_cast<float>(bf16_values[j]);
auto log_aminmax = [&](const int &j, const float& value) {
log_abs_values[j] = log2f_approx(fabsf(value));
amax = j == 0 ? value : fmaxf(amax, fabsf(value));
log_amax = j == 0 ? log_abs_values[j] : fmaxf(log_amax, log_abs_values[j]);
log_amin = value != 0 ? (j == 0 ? log_abs_values[j] : fminf(log_amin, log_abs_values[j])) : log_amin;
};
#pragma unroll
for (int k = 0; k < kNumUnrolls * 4; ++ k) {
uint32_values[k] = uint32_buffer[k ^ (lane_id * kNumUnrolls / 8)];
auto bf162_values = *reinterpret_cast<__nv_bfloat162*>(uint32_values + k);
auto float2_values = __bfloat1622float2(bf162_values);
log_aminmax(k * 2, float2_values.x);
log_aminmax(k * 2 + 1, float2_values.y);
}
// Reduce per 128 channels
......@@ -513,25 +565,27 @@ combine(void* combined_x,
return decoded;
};
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4 * kNumUnrolls; j += 2) {
auto bf162_pack = __nv_bfloat162(transform(log_abs_values[j]), transform(log_abs_values[j + 1]));
for (int k = 0; k < kNumUnrolls * 4; ++ k) {
auto bf162_pack = __nv_bfloat162(transform(log_abs_values[k * 2]), transform(log_abs_values[k * 2 + 1]));
auto uint32_pack = *reinterpret_cast<uint32_t*>(&bf162_pack);
uint32_values[j / 2] = (uint32_values[j / 2] & 0x80008000) | uint32_pack;
uint32_buffer[k ^ (lane_id * kNumUnrolls / 8)] = (uint32_values[k] & 0x80008000) | uint32_pack;
}
}
__syncwarp();
tma_store_fence();
}
__syncwarp();
// Store
EP_STATIC_ASSERT(hidden_bf16_int4 % kNumUnrolls == 0, "Invalid hidden");
if (i < hidden_bf16_int4) {
#pragma unroll
for (int k = 0; k < kNumUnrolls; ++ k)
st_na_global(cpy_dst_int4_ptr + i + k, int4_values[k]);
}
if (elect_one_sync(lane_id))
tma_store_1d(tma_buffer[stage_idx], cpy_dst_int4_ptr + i, get_num_tma_bytes(i));
__syncwarp();
}
}
// Flush all stores
tma_store_wait();
__syncwarp();
// Issue RDMA
// NOTES: for zero-copy mode, we assume the data is already in the send buffer
if (dst_p2p_ptr == 0)
......@@ -635,10 +689,14 @@ void combine(void* combined_x,
// Online cast cannot use zero-copy
EP_HOST_ASSERT(not (zero_copy and use_logfmt));
constexpr int kNumTMABytesPerWarp = 12 * (512 + 16);
const int smem_size = kNumTMABytesPerWarp * num_warps;
#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = use_logfmt ? \
combine<true, hidden, kNumMaxTopk> : \
combine<false, hidden, kNumMaxTopk>; \
SET_SHARED_MEMORY_FOR_TMA(combine_func); \
LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
......
......@@ -30,6 +30,19 @@ template<> struct VecInt<4> { using vec_t = int; };
template<> struct VecInt<8> { using vec_t = int64_t; };
template<> struct VecInt<16> { using vec_t = int4; };
template <typename FuncT>
struct PatternVisitor {
FuncT func;
__device__ __host__
explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
__device__ __host__
auto operator [](const uint32_t& i) {
return func(i);
}
};
__device__ __forceinline__ void trap() {
asm("trap;");
}
......@@ -281,6 +294,25 @@ __device__ __forceinline__ float exp2f_approx(const float &x) {
// TMA PTX instructions
#ifndef DISABLE_SM90_FEATURES
__device__ __forceinline__ uint32_t elect_one_sync(int lane_id) {
#ifndef DISABLE_SM90_FEATURES
uint32_t pred = 0;
asm volatile(
"{\n"
".reg .b32 %%rx;\n"
".reg .pred %%px;\n"
" elect.sync %%rx|%%px, %2;\n"
"@%%px mov.s32 %1, 1;\n"
" mov.s32 %0, %%rx;\n"
"}\n"
: "+r"(lane_id), "+r"(pred)
: "r"(0xffffffff));
return pred;
#else
return lane_id == 0;
#endif
}
__device__ __forceinline__ void fence_view_async_shared() {
asm volatile("fence.proxy.async.shared::cta; \n" :: );
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment