custom_all_reduce.h 2.36 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#pragma once
/*
   * Copyright (C) 2024-2025, The vLLM team.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <torch/extension.h>

// all reduce
using fptr_t = int64_t;

namespace aiter {

fptr_t init_custom_ar(torch::Tensor& meta,
                      torch::Tensor& rank_data,
                      const std::vector<torch::Tensor>& handles,
                      const std::vector<int64_t>& offsets,
                      int64_t rank,
                      bool fully_connected);
void all_reduce(fptr_t _fa,
                torch::Tensor& inp,
                torch::Tensor& out,
                bool open_fp8_quant,
                std::optional<torch::Tensor> reg_buffer);
void all_gather_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_gather_unreg(fptr_t _fa,
                      torch::Tensor& inp,
                      torch::Tensor& reg_buffer,
                      torch::Tensor& out);
void fused_allreduce_rmsnorm(fptr_t _fa,
                torch::Tensor& inp,
                torch::Tensor& res_inp,
                torch::Tensor& res_out,
                torch::Tensor& out,
                torch::Tensor& w,
                float eps,
                std::optional<torch::Tensor> reg_buffer);

void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa,
                     torch::Tensor& t,
                     const std::vector<torch::Tensor>& handles,
                     const std::vector<int64_t>& offsets);
std::tuple<torch::Tensor, torch::Tensor> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
                            const std::vector<torch::Tensor>& handles,
                            const std::vector<torch::Tensor>& offsets);
#ifdef USE_ROCM
torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
#endif

} // namespace aiter