"vscode:/vscode.git/clone" did not exist on "4c6e65437b4e729ea4030d6a21ee13fedac60612"
segment_reduce.cc 4.34 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2020 by Contributors
3
4
 * @file kernel/cpu/segment_reduce.cc
 * @brief Segment reduce C APIs and definitions.
5
6
7
8
9
10
11
12
13
 */
#include "./segment_reduce.h"
#include <dgl/array.h>
#include <string>
#include "./spmm_binary_ops.h"

namespace dgl {
namespace aten {

14
/*! @brief Segment Reduce operator. */
15
template <int XPU, typename IdType, typename DType>
16
17
18
19
20
21
22
void SegmentReduce(
    const std::string& op,
    NDArray feat,
    NDArray offsets,
    NDArray out,
    NDArray arg) {
  if (op == "sum") {
23
    cpu::SegmentSum<IdType, DType>(feat, offsets, out);
24
  } else if (op == "max" || op == "min") {
25
    if (op == "max") {
26
27
      cpu::SegmentCmp<IdType, DType, cpu::op::Max<DType>>(
          feat, offsets, out, arg);
28
    } else {
29
30
      cpu::SegmentCmp<IdType, DType, cpu::op::Min<DType>>(
          feat, offsets, out, arg);
31
    }
32
33
34
35
36
  } else {
    LOG(FATAL) << "Unsupported reduce function " << op;
  }
}

37
/*! @brief Scatter Add.*/
38
template <int XPU, typename IdType, typename DType>
39
40
41
void ScatterAdd(NDArray feat,
                NDArray idx,
                NDArray out) {
42
  cpu::ScatterAdd<IdType, DType>(feat, idx, out);
43
44
}

45
/*! @brief Update gradients for reduce operator max/min on heterogeneous graph.*/
46
template <int XPU, typename IdType, typename DType>
47
48
49
50
51
52
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
                const std::string& op,
                const std::vector<NDArray>& feat,
                const std::vector<NDArray>& idx,
                const std::vector<NDArray>& idx_etype,
                std::vector<NDArray>* out) {
53
  cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
54
55
}

56
/*! @brief Backward function of segment cmp.*/
57
template <int XPU, typename IdType, typename DType>
58
59
60
61
void BackwardSegmentCmp(
    NDArray feat,
    NDArray arg,
    NDArray out) {
62
  cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
63
64
}

65
template void SegmentReduce<kDGLCPU, int32_t, float>(
66
67
68
69
70
    const std::string &op,
    NDArray feat,
    NDArray offsets,
    NDArray out,
    NDArray arg);
71
template void SegmentReduce<kDGLCPU, int64_t, float>(
72
73
74
75
76
    const std::string &op,
    NDArray feat,
    NDArray offsets,
    NDArray out,
    NDArray arg);
77
template void SegmentReduce<kDGLCPU, int32_t, double>(
78
79
80
81
82
    const std::string &op,
    NDArray feat,
    NDArray offsets,
    NDArray out,
    NDArray arg);
83
template void SegmentReduce<kDGLCPU, int64_t, double>(
84
85
86
87
88
    const std::string &op,
    NDArray feat,
    NDArray offsets,
    NDArray out,
    NDArray arg);
89
90

template void ScatterAdd<kDGLCPU, int32_t, float>(
91
92
93
    NDArray feat,
    NDArray idx,
    NDArray out);
94
template void ScatterAdd<kDGLCPU, int64_t, float>(
95
96
97
    NDArray feat,
    NDArray idx,
    NDArray out);
98
template void ScatterAdd<kDGLCPU, int32_t, double>(
99
100
101
    NDArray feat,
    NDArray idx,
    NDArray out);
102
template void ScatterAdd<kDGLCPU, int64_t, double>(
103
104
105
    NDArray feat,
    NDArray arg,
    NDArray out);
106

107
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(
108
109
110
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
111
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, float>(
112
113
114
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
115
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, double>(
116
117
118
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
119
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
120
121
122
123
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);

124
template void BackwardSegmentCmp<kDGLCPU, int32_t, float>(
125
126
127
    NDArray feat,
    NDArray arg,
    NDArray out);
128
template void BackwardSegmentCmp<kDGLCPU, int64_t, float>(
129
130
131
    NDArray feat,
    NDArray arg,
    NDArray out);
132
template void BackwardSegmentCmp<kDGLCPU, int32_t, double>(
133
134
135
    NDArray feat,
    NDArray arg,
    NDArray out);
136
template void BackwardSegmentCmp<kDGLCPU, int64_t, double>(
137
138
139
140
141
142
    NDArray feat,
    NDArray arg,
    NDArray out);

}  // namespace aten
}  // namespace dgl