api.cuh 8.39 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
3
4
5
6
7
8
9
#pragma once

#include <vector>

namespace deep_ep {

// Intranode runtime
namespace intranode {

10
void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

} // namespace intranode

// Internode runtime
namespace internode {

std::vector<uint8_t> get_unique_id();

int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode);

void *alloc(size_t size, size_t alignment);

void free(void *ptr);

void barrier();

void finalize();

} // namespace internode

31
32
33
34
35
36
37
38
39
40
41
// Layout kernels
namespace layout {

void get_dispatch_layout(const int64_t* topk_idx,
                         int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
                         int* num_tokens_per_expert, bool* is_token_in_rank,
                         int num_tokens, int num_topk, int num_ranks, int num_experts,
                         cudaStream_t stream);

} // namespace layout

Chenggang Zhao's avatar
Chenggang Zhao committed
42
43
44
45
46
47
48
// Intranode kernels
namespace intranode {

void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
                     const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
                     int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
                     int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
49
                     void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
Chenggang Zhao's avatar
Chenggang Zhao committed
50
51
52
                     cudaStream_t stream, int num_sms);

void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
53
                            void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks,
Chenggang Zhao's avatar
Chenggang Zhao committed
54
55
56
57
58
                            cudaStream_t stream);

void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
              int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
              const bool* is_token_in_rank, const int* channel_prefix_matrix,
59
              int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
Shifang Xu's avatar
Shifang Xu committed
60
              int scale_token_stride, int scale_hidden_stride,
Chenggang Zhao's avatar
Chenggang Zhao committed
61
62
63
64
65
              void** buffer_ptrs, int rank, int num_ranks,
              cudaStream_t stream, int num_sms,
              int num_max_send_tokens, int num_recv_buffer_tokens);

void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
66
                           int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
67
68
69
70

void combine(cudaDataType_t type,
             void* recv_x, float* recv_topk_weights,
             const void* x, const float* topk_weights,
Shangyan Zhou's avatar
Shangyan Zhou committed
71
             const void* bias_0, const void* bias_1,
Chenggang Zhao's avatar
Chenggang Zhao committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
             const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
             int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
             void** buffer_ptrs, int rank, int num_ranks,
             cudaStream_t stream, int num_sms,
             int num_max_send_tokens, int num_recv_buffer_tokens);

} // namespace intranode

// Internode kernels
namespace internode {

int get_source_meta_bytes();

void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
                     const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
                     const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
                     const bool* is_token_in_rank, int num_tokens, int num_channels,
                     int hidden_int4, int num_scales, int num_topk, int expert_alignment,
                     int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
                     int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
                     void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
                     void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
94
                     int** barrier_signal_ptrs, int rank,
Chenggang Zhao's avatar
Chenggang Zhao committed
95
96
97
98
99
100
101
102
103
104
                     cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
                     bool low_latency_mode);

void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
              const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
              int* send_rdma_head, int* send_nvl_head,
              int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
              const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
              const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
              const bool* is_token_in_rank,
Shifang Xu's avatar
Shifang Xu committed
105
106
              int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
              int scale_token_stride, int scale_hidden_stride,
Chenggang Zhao's avatar
Chenggang Zhao committed
107
108
109
110
111
112
113
114
115
116
              void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
              void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
              int rank, int num_ranks, bool is_cached_dispatch,
              cudaStream_t stream, int num_channels, bool low_latency_mode);

void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
                   int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head,
                   const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
                   void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
                   void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
117
                   int** barrier_signal_ptrs, int rank, cudaStream_t stream,
Chenggang Zhao's avatar
Chenggang Zhao committed
118
119
120
121
122
123
124
                   int64_t num_rdma_bytes, int64_t num_nvl_bytes,
                   bool is_cached_dispatch, bool low_latency_mode);

void combine(cudaDataType_t type,
             void* combined_x, float* combined_topk_weights,
             const bool* is_combined_token_in_rank,
             const void* x, const float* topk_weights,
Shangyan Zhou's avatar
Shangyan Zhou committed
125
             const void* bias_0, const void* bias_1,
Chenggang Zhao's avatar
Chenggang Zhao committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
             const int* combined_rdma_head, const int* combined_nvl_head,
             const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
             int num_tokens, int num_combined_tokens, int hidden, int num_topk,
             void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
             void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
             int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode);

} // namespace internode

// Internode low-latency kernels
namespace internode_ll {

void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
                              int* clean_1, int num_clean_int_1,
                              cudaStream_t stream);

Shifang Xu's avatar
Shifang Xu committed
142
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
Chenggang Zhao's avatar
Chenggang Zhao committed
143
              int* packed_recv_src_info, int64_t* packed_recv_layout_range,
144
              int* packed_recv_count,
145
              int* cumulative_local_expert_recv_stats,
Chenggang Zhao's avatar
Chenggang Zhao committed
146
147
148
149
              void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
              const void* x, const int64_t* topk_idx,
              int* next_clean, int num_next_clean_int,
              int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
Shifang Xu's avatar
Shifang Xu committed
150
151
              int num_topk, int num_experts, int rank, int num_ranks,
              bool use_fp8, bool round_scale, bool use_ue8m0,
152
153
              void* workspace, int num_device_sms,
              cudaStream_t stream, int phases);
Chenggang Zhao's avatar
Chenggang Zhao committed
154
155
156
157
158
159
160
161

void combine(void* combined_x,
             void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
             const void* x, const int64_t* topk_idx, const float* topk_weights,
             const int* src_info, const int64_t* layout_range,
             int* next_clean, int num_next_clean_int,
             int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
             int num_topk, int num_experts, int rank, int num_ranks,
162
163
             void* workspace, int num_device_sms,
             cudaStream_t stream, int phases, bool zero_copy);
Chenggang Zhao's avatar
Chenggang Zhao committed
164
165
166
167

} // namespace internode_ll

} // namespace deep_ep