Commit c6faca45 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Code lint

parent c7033854
...@@ -1105,12 +1105,13 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1105,12 +1105,13 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(num_experts % num_ranks == 0); EP_HOST_ASSERT(num_experts % num_ranks == 0);
// Diagnosis tensors
if (cumulative_local_expert_recv_stats.has_value()) { if (cumulative_local_expert_recv_stats.has_value()) {
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous()); EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous());
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks); EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);
} }
if (dispatch_wait_recv_cost_stats.has_value()) { if (dispatch_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64); EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous()); EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous());
......
...@@ -281,12 +281,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -281,12 +281,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
shared_num_recv_tokens[warp_group_id] = num_recv_tokens; shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx); recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
// Add stats for diagnosis
if (cumulative_local_expert_recv_stats != nullptr) if (cumulative_local_expert_recv_stats != nullptr)
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens); atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
if (dispatch_wait_recv_cost_stats != nullptr) if (dispatch_wait_recv_cost_stats != nullptr)
atomicAdd(reinterpret_cast<unsigned long long*>(dispatch_wait_recv_cost_stats + src_rank), atomicAdd(reinterpret_cast<unsigned long long*>(dispatch_wait_recv_cost_stats + src_rank), wait_recv_cost);
wait_recv_cost);
} }
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32)); asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
num_recv_tokens = shared_num_recv_tokens[warp_group_id]; num_recv_tokens = shared_num_recv_tokens[warp_group_id];
...@@ -631,9 +631,10 @@ combine(void* combined_x, ...@@ -631,9 +631,10 @@ combine(void* combined_x,
auto start_time = clock64(); auto start_time = clock64();
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0); while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
auto wait_recv_cost = clock64() - start_time; auto wait_recv_cost = clock64() - start_time;
if (combine_wait_recv_cost_stats != nullptr) if (combine_wait_recv_cost_stats != nullptr) {
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats const auto& src_rank = responsible_expert_idx / num_local_experts;
+ responsible_expert_idx / num_local_experts), wait_recv_cost); atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
}
} }
} }
cg::this_grid().sync(); cg::this_grid().sync();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment