deep_ep.hpp 9.1 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
3
4
5
6
7
8
#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/types.h>
#include <tuple>
#include <vector>

lijian6's avatar
lijian6 committed
9
#include "kernels/configs.cuh"
lijian6's avatar
lijian6 committed
10
#include "kernels/exception.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
11
12
13
14
15
16
17
18
19
20
#include "config.hpp"
#include "event.hpp"

namespace deep_ep {

struct Buffer {
    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8");

private:
    // Low-latency mode buffer
lijian6's avatar
lijian6 committed
21
22
    int  low_latency_buffer_idx = 0;
    bool low_latency_mode       = false;
Chenggang Zhao's avatar
Chenggang Zhao committed
23
24
25

    // NVLink Buffer
    int64_t num_nvl_bytes;
lijian6's avatar
lijian6 committed
26
27
    void   *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
    void  **buffer_ptrs_gpu                = nullptr;
Chenggang Zhao's avatar
Chenggang Zhao committed
28

lishen's avatar
lishen committed
29
30
31
    void* nvl_buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
    void** nvl_buffer_ptrs_gpu = nullptr;

lijian6's avatar
lijian6 committed
32
    // DUSHMEM Buffer
Chenggang Zhao's avatar
Chenggang Zhao committed
33
    int64_t num_rdma_bytes;
lijian6's avatar
lijian6 committed
34
    void   *rdma_buffer_ptr = nullptr;
Chenggang Zhao's avatar
Chenggang Zhao committed
35

lishen's avatar
lishen committed
36
37
38
39
    // Shrink mode buffer
    bool enable_shrink = false;
    int* mask_buffer_ptr = nullptr;
    int* sync_buffer_ptr = nullptr;
40

Chenggang Zhao's avatar
Chenggang Zhao committed
41
    // Device info and communication
lijian6's avatar
lijian6 committed
42
43
44
45
46
    int               device_id;
    int               num_device_sms;
    int               rank, rdma_rank, nvl_rank;
    int               num_ranks, num_rdma_ranks, num_nvl_ranks;
    hipIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
Chenggang Zhao's avatar
Chenggang Zhao committed
47
48

    // Stream for communication
lijian6's avatar
lijian6 committed
49
    at::hip::HIPStreamMasqueradingAsCUDA comm_stream;
Chenggang Zhao's avatar
Chenggang Zhao committed
50

lijian6's avatar
lijian6 committed
51
    // After IPC/DUSHMEM synchronization, this flag will be true
Chenggang Zhao's avatar
Chenggang Zhao committed
52
53
    bool available = false;

54
55
56
57
58
    // Whether explicit `destroy()` is required.
    bool explicitly_destroy;
    // After `destroy()` be called, this flag will be true
    bool destroyed = false;

59
    // Barrier signals
lijian6's avatar
lijian6 committed
60
61
    int  *barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
    int **barrier_signal_ptrs_gpu                = nullptr;
Chenggang Zhao's avatar
Chenggang Zhao committed
62
63

    // Workspace
lijian6's avatar
lijian6 committed
64
    void *workspace = nullptr;
Chenggang Zhao's avatar
Chenggang Zhao committed
65
66

    // Host-side MoE info
lijian6's avatar
lijian6 committed
67
68
    volatile int *moe_recv_counter        = nullptr;
    int          *moe_recv_counter_mapped = nullptr;
Chenggang Zhao's avatar
Chenggang Zhao committed
69
70

    // Host-side expert-level MoE info
lijian6's avatar
lijian6 committed
71
72
    volatile int *moe_recv_expert_counter        = nullptr;
    int          *moe_recv_expert_counter_mapped = nullptr;
Chenggang Zhao's avatar
Chenggang Zhao committed
73
74

    // Host-side RDMA-level MoE info
lijian6's avatar
lijian6 committed
75
76
77
    volatile int *moe_recv_rdma_counter        = nullptr;
    int          *moe_recv_rdma_counter_mapped = nullptr;

Chenggang Zhao's avatar
Chenggang Zhao committed
78
public:
lijian6's avatar
lijian6 committed
79
    Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
80
           bool low_latency_mode, bool explicitly_destroy, bool enable_shrink);
Chenggang Zhao's avatar
Chenggang Zhao committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    ~Buffer() noexcept(false);

    bool is_available() const;

    bool is_internode_available() const;

    int get_num_rdma_ranks() const;

    int get_rdma_rank() const;

    int get_root_rdma_rank(bool global) const;

    int get_local_device_id() const;

    pybind11::bytearray get_local_ipc_handle() const;

lijian6's avatar
lijian6 committed
98
    pybind11::bytearray get_local_dushmem_unique_id() const;
Chenggang Zhao's avatar
Chenggang Zhao committed
99

lijian6's avatar
lijian6 committed
100
101
    torch::Tensor get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset,
                                          bool use_rdma_buffer) const;
