custom_all_reduce.h 5.49 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
#pragma once
/*
3
4
 * Copyright (C) Advanced Micro Devices, Inc. All rights reserved.
 * Copyright (C) 2024-2026, The vLLM team.
Xiaowei.zhang's avatar
Xiaowei.zhang committed
5
6
7
8
9
10
11
12
13
14
15
16
17
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
18
19
20
#include <cstdint>
#include <vector>
#include "aiter_tensor.h"
Xiaowei.zhang's avatar
Xiaowei.zhang committed
21
22
23
24
25
26

// all reduce
using fptr_t = int64_t;

namespace aiter {

27
28
29
30
fptr_t init_custom_ar(int64_t meta_ptr,
                      int64_t rank_data_ptr,
                      int64_t rank_data_sz,
                      const std::vector<int64_t>& ipc_handle_ptrs,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
31
32
33
34
                      const std::vector<int64_t>& offsets,
                      int64_t rank,
                      bool fully_connected);
void all_reduce(fptr_t _fa,
35
36
37
                const aiter_tensor_t& inp,
                const aiter_tensor_t& out,
                bool use_new,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
38
                bool open_fp8_quant,
39
40
41
42
43
44
45
46
47
48
49
                int64_t reg_inp_ptr,
                int64_t reg_inp_bytes);
void reduce_scatter(fptr_t _fa,
                    const aiter_tensor_t& inp,
                    const aiter_tensor_t& out,
                    int64_t reg_ptr,
                    int64_t reg_bytes);
void all_gather_reg(fptr_t _fa,
                    const aiter_tensor_t& inp,
                    const aiter_tensor_t& out,
                    int64_t dim);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
50
void all_gather_unreg(fptr_t _fa,
51
52
53
54
55
                      const aiter_tensor_t& inp,
                      int64_t reg_buffer,
                      const aiter_tensor_t& out,
                      int64_t reg_bytes,
                      int64_t dim);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
56
void fused_allreduce_rmsnorm(fptr_t _fa,
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
                             const aiter_tensor_t& inp,
                             const aiter_tensor_t& res_inp,
                             const aiter_tensor_t& res_out,
                             const aiter_tensor_t& out,
                             const aiter_tensor_t& w,
                             double eps,
                             int64_t reg_ptr,
                             int64_t reg_bytes,
                             bool use_1stage);
void fused_allreduce_rmsnorm_quant(fptr_t _fa,
                                   const aiter_tensor_t& inp,
                                   const aiter_tensor_t& res_inp,
                                   const aiter_tensor_t& res_out,
                                   const aiter_tensor_t& out,
                                   const aiter_tensor_t& scale_out,
                                   const aiter_tensor_t& w,
                                   double eps,
                                   int64_t reg_ptr,
                                   int64_t reg_bytes,
                                   bool use_1stage);
void fused_allreduce_rmsnorm_quant_per_group(fptr_t _fa,
                                             const aiter_tensor_t& inp,
                                             const aiter_tensor_t& res_inp,
                                             const aiter_tensor_t& res_out,
                                             const aiter_tensor_t& out,
                                             const aiter_tensor_t& scale_out,
                                             const aiter_tensor_t& w,
                                             double eps,
                                             int64_t group_size,
                                             int64_t reg_ptr,
                                             int64_t reg_bytes,
                                             bool use_1stage,
                                             int64_t bf16_out_ptr = 0);
void fused_qknorm_allreduce(fptr_t _fa,
                            const aiter_tensor_t& qkv_in,
                            const aiter_tensor_t& q_w,
                            const aiter_tensor_t& k_w,
                            const aiter_tensor_t& q_out,
                            const aiter_tensor_t& k_out,
                            const aiter_tensor_t& v_out,
                            double eps,
                            int64_t reg_ptr,
                            int64_t reg_bytes);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
100
101
void dispose(fptr_t _fa);
int64_t meta_size();
102
103
104
105
106
107
108
109
110
111
112
113
void register_input_buffer(fptr_t _fa,
                           int64_t self_ptr,
                           const std::vector<int64_t>& ipc_handle_ptrs,
                           const std::vector<int64_t>& offsets);
void register_output_buffer(fptr_t _fa,
                            int64_t self_ptr,
                            const std::vector<int64_t>& ipc_handle_ptrs,
                            const std::vector<int64_t>& offsets);
int64_t get_graph_buffer_count(fptr_t _fa);
void get_graph_buffer_ipc_meta(fptr_t _fa,
                               int64_t handle_out,
                               int64_t offset_out);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
114
void register_graph_buffers(fptr_t _fa,
115
116
                            const std::vector<int64_t>& handle_ptrs,
                            const std::vector<int64_t>& offset_ptrs);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
117
#ifdef USE_ROCM
118
119
120
int64_t allocate_meta_buffer(int64_t size);
void free_meta_buffer(int64_t ptr);
void get_meta_buffer_ipc_handle(int64_t inp_ptr, int64_t out_handle_ptr);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
121
122
123
#endif

} // namespace aiter