config.hpp 12.1 KB
Newer Older
1
2
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
Chenggang Zhao's avatar
Chenggang Zhao committed
3
4
#pragma once

lijian6's avatar
lijian6 committed
5
6
#include "kernels/api.cuh"
#include "kernels/configs.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
7
8
9
10
11
12
13
14
15
16
17
#include "kernels/exception.cuh"

namespace deep_ep {

struct Config {
    int num_sms;
    int num_max_nvl_chunked_send_tokens;
    int num_max_nvl_chunked_recv_tokens;
    int num_max_rdma_chunked_send_tokens;
    int num_max_rdma_chunked_recv_tokens;

lijian6's avatar
lijian6 committed
18
19
20
21
22
23
    Config(int num_sms, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
           int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens)
        : num_sms(num_sms), num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
          num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
          num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
          num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
Chenggang Zhao's avatar
Chenggang Zhao committed
24
        EP_HOST_ASSERT(num_sms >= 0);
lijian6's avatar
lijian6 committed
25
26
        EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and
                           num_max_nvl_chunked_recv_tokens > 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
27
        EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
lijian6's avatar
lijian6 committed
28
29
        EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and
                           num_max_rdma_chunked_recv_tokens > 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
30
31

        // Ceil up RDMA buffer size
lijian6's avatar
lijian6 committed
32
33
        this->num_max_rdma_chunked_recv_tokens =
            ALIGN<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
Chenggang Zhao's avatar
Chenggang Zhao committed
34
        EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
lijian6's avatar
lijian6 committed
35
36
37
38
        // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always
        // have space to push
        EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <=
                           num_max_rdma_chunked_recv_tokens / 2);
Chenggang Zhao's avatar
Chenggang Zhao committed
39
40
41
42
43
    }

    size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
        // Below are some assumptions
        // TODO: add assertions
lijian6's avatar
lijian6 committed
44
        constexpr int kNumMaxTopK   = 128;
Chenggang Zhao's avatar
Chenggang Zhao committed
45
46
        constexpr int kNumMaxScales = 128;
        EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
47
        EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % (2 * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL) == 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
48
        const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
lijian6's avatar
lijian6 committed
49
        const auto num_nvl_ranks  = std::min(num_ranks, NUM_MAX_NVL_PEERS);
50
        const int  num_channels   = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
Chenggang Zhao's avatar
Chenggang Zhao committed
51

52
53
54
        // 计算每个nvl通信数据包的数据量
        size_t num_single_nvl_bag_bytes =
            hidden_bytes +                          // 数据缓冲区(Token Data)。存储从 RDMA 转发过来的 token 数据(x 张量)
lijian6's avatar
lijian6 committed
55
#ifndef DISABLE_ROCSHMEM
56
            internode::get_source_meta_bytes() +    // 源元数据缓冲区(Source Metadata)。存储每个 token 的源信息(哪个 RDMA rank 发送的)
57
#endif
58
59
60
61
62
63
64
65
66
67
68
            kNumMaxTopK * sizeof(int) +             // TopK 索引缓冲区。存储每个 token 的 top-k 专家索引
            kNumMaxTopK * sizeof(float) +           // TopK 权重缓冲区。存储每个 token 的 top-k 专家权重
            kNumMaxScales * sizeof(float);          // Scale 缓冲区。存储每个 token 的量化缩放因子

        // 计算每个 NVL channel 的控制信息所需的字节数,存储每个 NVL channel 的前缀索引信息,用于快速定位数据(nvl_channel_prefix_start、nvl_channel_prefix_end 等)
        size_t num_single_nvl_control_bytes = (2 * num_rdma_ranks + 3) * sizeof(int);

        // NVL 数据总的字节数
        size_t num_bytes = (num_single_nvl_bag_bytes * num_max_nvl_chunked_recv_tokens + num_single_nvl_control_bytes) * num_channels * num_nvl_ranks;

        // 128 字节对齐,匹配 GPU 缓存行大小,优化内存访问。
Chenggang Zhao's avatar
Chenggang Zhao committed
69
70
71
72
73
        num_bytes = ((num_bytes + 127) / 128) * 128;
        return num_bytes;
    }

    size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