Chenggang Zhao's avatar
Chenggang Zhao committed
102

Shangyan Zhou's avatar
Shangyan Zhou committed
103
104
    torch::Stream get_comm_stream() const;

lijian6's avatar
lijian6 committed
105
106
107
    void sync(const std::vector<int>                                &device_ids,
              const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
              const std::optional<pybind11::bytearray>              &root_unique_id_opt);
Chenggang Zhao's avatar
Chenggang Zhao committed
108

109
110
    void destroy();

lijian6's avatar
lijian6 committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
               std::optional<EventHandle>>
    get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
                        std::optional<EventHandle> &previous_event, bool async,
                        bool allocate_on_comm_stream);

    std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
               std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
               torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
               std::optional<EventHandle>>
    intranode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
                       const std::optional<torch::Tensor> &topk_idx,
                       const std::optional<torch::Tensor> &topk_weights,
                       const std::optional<torch::Tensor> &num_tokens_per_rank,
                       const torch::Tensor                &is_token_in_rank,
                       const std::optional<torch::Tensor> &num_tokens_per_expert,
                       int                                 cached_num_recv_tokens,
                       const std::optional<torch::Tensor> &cached_rank_prefix_matrix,
                       const std::optional<torch::Tensor> &cached_channel_prefix_matrix,
                       int expert_alignment, int num_worst_tokens, const Config &config,
                       std::optional<EventHandle> &previous_event, bool async,
                       bool allocate_on_comm_stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
133
134

    std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
lijian6's avatar
lijian6 committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    intranode_combine(const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
                      const std::optional<torch::Tensor> &bias_0,
                      const std::optional<torch::Tensor> &bias_1, const torch::Tensor &src_idx,
                      const torch::Tensor &rank_prefix_matrix,
                      const torch::Tensor &channel_prefix_matrix, const torch::Tensor &send_head,
                      const Config &config, std::optional<EventHandle> &previous_event, bool async,
                      bool allocate_on_comm_stream);

    std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
               std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
               std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>,
               torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
               std::optional<torch::Tensor>, std::optional<EventHandle>>
    internode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
                       const std::optional<torch::Tensor> &topk_idx,
                       const std::optional<torch::Tensor> &topk_weights,
                       const std::optional<torch::Tensor> &num_tokens_per_rank,
                       const std::optional<torch::Tensor> &num_tokens_per_rdma_rank,
                       const torch::Tensor                &is_token_in_rank,
                       const std::optional<torch::Tensor> &num_tokens_per_expert,
Chenggang Zhao's avatar
Chenggang Zhao committed
155
                       int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
lijian6's avatar
lijian6 committed
156
157
158
159
160
161
162
                       const std::optional<torch::Tensor> &cached_rdma_channel_prefix_matrix,
                       const std::optional<torch::Tensor> &cached_recv_rdma_rank_prefix_sum,
                       const std::optional<torch::Tensor> &cached_gbl_channel_prefix_matrix,
                       const std::optional<torch::Tensor> &cached_recv_gbl_rank_prefix_sum,
                       int expert_alignment, const Config &config,
                       std::optional<EventHandle> &previous_event, bool async,
                       bool allocate_on_comm_stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
163
164

    std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
lijian6's avatar
lijian6 committed
165
166
167
168
169
170
171
172
173
174
175
176
    internode_combine(
        const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
        const std::optional<torch::Tensor> &bias_0, const std::optional<torch::Tensor> &bias_1,
        const torch::Tensor &src_meta, const torch::Tensor &is_combined_token_in_rank,
        const torch::Tensor &rdma_channel_prefix_matrix, const torch::Tensor &rdma_rank_prefix_sum,
        const torch::Tensor &gbl_channel_prefix_matrix, const torch::Tensor &combined_rdma_head,
        const torch::Tensor &combined_nvl_head, const Config &config,
        std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream);

    void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
                                  int num_experts);

lishen's avatar
lishen committed
177
178
179
    std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
    low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
                         int num_max_dispatch_tokens_per_rank, int num_experts,
lishen's avatar
lishen committed
180
                         bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
lishen's avatar
lishen committed
181
                         bool async, bool return_recv_hook);
Chenggang Zhao's avatar
Chenggang Zhao committed
182
183

    std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
lishen's avatar
lishen committed
184
185
    low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
                        const torch::Tensor& src_info, const torch::Tensor& layout_range,
lishen's avatar
lishen committed
186
                        const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
lishen's avatar
lishen committed
187
                        int num_max_dispatch_tokens_per_rank, int num_experts,
lijian6's avatar
lijian6 committed
188
                        bool zero_copy, bool async, bool return_recv_hook,
lishen's avatar
lishen committed
189
                        const std::optional<torch::Tensor>& out = std::nullopt);
190

lishen's avatar
lishen committed
191
    torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
Chenggang Zhao's avatar
Chenggang Zhao committed
192
193
194
};

} // namespace deep_ep