config.hpp 9.89 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
#pragma once

lijian6's avatar
lijian6 committed
3
4
#include "kernels/api.cuh"
#include "kernels/configs.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
5
6
7
8
9
10
11
12
13
14
15
#include "kernels/exception.cuh"

namespace deep_ep {

struct Config {
    int num_sms;
    int num_max_nvl_chunked_send_tokens;
    int num_max_nvl_chunked_recv_tokens;
    int num_max_rdma_chunked_send_tokens;
    int num_max_rdma_chunked_recv_tokens;

lijian6's avatar
lijian6 committed
16
17
18
19
20
21
    Config(int num_sms, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
           int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens)
        : num_sms(num_sms), num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
          num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
          num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
          num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
Chenggang Zhao's avatar
Chenggang Zhao committed
22
        EP_HOST_ASSERT(num_sms >= 0);
lijian6's avatar
lijian6 committed
23
24
        EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and
                           num_max_nvl_chunked_recv_tokens > 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
25
        EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
lijian6's avatar
lijian6 committed
26
27
        EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and
                           num_max_rdma_chunked_recv_tokens > 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
28
29

        // Ceil up RDMA buffer size
lijian6's avatar
lijian6 committed
30
31
        this->num_max_rdma_chunked_recv_tokens =
            ALIGN<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
Chenggang Zhao's avatar
Chenggang Zhao committed
32
        EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
lijian6's avatar
lijian6 committed
33
34
35
36
        // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always
        // have space to push
        EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <=
                           num_max_rdma_chunked_recv_tokens / 2);
Chenggang Zhao's avatar
Chenggang Zhao committed
37
38
39
40
41
    }

    size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
        // Below are some assumptions
        // TODO: add assertions
lijian6's avatar
lijian6 committed
42
        constexpr int kNumMaxTopK   = 128;
Chenggang Zhao's avatar
Chenggang Zhao committed
43
44
45
46
        constexpr int kNumMaxScales = 128;
        EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
        EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
        const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
lijian6's avatar
lijian6 committed
47
48
        const auto num_nvl_ranks  = std::min(num_ranks, NUM_MAX_NVL_PEERS);
        const int  num_channels   = num_sms / 2;
Chenggang Zhao's avatar
Chenggang Zhao committed
49
50
51
52

        size_t num_bytes = 0;
        num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
        num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
lijian6's avatar
lijian6 committed
53
54
55
#ifndef DISABLE_ROCSHMEM
        num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
                     internode::get_source_meta_bytes();
56
#endif
lijian6's avatar
lijian6 committed
57
58
59
60
61
62
        num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
                     sizeof(int64_t);
        num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
                     sizeof(float);
        num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
                     kNumMaxScales * sizeof(float);
Chenggang Zhao's avatar
Chenggang Zhao committed
63
64
65
66
67
        num_bytes = ((num_bytes + 127) / 128) * 128;
        return num_bytes;
    }

    size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
lijian6's avatar
lijian6 committed
68
#ifndef DISABLE_ROCSHMEM
Chenggang Zhao's avatar
Chenggang Zhao committed
69
70
71
72
73
74
        // Legacy mode
        if (num_ranks <= NUM_MAX_NVL_PEERS)
            return 0;

        // Below are some assumptions
        // TODO: add assertions
lijian6's avatar
lijian6 committed
75
        constexpr int kNumMaxTopK   = 128;
Chenggang Zhao's avatar
Chenggang Zhao committed
76
77
78
79
        constexpr int kNumMaxScales = 128;
        EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
        EP_HOST_ASSERT(num_sms % 2 == 0);
        const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
lijian6's avatar
lijian6 committed
80
        const int num_channels   = num_sms / 2;
Chenggang Zhao's avatar
Chenggang Zhao committed
81
82
83

        size_t num_bytes = 0;
        num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
lijian6's avatar
lijian6 committed
84
85
86
87
88
89
90
91
92
93
94
95
        num_bytes +=
            num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
        num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
                     internode::get_source_meta_bytes() * 2;
        num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
                     kNumMaxTopK * sizeof(int64_t) * 2;
        num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
                     kNumMaxTopK * sizeof(float) * 2;
        num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
                     kNumMaxScales * sizeof(float) * 2;
        num_bytes +=
            num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
Chenggang Zhao's avatar
Chenggang Zhao committed
96
97
        num_bytes = ((num_bytes + 127) / 128) * 128;
        return num_bytes;
98
#else
lijian6's avatar
lijian6 committed
99
100
        EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install "
                                  "rocSHMEM by following docs/install_dependencies.md");
101
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
102
103
104
105
106
107
    }
};