lijian6's avatar
lijian6 committed
74
#ifndef DISABLE_ROCSHMEM
Chenggang Zhao's avatar
Chenggang Zhao committed
75
76
77
78
79
80
        // Legacy mode
        if (num_ranks <= NUM_MAX_NVL_PEERS)
            return 0;

        // Below are some assumptions
        // TODO: add assertions
lijian6's avatar
lijian6 committed
81
        constexpr int kNumMaxTopK   = 128;
Chenggang Zhao's avatar
Chenggang Zhao committed
82
83
        constexpr int kNumMaxScales = 128;
        EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
84
        EP_HOST_ASSERT(num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
85
        const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        const int num_channels   = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;

        // 计算每个rdma通信数据包的数据量
        size_t num_single_rdma_bag_bytes = 
            hidden_bytes +                          // 数据缓冲区。存储实际的 token 数据(x 张量),对应代码中的 rdma_channel_data
            internode::get_source_meta_bytes() +    // 源元数据缓冲区。存储每个 token 的源信息(SourceMeta)
            kNumMaxTopK * sizeof(int) +             // 存储每个 token 的 top-k 专家索引。对应 topk_idx 数据
            kNumMaxTopK * sizeof(float) +           // 存储每个 token 的 top-k 专家权重。对应 topk_weights 数据
            kNumMaxScales * sizeof(float) +         // 存储每个 token 的缩放因子(x_scales)
            sizeof(int4);                           // 预留空间用于内存对齐和未来扩展
        
        // 计算每个 RDMA channel 的控制信息(起始/结束索引)所需的字节数,对应代码中的 rdma_channel_meta
        size_t num_single_rdma_control_bytes = (NUM_MAX_NVL_PEERS * 2 + 4) * sizeof(int);

        // RDMA 数据总的字节数
        size_t num_bytes = (num_single_rdma_bag_bytes * num_max_rdma_chunked_recv_tokens + num_single_rdma_control_bytes) *
            num_channels * num_rdma_ranks * 2;
Chenggang Zhao's avatar
Chenggang Zhao committed
103

104
        // 128 字节对齐(缓存行对齐),优化内存访问性能
Chenggang Zhao's avatar
Chenggang Zhao committed
105
106
        num_bytes = ((num_bytes + 127) / 128) * 128;
        return num_bytes;
107
#else
lijian6's avatar
lijian6 committed
108
109
        EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install "
                                  "rocSHMEM by following docs/install_dependencies.md");
110
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
111
112
113
114
115
116
    }
};

struct LowLatencyBuffer {
    int num_clean_int = 0;

117
118
119
    void* dispatch_rdma_send_buffer = nullptr;
    void* dispatch_rdma_recv_data_buffer = nullptr;
    int64_t* dispatch_rdma_recv_count_buffer = nullptr;
Chenggang Zhao's avatar
Chenggang Zhao committed
120

121
122
123
    void* combine_rdma_send_buffer = nullptr;
    void* combine_rdma_recv_data_buffer = nullptr;
    int64_t* combine_rdma_recv_flag_buffer = nullptr;
Chenggang Zhao's avatar
Chenggang Zhao committed
124

125
126
    void* combine_rdma_send_buffer_data_start = nullptr;
    size_t num_bytes_per_combine_msg = 0;
127

128
    std::pair<int64_t*, int> clean_meta() {
Chenggang Zhao's avatar
Chenggang Zhao committed
129
130
131
132
133
134
        EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
        return {dispatch_rdma_recv_count_buffer, num_clean_int};
    }
};

struct LowLatencyLayout {
lijian6's avatar
lijian6 committed
135
    size_t           total_bytes = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
136
137
    LowLatencyBuffer buffers[2];

lijian6's avatar
lijian6 committed
138
139
140
    template <typename out_ptr_t = void *, typename count_ptr_t = uint8_t *,
              typename in_ptr_t = void *>
    out_ptr_t advance(const in_ptr_t &ptr, size_t count) {
Chenggang Zhao's avatar
Chenggang Zhao committed
141
142
143
        return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
    }

lijian6's avatar
lijian6 committed
144
    LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
145
                     int num_ranks, int num_experts, bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
146
        const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
Chenggang Zhao's avatar
Chenggang Zhao committed
147

148
149
        const int num_nodes = num_ranks / NUM_MAX_NVL_PEERS;  // 计算结点数

Chenggang Zhao's avatar
Chenggang Zhao committed
150
151
152
153
154
155
        // Dispatch and combine layout:
        //  - 2 symmetric odd/even send buffer
        //  - 2 symmetric odd/even receive buffers
        //  - 2 symmetric odd/even signaling buffers

