segment_reduce.h 6.47 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/spmm.h
 * @brief Segment reduce kernel function header.
5
6
7
8
9
 */
#ifndef DGL_ARRAY_CPU_SEGMENT_REDUCE_H_
#define DGL_ARRAY_CPU_SEGMENT_REDUCE_H_

#include <dgl/array.h>
10
#include <dgl/base_heterograph.h>
11
12
#include <dgl/runtime/parallel_for.h>

13
#include <string>
14
#include <vector>
15
16
17
18
19

namespace dgl {
namespace aten {
namespace cpu {

20
/**
21
22
23
24
 * @brief CPU kernel of segment sum.
 * @param feat The input tensor.
 * @param offsets The offset tensor storing the ranges of segments.
 * @param out The output tensor.
25
 */
26
27
28
29
template <typename IdType, typename DType>
void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
  int n = out->shape[0];
  int dim = 1;
30
  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
31
32
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* offsets_data = offsets.Ptr<IdType>();
33
  DType* out_data = out.Ptr<DType>();
34
35
36
37
38
39
  runtime::parallel_for(0, n, [=](int b, int e) {
    for (auto i = b; i < e; ++i) {
      for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {
        for (int k = 0; k < dim; ++k) {
          out_data[i * dim + k] += feat_data[j * dim + k];
        }
40
41
      }
    }
42
  });
43
44
}

45
/**
46
47
48
49
50
 * @brief CPU kernel of segment min/max.
 * @param feat The input tensor.
 * @param offsets The offset tensor storing the ranges of segments.
 * @param out The output tensor.
 * @param arg An auxiliary tensor storing the argmin/max information
51
52
 *        used in backward phase.
 */
53
template <typename IdType, typename DType, typename Cmp>
54
void SegmentCmp(NDArray feat, NDArray offsets, NDArray out, NDArray arg) {
55
56
  int n = out->shape[0];
  int dim = 1;
57
  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
58
59
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* offsets_data = offsets.Ptr<IdType>();
60
61
  DType* out_data = out.Ptr<DType>();
  IdType* arg_data = arg.Ptr<IdType>();
62
63
  std::fill(out_data, out_data + out.NumElements(), Cmp::zero);
  std::fill(arg_data, arg_data + arg.NumElements(), -1);
64
65
66
67
68
69
70
71
72
  runtime::parallel_for(0, n, [=](int b, int e) {
    for (auto i = b; i < e; ++i) {
      for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {
        for (int k = 0; k < dim; ++k) {
          const DType val = feat_data[j * dim + k];
          if (Cmp::Call(out_data[i * dim + k], val)) {
            out_data[i * dim + k] = val;
            arg_data[i * dim + k] = j;
          }
73
74
75
        }
      }
    }
76
  });
77
78
}

79
/**
80
81
82
83
84
 * @brief CPU kernel of Scatter Add (on first dimension) operator.
 * @note math equation: out[idx[i], *] += feat[i, *]
 * @param feat The input tensor.
 * @param idx The indices tensor.
 * @param out The output tensor.
85
86
87
88
89
 */
template <typename IdType, typename DType>
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
  int n = feat->shape[0];
  int dim = 1;
90
  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
91
92
93
94
95
96
97
98
99
100
101
102
103
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* idx_data = idx.Ptr<IdType>();
  DType* out_data = out.Ptr<DType>();
#pragma omp parallel for
  for (int i = 0; i < n; ++i) {
    const int write_row = idx_data[i];
    for (int k = 0; k < dim; ++k) {
#pragma omp atomic
      out_data[write_row * dim + k] += feat_data[i * dim + k];
    }
  }
}

104
/**
105
106
107
108
109
110
111
 * @brief CPU kernel to update gradients for reduce op max/min
 * @param graph The input heterogeneous graph.
 * @param op The binary operator, could be `copy_u`, `copy_e'.
 * @param list_feat List of the input tensors.
 * @param list_idx  List of the indices tensors.
 * @param list_idx_etype List of the node- or edge-type tensors.
 * @param list_out List of the output tensors.
112
113
 */
template <typename IdType, typename DType>
114
115
116
117
118
void UpdateGradMinMax_hetero(
    HeteroGraphPtr graph, const std::string& op,
    const std::vector<NDArray>& list_feat, const std::vector<NDArray>& list_idx,
    const std::vector<NDArray>& list_idx_types,
    std::vector<NDArray>* list_out) {
119
  if (op == "copy_lhs" || op == "copy_rhs") {
120
121
    std::vector<std::vector<dgl_id_t>> src_dst_ntypes(
        graph->NumVertexTypes(), std::vector<dgl_id_t>());
122

123
124
    for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
      auto pair = graph->meta_graph()->FindEdge(etype);
125
126
      const dgl_id_t dst_ntype = pair.first;  // graph is reversed
      const dgl_id_t src_ntype = pair.second;
127
128
129
130
131
132
133
      auto same_src_dst_ntype = std::find(
          std::begin(src_dst_ntypes[dst_ntype]),
          std::end(src_dst_ntypes[dst_ntype]), src_ntype);
      // if op is "copy_lhs", relation type with same src and dst node type will
      // be updated once
      if (op == "copy_lhs" &&
          same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))
134
135
136
137
138
139
140
        continue;
      src_dst_ntypes[dst_ntype].push_back(src_ntype);
      const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();
      const IdType* idx_data = list_idx[dst_ntype].Ptr<IdType>();
      const IdType* idx_type_data = list_idx_types[dst_ntype].Ptr<IdType>();
      int type = (op == "copy_lhs") ? src_ntype : etype;
      DType* out_data = (*list_out)[type].Ptr<DType>();
141
      int dim = 1;
142
143
144
      for (int i = 1; i < (*list_out)[type]->ndim; ++i)
        dim *= (*list_out)[type]->shape[i];
      int n = list_feat[dst_ntype]->shape[0];
145
146
147
#pragma omp parallel for
      for (int i = 0; i < n; ++i) {
        for (int k = 0; k < dim; ++k) {
148
          if (type == idx_type_data[i * dim + k]) {
149
150
            const int write_row = idx_data[i * dim + k];
#pragma omp atomic
151
152
            out_data[write_row * dim + k] +=
                feat_data[i * dim + k];  // feat = dZ
153
154
155
156
157
158
159
160
161
          }
        }
      }
    }
  } else {
    LOG(FATAL) << "Unsupported binary operator: " << op;
  }
}

162
/**
163
164
165
166
167
 * @brief CPU kernel of backward phase of segment min/max.
 * @note math equation: out[arg[i, k], k] = feat[i, k]
 * @param feat The input tensor.
 * @param arg The argmin/argmax tensor.
 * @param out The output tensor.
168
 */
169
170
171
172
template <typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
  int n = feat->shape[0];
  int dim = 1;
173
  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
174
175
176
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* arg_data = arg.Ptr<IdType>();
  DType* out_data = out.Ptr<DType>();
177
178
179
180
181
182
183
  runtime::parallel_for(0, n, [=](int b, int e) {
    for (auto i = b; i < e; ++i) {
      for (int k = 0; k < dim; ++k) {
        int write_row = arg_data[i * dim + k];
        if (write_row >= 0)
          out_data[write_row * dim + k] = feat_data[i * dim + k];
      }
184
    }
185
  });
186
187
188
189
190
191
192
}

}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_SEGMENT_REDUCE_H_