scatter.h 2.79 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#pragma once

#include <torch/extension.h>

5
#include "macros.h"
Matthias Fey's avatar
Matthias Fey committed
6
7

namespace scatter {
8
SCATTER_API int64_t cuda_version() noexcept;
Matthias Fey's avatar
Matthias Fey committed
9
10
11
12
13

namespace detail {
SCATTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
} // namespace detail
} // namespace scatter
rusty1s's avatar
rusty1s committed
14

15
16
17
18
SCATTER_API torch::Tensor
scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
            torch::optional<torch::Tensor> optional_out,
            torch::optional<int64_t> dim_size);
rusty1s's avatar
rusty1s committed
19

20
21
22
23
SCATTER_API torch::Tensor
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
             torch::optional<torch::Tensor> optional_out,
             torch::optional<int64_t> dim_size);
rusty1s's avatar
rusty1s committed
24

25
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
rusty1s's avatar
rusty1s committed
26
27
28
29
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
            torch::optional<torch::Tensor> optional_out,
            torch::optional<int64_t> dim_size);

30
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
rusty1s's avatar
rusty1s committed
31
32
33
34
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
            torch::optional<torch::Tensor> optional_out,
            torch::optional<int64_t> dim_size);

35
36
37
38
SCATTER_API torch::Tensor
segment_sum_coo(torch::Tensor src, torch::Tensor index,
                torch::optional<torch::Tensor> optional_out,
                torch::optional<int64_t> dim_size);
rusty1s's avatar
rusty1s committed
39

40
41
42
43
SCATTER_API torch::Tensor
segment_mean_coo(torch::Tensor src, torch::Tensor index,
                 torch::optional<torch::Tensor> optional_out,
                 torch::optional<int64_t> dim_size);
rusty1s's avatar
rusty1s committed
44

45
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
rusty1s's avatar
rusty1s committed
46
47
48
49
segment_min_coo(torch::Tensor src, torch::Tensor index,
                torch::optional<torch::Tensor> optional_out,
                torch::optional<int64_t> dim_size);

50
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
rusty1s's avatar
rusty1s committed
51
52
53
54
segment_max_coo(torch::Tensor src, torch::Tensor index,
                torch::optional<torch::Tensor> optional_out,
                torch::optional<int64_t> dim_size);

55
56
57
SCATTER_API torch::Tensor
gather_coo(torch::Tensor src, torch::Tensor index,
           torch::optional<torch::Tensor> optional_out);
rusty1s's avatar
rusty1s committed
58

59
60
61
SCATTER_API torch::Tensor
segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
                torch::optional<torch::Tensor> optional_out);
rusty1s's avatar
rusty1s committed
62

63
64
65
SCATTER_API torch::Tensor
segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
                 torch::optional<torch::Tensor> optional_out);
rusty1s's avatar
rusty1s committed
66

67
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
rusty1s's avatar
rusty1s committed
68
69
70
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
                torch::optional<torch::Tensor> optional_out);

71
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
rusty1s's avatar
rusty1s committed
72
73
74
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
                torch::optional<torch::Tensor> optional_out);

75
76
77
SCATTER_API torch::Tensor
gather_csr(torch::Tensor src, torch::Tensor indptr,
           torch::optional<torch::Tensor> optional_out);