        // Message sizes
lijian6's avatar
lijian6 committed
156
157
158
159
        // NOTES: you should add a control `int4` for combine messages if you want to do data
        // transformation
        EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast<size_t>(hidden));
        size_t num_bytes_per_dispatch_msg =
160
161
            sizeof(int4) + std::max(hidden * sizeof(hip_bfloat16), hidden +
            (quant_group_size == 0 ? 4 : num_scales) * sizeof(float));   // 应该是1,但是代码中为了满足int4对齐
lishen's avatar
lishen committed
162
163

        // 与internode_ll::combine 中的 num_bytes_per_slot 相等
164
165
166
        size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) + 
                                           (enable_dispatch_ll_layered ? 0 : // 即enable_combine_overlap==true,执行函数combine_sbo
                                           num_scales * sizeof(__hip_bfloat162));
Chenggang Zhao's avatar
Chenggang Zhao committed
167
168

        // Send buffer
lijian6's avatar
lijian6 committed
169
170
171
172
        size_t dispatch_send_buffer_bytes =
            num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
        size_t combine_send_buffer_bytes =
            num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
Chenggang Zhao's avatar
Chenggang Zhao committed
173
174
175
176
177
178
        size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
        EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
        total_bytes += send_buffer_bytes * 2;

        // Symmetric receive buffers
        // TODO: optimize memory usages
lijian6's avatar
lijian6 committed
179
180
181
182
183
184
        size_t dispatch_recv_data_buffer_bytes =
            num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
        size_t combine_recv_buffer_bytes =
            num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
        size_t recv_buffer_bytes =
            std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
185
186
187
188
        EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
        total_bytes += recv_buffer_bytes * 2;

        // Symmetric signaling buffers
189
        size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t);
190
191
192
193
        if (enable_dispatch_ll_layered) {
            dispatch_recv_count_buffer_bytes +=
                NUM_MAX_NVL_PEERS * num_nodes * num_max_dispatch_tokens_per_rank * sizeof(int) + NUM_MAX_NVL_PEERS * sizeof(int);
        }
194
195
        size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
        size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
lishen's avatar
lishen committed
196
197
        size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
        total_bytes += signaling_buffer_bytes_aligned * 2;
Chenggang Zhao's avatar
Chenggang Zhao committed
198
199
200
201

        // Assign pointers
        // NOTES: we still leave some space for distinguishing dispatch/combine buffer,
        // so you may see some parameters are duplicated
lijian6's avatar
lijian6 committed
202
        for (int i = 0; i < 2; ++i) {
Chenggang Zhao's avatar
Chenggang Zhao committed
203
            buffers[i] = {
204
205
                static_cast<int>(signaling_buffer_bytes / sizeof(int64_t)),
                // dispatch:send_buffer + recv_buffer + recv_count
lishen's avatar
lishen committed
206
207
208
                advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
                advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i),
                advance<int64_t*>(rdma_buffer, signaling_buffer_bytes_aligned * i),
209
                // combine:send_buffer + recv_buffer + recv_count
lishen's avatar
lishen committed
210
211
212
                advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
                advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i),
                advance<int64_t*>(rdma_buffer, signaling_buffer_bytes_aligned * i),
213
                // combine_rdma_send_buffer_data_start
lishen's avatar
lishen committed
214
                advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
215
216
217
                //
                num_bytes_per_combine_msg
            };
Chenggang Zhao's avatar
Chenggang Zhao committed
218
219
220
221
        }
    }
};

lijian6's avatar
lijian6 committed
222
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
223
224
                                             int num_ranks, int num_experts, 
                                             bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
lijian6's avatar
lijian6 committed
225
    auto num_bytes =
226
227
        LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, 
        enable_dispatch_ll_layered, quant_group_size)
lijian6's avatar
lijian6 committed
228
229
230
            .total_bytes;
    return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
           NUM_BUFFER_ALIGNMENT_BYTES;
Chenggang Zhao's avatar
Chenggang Zhao committed
231
232
233
}

} // namespace deep_ep