deep_ep_hip.hpp 9.31 KB
Newer Older
lijian6's avatar
lijian6 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/types.h>
#include <tuple>
#include <vector>

#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
#include "config_hip.hpp"
#include "event.hpp"

namespace deep_ep {

struct Buffer {
    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8");

private:
    // Low-latency mode buffer
    int  low_latency_buffer_idx = 0;
    bool low_latency_mode       = false;

    // NVLink Buffer
    int64_t num_nvl_bytes;
    void   *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
    void  **buffer_ptrs_gpu                = nullptr;

    // NVSHMEM Buffer
    int64_t num_rdma_bytes;
    void   *rdma_buffer_ptr = nullptr;

    // Device info and communication
    int               device_id;
    int               num_device_sms;
    int               rank, rdma_rank, nvl_rank;
    int               num_ranks, num_rdma_ranks, num_nvl_ranks;
    hipIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];

    // Stream for communication
    at::hip::HIPStreamMasqueradingAsCUDA comm_stream;

    // After IPC/NVSHMEM synchronization, this flag will be true
    bool available = false;

    // Whether explicit `destroy()` is required.
    bool explicitly_destroy;
    // After `destroy()` be called, this flag will be true
    bool destroyed = false;

    // Barrier signals
    int  *barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
    int **barrier_signal_ptrs_gpu                = nullptr;

    // Workspace
    void *workspace = nullptr;

    // Host-side MoE info
    volatile int *moe_recv_counter        = nullptr;
    int          *moe_recv_counter_mapped = nullptr;

    // Host-side expert-level MoE info
    volatile int *moe_recv_expert_counter        = nullptr;
    int          *moe_recv_expert_counter_mapped = nullptr;

    // Host-side RDMA-level MoE info
    volatile int *moe_recv_rdma_counter        = nullptr;
    int          *moe_recv_rdma_counter_mapped = nullptr;

    bool use_default_stream_as_comm_stream = false;

public:
    Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
           bool low_latency_mode, bool explicitly_destroy, bool use_default_stream_as_comm_stream);

    ~Buffer() noexcept(false);

    bool is_available() const;

    bool is_internode_available() const;

    int get_num_rdma_ranks() const;

    int get_rdma_rank() const;

    int get_root_rdma_rank(bool global) const;

    int get_local_device_id() const;

    pybind11::bytearray get_local_ipc_handle() const;

    pybind11::bytearray get_local_nvshmem_unique_id() const;

    torch::Tensor get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset,
                                          bool use_rdma_buffer) const;

    torch::Stream get_comm_stream() const;

    void sync(const std::vector<int>                                &device_ids,
              const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
              const std::optional<pybind11::bytearray>              &root_unique_id_opt);

    void destroy();

    std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
               std::optional<EventHandle>>
    get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
                        std::optional<EventHandle> &previous_event, bool async,
                        bool allocate_on_comm_stream);

    std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
               std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
               torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
               std::optional<EventHandle>>
    intranode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
                       const std::optional<torch::Tensor> &topk_idx,
                       const std::optional<torch::Tensor> &topk_weights,
                       const std::optional<torch::Tensor> &num_tokens_per_rank,
                       const torch::Tensor                &is_token_in_rank,
                       const std::optional<torch::Tensor> &num_tokens_per_expert,
                       int                                 cached_num_recv_tokens,
                       const std::optional<torch::Tensor> &cached_rank_prefix_matrix,
                       const std::optional<torch::Tensor> &cached_channel_prefix_matrix,
                       int expert_alignment, int num_worst_tokens, const Config &config,
                       std::optional<EventHandle> &previous_event, bool async,
                       bool allocate_on_comm_stream);

    std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
    intranode_combine(const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
                      const std::optional<torch::Tensor> &bias_0,
                      const std::optional<torch::Tensor> &bias_1, const torch::Tensor &src_idx,
                      const torch::Tensor &rank_prefix_matrix,
                      const torch::Tensor &channel_prefix_matrix, const torch::Tensor &send_head,
                      const Config &config, std::optional<EventHandle> &previous_event, bool async,
                      bool allocate_on_comm_stream);

    std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
               std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
               std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>,
               torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
               std::optional<torch::Tensor>, std::optional<EventHandle>>
    internode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
                       const std::optional<torch::Tensor> &topk_idx,
                       const std::optional<torch::Tensor> &topk_weights,
                       const std::optional<torch::Tensor> &num_tokens_per_rank,
                       const std::optional<torch::Tensor> &num_tokens_per_rdma_rank,
                       const torch::Tensor                &is_token_in_rank,
                       const std::optional<torch::Tensor> &num_tokens_per_expert,
                       int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
                       const std::optional<torch::Tensor> &cached_rdma_channel_prefix_matrix,
                       const std::optional<torch::Tensor> &cached_recv_rdma_rank_prefix_sum,
                       const std::optional<torch::Tensor> &cached_gbl_channel_prefix_matrix,
                       const std::optional<torch::Tensor> &cached_recv_gbl_rank_prefix_sum,
                       int expert_alignment, const Config &config,
                       std::optional<EventHandle> &previous_event, bool async,
                       bool allocate_on_comm_stream);

    std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
    internode_combine(
        const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
        const std::optional<torch::Tensor> &bias_0, const std::optional<torch::Tensor> &bias_1,
        const torch::Tensor &src_meta, const torch::Tensor &is_combined_token_in_rank,
        const torch::Tensor &rdma_channel_prefix_matrix, const torch::Tensor &rdma_rank_prefix_sum,
        const torch::Tensor &gbl_channel_prefix_matrix, const torch::Tensor &combined_rdma_head,
        const torch::Tensor &combined_nvl_head, const Config &config,
        std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream);

    void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
                                  int num_experts);

    std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
               torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
    low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx,
                         const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats,
                         const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
                         int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
                         bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook);

    std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
    low_latency_combine(const torch::Tensor &x, const torch::Tensor &topk_idx,
                        const torch::Tensor &topk_weights, const torch::Tensor &src_info,
                        const torch::Tensor                &layout_range,
                        const std::optional<torch::Tensor> &combine_wait_recv_cost_stats,
                        int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
                        bool zero_copy, bool async, bool return_recv_hook,
                        const std::optional<torch::Tensor> &out = std::nullopt);

    torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
                                                      int hidden, int num_experts) const;
};

} // namespace deep_ep