api.cuh 6.9 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
#pragma once

lijian6's avatar
lijian6 committed
3
#include <hip/hip_runtime.h>
Chenggang Zhao's avatar
Chenggang Zhao committed
4
5
#include <vector>

6
7
#include "configs.cuh"

Chenggang Zhao's avatar
Chenggang Zhao committed
8
9
10
11
12
namespace deep_ep {

// Intranode runtime
namespace intranode {

lijian6's avatar
lijian6 committed
13
void barrier(int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
14
15
16
17
18
19
20
21

} // namespace intranode

// Internode runtime
namespace internode {

std::vector<uint8_t> get_unique_id();

lijian6's avatar
lijian6 committed
22
23
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks,
         bool low_latency_mode);
Chenggang Zhao's avatar
Chenggang Zhao committed
24
25
26
27
28
29
30
31
32
33
34

void *alloc(size_t size, size_t alignment);

void free(void *ptr);

void barrier();

void finalize();

} // namespace internode

35
36
37
// Layout kernels
namespace layout {

lijian6's avatar
lijian6 committed
38
39
40
41
void get_dispatch_layout(const int64_t *topk_idx, int *num_tokens_per_rank,
                         int *num_tokens_per_rdma_rank, int *num_tokens_per_expert,
                         bool *is_token_in_rank, int num_tokens, int num_topk, int num_ranks,
                         int num_experts, hipStream_t stream);
42
43
44

} // namespace layout

Chenggang Zhao's avatar
Chenggang Zhao committed
45
46
47
// Intranode kernels
namespace intranode {

lijian6's avatar
lijian6 committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                     const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                     int64_t *moe_num_recv_tokens_per_experts, int num_experts, int num_tokens,
                     const bool *is_token_in_rank, int *channel_prefix_matrix,
                     int *rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
                     void **buffer_ptrs, int **barrier_signal_ptrs, int rank, hipStream_t stream,
                     int num_sms);

void cached_notify_dispatch(const int *rank_prefix_matrix, int num_memset_int, void **buffer_ptrs,
                            int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t stream);

void dispatch(void *recv_x, float *recv_x_scales, int *recv_src_idx, int64_t *recv_topk_idx,
              float *recv_topk_weights, int *recv_channel_offset, int *send_head, const void *x,
              const float *x_scales, const int64_t *topk_idx, const float *topk_weights,
              const bool *is_token_in_rank, const int *channel_prefix_matrix, int num_tokens,
              int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
              int scale_token_stride, int scale_hidden_stride, void **buffer_ptrs, int rank,
              int num_ranks, hipStream_t stream, int num_sms, int num_max_send_tokens,
              int num_recv_buffer_tokens);

void cached_notify_combine(void **buffer_ptrs, int *send_head, int num_channels,
                           int num_recv_tokens, int num_memset_int, int **barrier_signal_ptrs,
                           int rank, int num_ranks, hipStream_t stream);

void combine(hipDataType type, void *recv_x, float *recv_topk_weights, const void *x,
             const float *topk_weights, const void *bias_0, const void *bias_1, const int *src_idx,
             const int *rank_prefix_matrix, const int *channel_prefix_matrix, int *send_head,
             int num_tokens, int num_recv_tokens, int hidden, int num_topk, void **buffer_ptrs,
             int rank, int num_ranks, hipStream_t stream, int num_sms, int num_max_send_tokens,
             int num_recv_buffer_tokens);
Chenggang Zhao's avatar
Chenggang Zhao committed
78
79
80
81
82
83
84
85

} // namespace intranode

// Internode kernels
namespace internode {

int get_source_meta_bytes();

lijian6's avatar
lijian6 committed
86
87
88
89
90
91
92
93
94
95
96
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                     const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
                     const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                     int num_experts, const bool *is_token_in_rank, int num_tokens,
                     int num_channels, int hidden_int4, int num_scales, int num_topk,
                     int expert_alignment, int *rdma_channel_prefix_matrix,
                     int *recv_rdma_rank_prefix_sum, int *gbl_channel_prefix_matrix,
                     int *recv_gbl_rank_prefix_sum, void *rdma_buffer_ptr,
                     int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                     int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                     hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
97
98
                     bool low_latency_mode);

lijian6's avatar
lijian6 committed
99
100
101
102
103
104
105
106
107
108
109
110
111
void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
              void *recv_src_meta, const void *x, const float *x_scales, const int64_t *topk_idx,
              const float *topk_weights, int *send_rdma_head, int *send_nvl_head,
              int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix,
              const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum,
              const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum,
              const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales,
              int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride,
              void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
              int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
              int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
              int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels,
              bool low_latency_mode);
Chenggang Zhao's avatar
Chenggang Zhao committed
112
113

void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
lijian6's avatar
lijian6 committed
114
115
116
117
118
119
                   int num_ranks, int num_channels, int num_combined_tokens,
                   int *combined_rdma_head, const int *rdma_channel_prefix_matrix,
                   const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
                   int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                   int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                   hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
120
121
                   bool is_cached_dispatch, bool low_latency_mode);

lijian6's avatar
lijian6 committed
122
123
124
125
126
127
128
129
130
131
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
             const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
             const void *bias_0, const void *bias_1, const int *combined_rdma_head,
             const int *combined_nvl_head, const void *src_meta,
             const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
             const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
             int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
             int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
             int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
             int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode);
Chenggang Zhao's avatar
Chenggang Zhao committed
132
133
134

} // namespace internode
} // namespace deep_ep