struct LowLatencyBuffer {
    int num_clean_int = 0;

lijian6's avatar
lijian6 committed
108
109
110
    void *dispatch_rdma_send_buffer       = nullptr;
    void *dispatch_rdma_recv_data_buffer  = nullptr;
    int  *dispatch_rdma_recv_count_buffer = nullptr;
Chenggang Zhao's avatar
Chenggang Zhao committed
111

lijian6's avatar
lijian6 committed
112
113
114
    void *combine_rdma_send_buffer      = nullptr;
    void *combine_rdma_recv_data_buffer = nullptr;
    int  *combine_rdma_recv_flag_buffer = nullptr;
Chenggang Zhao's avatar
Chenggang Zhao committed
115

lijian6's avatar
lijian6 committed
116
117
    void  *combine_rdma_send_buffer_data_start = nullptr;
    size_t num_bytes_per_combine_msg           = 0;
118

lijian6's avatar
lijian6 committed
119
    std::pair<int *, int> clean_meta() {
Chenggang Zhao's avatar
Chenggang Zhao committed
120
121
122
123
124
125
        EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
        return {dispatch_rdma_recv_count_buffer, num_clean_int};
    }
};

struct LowLatencyLayout {
lijian6's avatar
lijian6 committed
126
    size_t           total_bytes = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
127
128
    LowLatencyBuffer buffers[2];

lijian6's avatar
lijian6 committed
129
130
131
    template <typename out_ptr_t = void *, typename count_ptr_t = uint8_t *,
              typename in_ptr_t = void *>
    out_ptr_t advance(const in_ptr_t &ptr, size_t count) {
Chenggang Zhao's avatar
Chenggang Zhao committed
132
133
134
        return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
    }

lijian6's avatar
lijian6 committed
135
136
    LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
                     int num_ranks, int num_experts) {
Chenggang Zhao's avatar
Chenggang Zhao committed
137
138
139
140
141
142
143
144
        const int num_scales = hidden / 128;

        // Dispatch and combine layout:
        //  - 2 symmetric odd/even send buffer
        //  - 2 symmetric odd/even receive buffers
        //  - 2 symmetric odd/even signaling buffers

        // Message sizes
lijian6's avatar
lijian6 committed
145
146
147
148
149
150
151
        // NOTES: you should add a control `int4` for combine messages if you want to do data
        // transformation
        EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast<size_t>(hidden));
        size_t num_bytes_per_dispatch_msg =
            sizeof(int4) +
            std::max(hidden * sizeof(hip_bfloat16), hidden + num_scales * sizeof(float));
        size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16);
Chenggang Zhao's avatar
Chenggang Zhao committed
152
153

        // Send buffer
lijian6's avatar
lijian6 committed
154
155
156
157
        size_t dispatch_send_buffer_bytes =
            num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
        size_t combine_send_buffer_bytes =
            num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
Chenggang Zhao's avatar
Chenggang Zhao committed
158
159
160
161
162
163
        size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
        EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
        total_bytes += send_buffer_bytes * 2;

        // Symmetric receive buffers
        // TODO: optimize memory usages
lijian6's avatar
lijian6 committed
164
165
166
167
168
169
        size_t dispatch_recv_data_buffer_bytes =
            num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
        size_t combine_recv_buffer_bytes =
            num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
        size_t recv_buffer_bytes =
            std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
170
171
172
173
174
        EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
        total_bytes += recv_buffer_bytes * 2;

        // Symmetric signaling buffers
        size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
lijian6's avatar
lijian6 committed
175
176
177
178
        size_t combine_recv_flag_buffer_bytes   = dispatch_recv_count_buffer_bytes;
        size_t signaling_buffer_bytes =
            std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
        size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
179
        total_bytes += signaling_buffer_bytes_aligned * 2;
Chenggang Zhao's avatar
Chenggang Zhao committed
180
181
182
183

        // Assign pointers
        // NOTES: we still leave some space for distinguishing dispatch/combine buffer,
        // so you may see some parameters are duplicated
lijian6's avatar
lijian6 committed
184
        for (int i = 0; i < 2; ++i) {
Chenggang Zhao's avatar
Chenggang Zhao committed
185
186
            buffers[i] = {
                static_cast<int>(signaling_buffer_bytes / sizeof(int)),
187
                advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
lijian6's avatar
lijian6 committed
188
189
190
                advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
                                         recv_buffer_bytes * i),
                advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
191
                advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
lijian6's avatar
lijian6 committed
192
193
194
                advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
                                         recv_buffer_bytes * i),
                advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
195
                advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
lijian6's avatar
lijian6 committed
196
                num_bytes_per_combine_msg};
Chenggang Zhao's avatar
Chenggang Zhao committed
197
198
199
200
        }
    }
};

lijian6's avatar
lijian6 committed
201
202
203
204
205
206
207
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
                                             int num_ranks, int num_experts) {
    auto num_bytes =
        LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
            .total_bytes;
    return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
           NUM_BUFFER_ALIGNMENT_BYTES;
Chenggang Zhao's avatar
Chenggang Zhao committed
208
209
210
}

} // namespace deep_ep