layout.cu 6.34 KB
Newer Older
1
2
3
4
5
6
7
8
#include "configs.cuh"
#include "launch.cuh"

namespace deep_ep {

namespace layout {

template <int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
lijian6's avatar
lijian6 committed
9
10
11
12
13
__global__ void get_dispatch_layout(const int64_t *topk_idx, int *num_tokens_per_rank,
                                    int *num_tokens_per_rdma_rank, int *num_tokens_per_expert,
                                    bool *is_token_in_rank, int num_tokens, int num_topk,
                                    int num_ranks, int num_experts) {
    auto sm_id     = static_cast<int>(blockIdx.x);
14
15
16
17
    auto thread_id = static_cast<int>(threadIdx.x);

    // Count expert statistics
    __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM];
lijian6's avatar
lijian6 committed
18
19
    int            expert_begin_idx = sm_id * kNumExpertsPerSM,
        expert_end_idx              = min(expert_begin_idx + kNumExpertsPerSM, num_experts);
20
    if (expert_begin_idx < expert_end_idx) {
lijian6's avatar
lijian6 committed
21
22
23
// Per-thread count
#pragma unroll
        for (int i = 0; i < kNumExpertsPerSM; ++i)
24
            num_tokens_per_expert_per_thread[thread_id][i] = 0;
lijian6's avatar
lijian6 committed
25
#pragma unroll
26
27
        for (int i = thread_id; i < num_tokens; i += kNumThreads) {
            auto shifted_topk_idx = topk_idx + i * num_topk;
lijian6's avatar
lijian6 committed
28
29
#pragma unroll
            for (int j = 0, expert_idx; j < num_topk; ++j) {
30
31
                expert_idx = static_cast<int>(shifted_topk_idx[j]);
                if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx)
lijian6's avatar
lijian6 committed
32
                    ++num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx];
33
34
35
36
37
38
39
40
            }
        }
        __syncthreads();

        // Sum up
        EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM");
        if (expert_begin_idx + thread_id < expert_end_idx) {
            int sum = 0;
lijian6's avatar
lijian6 committed
41
42
#pragma unroll
            for (int i = 0; i < kNumThreads; ++i)
43
44
45
46
47
48
49
50
51
52
                sum += num_tokens_per_expert_per_thread[i][thread_id];
            num_tokens_per_expert[expert_begin_idx + thread_id] = sum;
        }
        return;
    }

    if (num_tokens_per_rdma_rank != nullptr)
        EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS);

    // Count rank statistics
lijian6's avatar
lijian6 committed
53
    constexpr int  kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS;
54
55
    __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM];
    __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM];
lijian6's avatar
lijian6 committed
56
57
58
59
60
    auto           sm_begin       = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM;
    int            rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM,
        rank_end_idx              = min(rank_begin_idx + kNumRanksPerSM, num_ranks);
    int rdma_rank_begin_idx       = rank_begin_idx / NUM_MAX_NVL_PEERS,
        rdma_rank_end_idx         = rank_end_idx / NUM_MAX_NVL_PEERS;
61
62
    if (rank_begin_idx < rank_end_idx) {
        const auto num_expert_per_rank = num_experts / num_ranks;
lijian6's avatar
lijian6 committed
63
64
        auto       expert_begin        = rank_begin_idx * num_expert_per_rank;
        auto       expert_end          = rank_end_idx * num_expert_per_rank;
65

lijian6's avatar
lijian6 committed
66
67
68
// Per-thread count
#pragma unroll
        for (int i = 0; i < kNumRanksPerSM; ++i)
69
            num_tokens_per_rank_per_thread[thread_id][i] = 0;
lijian6's avatar
lijian6 committed
70
71
#pragma unroll
        for (int i = 0; i < kNumRDMARanksPerSM; ++i)
72
            num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0;
lijian6's avatar
lijian6 committed
73
#pragma unroll
74
        for (int i = thread_id; i < num_tokens; i += kNumThreads) {
lijian6's avatar
lijian6 committed
75
76
77
            auto shifted_topk_idx           = topk_idx + i * num_topk;
            int  is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0};
#pragma unroll
78
79
80
81
82
            for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) {
                expert_idx = static_cast<int>(shifted_topk_idx[j]);
                if (expert_begin <= expert_idx and expert_idx < expert_end) {
                    // Count single rank
                    rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx;
lijian6's avatar
lijian6 committed
83
                    is_in_rank[rank_idx]++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS]++;
84
85
86
87
                }
            }

            auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks;
lijian6's avatar
lijian6 committed
88
89
#pragma unroll
            for (int j = 0; j + rank_begin_idx < rank_end_idx; ++j) {
90
91
92
93
                shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0);
                num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0);
            }

lijian6's avatar
lijian6 committed
94
95
#pragma unroll
            for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++j)
96
97
98
99
100
101
102
103
                num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0);
        }
        __syncthreads();

        // Sum up
        EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM");
        if (rank_begin_idx + thread_id < rank_end_idx) {
            int sum = 0;
lijian6's avatar
lijian6 committed
104
105
#pragma unroll
            for (int i = 0; i < kNumThreads; ++i)
106
107
108
109
                sum += num_tokens_per_rank_per_thread[i][thread_id];
            num_tokens_per_rank[rank_begin_idx + thread_id] = sum;
        }

lijian6's avatar
lijian6 committed
110
111
        if (num_tokens_per_rdma_rank != nullptr and
            rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) {
112
            int sum = 0;
lijian6's avatar
lijian6 committed
113
114
#pragma unroll
            for (int i = 0; i < kNumThreads; ++i)
115
116
117
118
119
120
                sum += num_tokens_per_rdma_rank_per_thread[i][thread_id];
            num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum;
        }
    }
}

lijian6's avatar
lijian6 committed
121
122
123
124
void get_dispatch_layout(const int64_t *topk_idx, int *num_tokens_per_rank,
                         int *num_tokens_per_rdma_rank, int *num_tokens_per_expert,
                         bool *is_token_in_rank, int num_tokens, int num_topk, int num_ranks,
                         int num_experts, hipStream_t stream) {
125
    constexpr int kNumThreads = 256, kNumExpertsPerSM = 4, kNumRanksPerSM = 8;
lijian6's avatar
lijian6 committed
126
127
    int           num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) +
                  (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
128
    EP_STATIC_ASSERT(kNumRanksPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of ranks per SM");
129
130
131

    SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
    LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
lijian6's avatar
lijian6 committed
132
133
                  topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert,
                  is_token_in_rank, num_tokens, num_topk, num_ranks, num_experts);
134
135
136
137
}

} // namespace layout

lijian6's avatar
lijian6 committed
138
} // namespace primus_turbo::deep_ep