segment_reduce.cc 5.68 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file kernel/cpu/segment_reduce.cc
 * @brief Segment reduce C APIs and definitions.
6
 */
sangwzh's avatar
sangwzh committed
7
#include "segment_reduce.h"
8

9
#include <dgl/array.h>
10

11
#include <string>
12

sangwzh's avatar
sangwzh committed
13
#include "spmm_binary_ops.h"
14
15
16
17

namespace dgl {
namespace aten {

18
/** @brief Segment Reduce operator. */
19
template <int XPU, typename IdType, typename DType>
20
void SegmentReduce(
21
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
22
23
    NDArray arg) {
  if (op == "sum") {
24
    cpu::SegmentSum<IdType, DType>(feat, offsets, out);
25
  } else if (op == "max" || op == "min") {
26
    if (op == "max") {
27
28
      cpu::SegmentCmp<IdType, DType, cpu::op::Max<DType>>(
          feat, offsets, out, arg);
29
    } else {
30
31
      cpu::SegmentCmp<IdType, DType, cpu::op::Min<DType>>(
          feat, offsets, out, arg);
32
    }
33
34
35
36
37
  } else {
    LOG(FATAL) << "Unsupported reduce function " << op;
  }
}

38
/** @brief Scatter Add.*/
39
template <int XPU, typename IdType, typename DType>
40
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
41
  cpu::ScatterAdd<IdType, DType>(feat, idx, out);
42
43
}

44
45
/** @brief Update gradients for reduce operator max/min on heterogeneous
 * graph.*/
46
template <int XPU, typename IdType, typename DType>
47
48
49
50
void UpdateGradMinMax_hetero(
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
51
  cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
52
53
}

54
/** @brief Backward function of segment cmp.*/
55
template <int XPU, typename IdType, typename DType>
56
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
57
  cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
58
59
}

60
61
62
63
64
65
template void SegmentReduce<kDGLCPU, int32_t, BFloat16>(
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
    NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, BFloat16>(
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
    NDArray arg);
66
template void SegmentReduce<kDGLCPU, int32_t, float>(
67
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
68
    NDArray arg);
69
template void SegmentReduce<kDGLCPU, int64_t, float>(
70
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
71
    NDArray arg);
72
template void SegmentReduce<kDGLCPU, int32_t, double>(
73
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
74
    NDArray arg);
75
template void SegmentReduce<kDGLCPU, int64_t, double>(
76
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
77
    NDArray arg);
78

79
80
81
82
83
84
85
86
87
88
template <>
void ScatterAdd<kDGLCPU, int32_t, BFloat16>(
    NDArray feat, NDArray idx, NDArray out) {
  LOG(FATAL) << "Unsupported CPU kernel for ScatterAdd for BF16.";
}
template <>
void ScatterAdd<kDGLCPU, int64_t, BFloat16>(
    NDArray feat, NDArray idx, NDArray out) {
  LOG(FATAL) << "Unsupported CPU kernel for ScatterAdd for BF16.";
}
89
template void ScatterAdd<kDGLCPU, int32_t, float>(
90
    NDArray feat, NDArray idx, NDArray out);
91
template void ScatterAdd<kDGLCPU, int64_t, float>(
92
    NDArray feat, NDArray idx, NDArray out);
93
template void ScatterAdd<kDGLCPU, int32_t, double>(
94
    NDArray feat, NDArray idx, NDArray out);
95
template void ScatterAdd<kDGLCPU, int64_t, double>(
96
    NDArray feat, NDArray arg, NDArray out);
97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
template <>
void UpdateGradMinMax_hetero<kDGLCPU, int32_t, BFloat16>(
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
  LOG(FATAL) << "Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16.";
}
template <>
void UpdateGradMinMax_hetero<kDGLCPU, int64_t, BFloat16>(
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
  LOG(FATAL) << "Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16.";
}
112
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(
113
114
115
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
116
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, float>(
117
118
119
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
120
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, double>(
121
122
123
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
124
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
125
126
127
128
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);

129
130
131
132
template void BackwardSegmentCmp<kDGLCPU, int32_t, BFloat16>(
    NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, BFloat16>(
    NDArray feat, NDArray arg, NDArray out);
133
template void BackwardSegmentCmp<kDGLCPU, int32_t, float>(
134
    NDArray feat, NDArray arg, NDArray out);
135
template void BackwardSegmentCmp<kDGLCPU, int64_t, float>(
136
    NDArray feat, NDArray arg, NDArray out);
137
template void BackwardSegmentCmp<kDGLCPU, int32_t, double>(
138
    NDArray feat, NDArray arg, NDArray out);
139
template void BackwardSegmentCmp<kDGLCPU, int64_t, double>(
140
    NDArray feat, NDArray arg, NDArray out);
141
142
143

}  // namespace aten
}  // namespace dgl