Unverified Commit abba6add authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Fix hidden_size % 128 != 0 in intranode kernels (#413)

* Fix hidden_size % 128 != 0

* Add `align_down()` function

* Use the full warp to wait TMA store

* Support arbitrary hidden sizes in fp8 cast

* lint
parent 2012e310
......@@ -11,10 +11,15 @@ dtype_t ceil_div(dtype_t a, dtype_t b) {
}
template <typename dtype_t>
dtype_t align(dtype_t a, dtype_t b) {
dtype_t align_up(dtype_t a, dtype_t b) {
return ceil_div<dtype_t>(a, b) * b;
}
template <typename dtype_t>
dtype_t align_down(dtype_t a, dtype_t b) {
return a / b * b;
}
struct Config {
int num_sms;
int num_max_nvl_chunked_send_tokens;
......@@ -36,7 +41,7 @@ struct Config {
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0);
// Ceil up RDMA buffer size
this->num_max_rdma_chunked_recv_tokens = align<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
this->num_max_rdma_chunked_recv_tokens = align_up<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
// NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2);
......@@ -160,7 +165,7 @@ struct LowLatencyLayout {
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = align<size_t>(signaling_buffer_bytes, 128);
size_t signaling_buffer_bytes_aligned = align_up<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2;
// Assign pointers
......
......@@ -43,7 +43,7 @@ int get_source_meta_bytes() {
__host__ __device__ __forceinline__
int get_num_bytes_per_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) {
return static_cast<int>(align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4)));
return static_cast<int>(align_up(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4)));
}
__host__ __device__ __forceinline__
......@@ -1516,8 +1516,8 @@ combine(int4* combined_x, float* combined_topk_weights,
// Load data
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token;
auto shifted_x = x + token_idx * hidden_int4;
if (elect_one_sync()) {
tma_store_wait<0>();
if (elect_one_sync()) {
tma_load_1d(tma_buffer, shifted_x, tma_mbarrier, hidden_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, hidden_bytes);
}
......
......@@ -263,7 +263,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = align<int>(num_scales, sizeof(float) / sizeof(scale_t));
const auto num_aligned_scales = align_up<int>(num_scales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
// Shared between sub-warps in warp groups
......@@ -584,7 +584,7 @@ combine(void* combined_x,
// Use different unroll factors for send and recv phases
constexpr int kNumSendUnrolls = kHidden % (32 * 4 * sizeof(int4) / sizeof(nv_bfloat16)) == 0 ? 4 : 2;
constexpr int kNumRecvUnrolls = 2;
constexpr int hidden_bf16_int4_pad = align(static_cast<int>(hidden_bf16_int4), 32 * kNumSendUnrolls);
constexpr int hidden_bf16_int4_pad = align_up(static_cast<int>(hidden_bf16_int4), 32 * kNumSendUnrolls);
EP_STATIC_ASSERT(kHidden % (32 * 2 * sizeof(int4) / sizeof(nv_bfloat16)) == 0, "Invalid hidden");
EP_STATIC_ASSERT(kNumSendUnrolls <= kNumMaxUnrolls and kNumRecvUnrolls <= kNumMaxUnrolls, "Invalid unrolls");
EP_STATIC_ASSERT(hidden_bf16_int4 % kNumSendUnrolls == 0, "Invalid hidden");
......
......@@ -399,10 +399,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
#ifndef DISABLE_SM90_FEATURES
if (elect_one_sync()) {
#pragma unroll
for (int i = 0; i < 2; ++ i) {
tma_store_wait<0>();
if (elect_one_sync()) {
tma_load_1d(tma_buffer, shifted_buffer_x_int4 + i * half_hidden_int4, tma_mbarrier, half_hidden_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, half_hidden_bytes);
mbarrier_wait(tma_mbarrier, tma_phase);
......@@ -589,6 +589,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);
int hidden_int4_aligned = align_down(hidden_int4, 32);
auto x_int4 = reinterpret_cast<const int4*>(x);
auto bias_0_int4 = reinterpret_cast<const int4*>(bias_0);
auto bias_1_int4 = reinterpret_cast<const int4*>(bias_1);
......@@ -791,7 +792,6 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// Wait shared memory release
#ifndef DISABLE_SM90_FEATURES
if (elect_one_sync())
tma_store_wait<0>();
__syncwarp();
#endif
......@@ -837,8 +837,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
out_dtypes[j] = static_cast<dtype_t>(values[j]);
#ifndef DISABLE_SM90_FEATURES
if (i < hidden_int4_aligned) {
// Wait TMA arrival
if (elect_one_sync())
tma_store_wait<kNumStages - 1>();
__syncwarp();
......@@ -855,8 +855,11 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
}
__syncwarp();
#else
} else {
#endif
recv_int4[token_idx * hidden_int4 + i] = out_int4;
#ifndef DISABLE_SM90_FEATURES
}
#endif
}
......
......@@ -408,10 +408,15 @@ __host__ __device__ constexpr dtype_t ceil_div(dtype_t a, dtype_t b) {
}
template <typename dtype_t>
__host__ __device__ constexpr dtype_t align(dtype_t a, dtype_t b) {
__host__ __device__ constexpr dtype_t align_up(dtype_t a, dtype_t b) {
return ceil_div<dtype_t>(a, b) * b;
}
template <typename dtype_t>
__host__ __device__ constexpr dtype_t align_down(dtype_t a, dtype_t b) {
return a / b * b;
}
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
int& token_start_idx, int& token_end_idx) {
int num_tokens_per_sm = ceil_div(num_tokens, num_sms);
......
......@@ -43,23 +43,34 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
return (1 - sim).item()
def align_up(x, y):
return (x + y - 1) // y * y
def per_token_cast_to_fp8(x: torch.Tensor):
assert x.dim() == 2 and x.size(1) % 128 == 0
assert x.dim() == 2
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
aligned_n = align_up(n, 128)
x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0)
x_padded_view = x_padded.view(m, -1, 128)
x_amax = x_padded_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_padded_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, aligned_n)[:, :n].contiguous(), (x_amax / 448.0).view(m, -1)
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
if x_fp8.numel() == 0:
return x_fp8.to(torch.bfloat16)
assert x_fp8.dim() == 2
m, n = x_fp8.shape
aligned_n = align_up(n, 128)
x_fp8_padded = torch.nn.functional.pad(x_fp8, (0, aligned_n - n), mode='constant', value=0)
if x_scales.dtype == torch.int:
x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23
x_scales = x_scales.view(dtype=torch.float)
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_fp32_padded = x_fp8_padded.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
return (x_fp32_padded * x_scales).view(x_fp8_padded.shape).to(torch.bfloat16)[:,:n].contiguous()
def inplace_unique(x: torch.Tensor, num_slots: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment