Commit 0a47402f authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Code cleanup

parent b6516358
...@@ -138,12 +138,14 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in ...@@ -138,12 +138,14 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
ld_volatile_global, st_na_global); ld_volatile_global, st_na_global);
} }
} }
__syncthreads(); __syncthreads();
// Wait previous operations to be finished
if (thread_id < kNumRDMARanks and thread_id != rdma_rank) if (thread_id < kNumRDMARanks and thread_id != rdma_rank)
nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank), 0); nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank), 0);
__syncthreads(); __syncthreads();
// Barrier
if (thread_id == 0) if (thread_id == 0)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team); nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads(); __syncthreads();
......
...@@ -499,9 +499,9 @@ combine(void* combined_x, ...@@ -499,9 +499,9 @@ combine(void* combined_x,
cg::this_grid().sync(); cg::this_grid().sync();
// Reduce tokens // Reduce tokens
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= 1024); EP_DEVICE_ASSERT(num_topk <= 32);
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization"); EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
for (int k = thread_id; k < hidden_bf16_int4; k += num_threads) { for (int hidden_idx = thread_id; hidden_idx < hidden_bf16_int4; hidden_idx += num_threads) {
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) { for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// Read top-k indices and weights // Read top-k indices and weights
int reg_topk_idx[kNumMaxTopk]; int reg_topk_idx[kNumMaxTopk];
...@@ -520,7 +520,7 @@ combine(void* combined_x, ...@@ -520,7 +520,7 @@ combine(void* combined_x,
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type); auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
// Reduce // Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + k); auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + hidden_idx);
const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec); const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j) for (int j = 0; j < kNumElemsPerInt4; ++ j)
...@@ -533,7 +533,7 @@ combine(void* combined_x, ...@@ -533,7 +533,7 @@ combine(void* combined_x,
#pragma unroll #pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j) for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]); combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
(static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[k] = combined_int4; (static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[hidden_idx] = combined_int4;
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment