scatter.h 3 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#pragma once

#include <torch/extension.h>

Matthias Fey's avatar
Matthias Fey committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#if (defined __cpp_inline_variables) || __cplusplus >= 201703L
#define SCATTER_INLINE_VARIABLE inline
#else
#ifdef _MSC_VER
#define SCATTER_INLINE_VARIABLE __declspec(selectany)
#else
#define SCATTER_INLINE_VARIABLE __attribute__((weak))
#endif
#endif

namespace scatter {
int64_t cuda_version() noexcept;

namespace detail {
SCATTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
} // namespace detail
} // namespace scatter
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
                          torch::optional<torch::Tensor> optional_out,
                          torch::optional<int64_t> dim_size);

torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
                           torch::optional<torch::Tensor> optional_out,
                           torch::optional<int64_t> dim_size);

std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
            torch::optional<torch::Tensor> optional_out,
            torch::optional<int64_t> dim_size);

std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
            torch::optional<torch::Tensor> optional_out,
            torch::optional<int64_t> dim_size);

torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index,
                              torch::optional<torch::Tensor> optional_out,
                              torch::optional<int64_t> dim_size);

torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index,
                               torch::optional<torch::Tensor> optional_out,
                               torch::optional<int64_t> dim_size);

std::tuple<torch::Tensor, torch::Tensor>
segment_min_coo(torch::Tensor src, torch::Tensor index,
                torch::optional<torch::Tensor> optional_out,
                torch::optional<int64_t> dim_size);

std::tuple<torch::Tensor, torch::Tensor>
segment_max_coo(torch::Tensor src, torch::Tensor index,
                torch::optional<torch::Tensor> optional_out,
                torch::optional<int64_t> dim_size);

torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
                         torch::optional<torch::Tensor> optional_out);

torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
                              torch::optional<torch::Tensor> optional_out);

torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
                               torch::optional<torch::Tensor> optional_out);

std::tuple<torch::Tensor, torch::Tensor>
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
                torch::optional<torch::Tensor> optional_out);

std::tuple<torch::Tensor, torch::Tensor>
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
                torch::optional<torch::Tensor> optional_out);

torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
                         torch::optional<torch::Tensor> optional_out);