"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/flashmla.git" did not exist on "a1eef562a6fc6ed135df9dbd91a54dbb2e727060"
Commit b90320e2 authored by alpha-baby's avatar alpha-baby
Browse files

enhance warp copy

parent 7de7464e
...@@ -125,14 +125,12 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in ...@@ -125,14 +125,12 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
// Issue send // Issue send
// TODO: more light fence or barrier or signaling // TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning // TODO: overlap EP barrier and NVL cleaning
for (int i = 0; i < kNumRDMARanks; ++i) { for (int i = warp_id; i < kNumRDMARanks; i+=num_threads/32) {
if (i != rdma_rank) { if (i != rdma_rank) {
if (warp_id == 0) { nvshmemi_ibgda_put_nbi_warp<true>(reinterpret_cast<uint64_t>(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank)),
nvshmemi_ibgda_put_nbi_warp<true>(reinterpret_cast<uint64_t>(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank)), reinterpret_cast<uint64_t>(rdma_recv_num_tokens_mixed.send_buffer(i)),
reinterpret_cast<uint64_t>(rdma_recv_num_tokens_mixed.send_buffer(i)), (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * sizeof(int),
(NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * sizeof(int), translate_dst_rdma_rank<kLowLatencyMode>(i, nvl_rank), 0, lane_id, 0);
translate_dst_rdma_rank<kLowLatencyMode>(i, nvl_rank), 0, lane_id, 0);
}
} else { } else {
UNROLLED_WARP_COPY(1, lane_id, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, UNROLLED_WARP_COPY(1, lane_id, NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
...@@ -140,6 +138,8 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in ...@@ -140,6 +138,8 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
ld_volatile_global, st_na_global); ld_volatile_global, st_na_global);
} }
} }
__syncthreads();
if (thread_id < kNumRDMARanks and thread_id != rdma_rank) if (thread_id < kNumRDMARanks and thread_id != rdma_rank)
nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank), 0); nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank), 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment