segment_reduce.h 7.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/spmm.h
 * \brief Segment reduce kernel function header.
 */
#ifndef DGL_ARRAY_CPU_SEGMENT_REDUCE_H_
#define DGL_ARRAY_CPU_SEGMENT_REDUCE_H_

#include <dgl/array.h>
10
#include <dgl/runtime/parallel_for.h>
11
12
13
#include <dgl/base_heterograph.h>
#include <vector>
#include <string>
14
15
16
17
18

namespace dgl {
namespace aten {
namespace cpu {

19
20
21
22
23
24
/*!
 * \brief CPU kernel of segment sum.
 * \param feat The input tensor.
 * \param offsets The offset tensor storing the ranges of segments.
 * \param out The output tensor.
 */
25
26
27
28
29
30
31
32
33
template <typename IdType, typename DType>
void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
  int n = out->shape[0];
  int dim = 1;
  for (int i = 1; i < out->ndim; ++i)
    dim *= out->shape[i];
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* offsets_data = offsets.Ptr<IdType>();
  DType *out_data = out.Ptr<DType>();
34
35
36
37
38
39
  runtime::parallel_for(0, n, [=](int b, int e) {
    for (auto i = b; i < e; ++i) {
      for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {
        for (int k = 0; k < dim; ++k) {
          out_data[i * dim + k] += feat_data[j * dim + k];
        }
40
41
      }
    }
42
  });
43
44
}

45
46
47
48
49
50
51
52
/*!
 * \brief CPU kernel of segment min/max.
 * \param feat The input tensor.
 * \param offsets The offset tensor storing the ranges of segments.
 * \param out The output tensor.
 * \param arg An auxiliary tensor storing the argmin/max information
 *        used in backward phase.
 */
53
54
55
56
57
58
59
60
61
62
63
64
65
template <typename IdType, typename DType, typename Cmp>
void SegmentCmp(NDArray feat, NDArray offsets,
                NDArray out, NDArray arg) {
  int n = out->shape[0];
  int dim = 1;
  for (int i = 1; i < out->ndim; ++i)
    dim *= out->shape[i];
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* offsets_data = offsets.Ptr<IdType>();
  DType *out_data = out.Ptr<DType>();
  IdType *arg_data = arg.Ptr<IdType>();
  std::fill(out_data, out_data + out.NumElements(), Cmp::zero);
  std::fill(arg_data, arg_data + arg.NumElements(), -1);
66
67
68
69
70
71
72
73
74
  runtime::parallel_for(0, n, [=](int b, int e) {
    for (auto i = b; i < e; ++i) {
      for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {
        for (int k = 0; k < dim; ++k) {
          const DType val = feat_data[j * dim + k];
          if (Cmp::Call(out_data[i * dim + k], val)) {
            out_data[i * dim + k] = val;
            arg_data[i * dim + k] = j;
          }
75
76
77
        }
      }
    }
78
  });
79
80
}

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/*!
 * \brief CPU kernel of Scatter Add (on first dimension) operator.
 * \note math equation: out[idx[i], *] += feat[i, *]
 * \param feat The input tensor.
 * \param idx The indices tensor.
 * \param out The output tensor.
 */
template <typename IdType, typename DType>
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
  int n = feat->shape[0];
  int dim = 1;
  for (int i = 1; i < out->ndim; ++i)
    dim *= out->shape[i];
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* idx_data = idx.Ptr<IdType>();
  DType* out_data = out.Ptr<DType>();
#pragma omp parallel for
  for (int i = 0; i < n; ++i) {
    const int write_row = idx_data[i];
    for (int k = 0; k < dim; ++k) {
#pragma omp atomic
      out_data[write_row * dim + k] += feat_data[i * dim + k];
    }
  }
}

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
/*!
 * \param graph The input heterogeneous graph.
 * \param op The binary operator, could be `copy_u`, `copy_e'.
 * \param list_feat List of the input tensors.
 * \param list_idx  List of the indices tensors.
 * \param list_idx_etype List of the node- or edge-type tensors.
 * \param list_out List of the output tensors.
 */
template <typename IdType, typename DType>
void UpdateGradMinMax_hetero(HeteroGraphPtr graph,
                       const std::string& op,
                       const std::vector<NDArray>& list_feat,
                       const std::vector<NDArray>& list_idx,
                       const std::vector<NDArray>& list_idx_ntypes,
                       std::vector<NDArray>* list_out) {
  if (op == "copy_lhs") {
    std::vector<std::vector<dgl_id_t>> dst_src_ntids(graph->NumVertexTypes(),
    std::vector<dgl_id_t>());
    for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
      auto pair = graph->meta_graph()->FindEdge(etype);
      const dgl_id_t dst_id = pair.first;  // graph is reversed
      const dgl_id_t src_id = pair.second;
      dst_src_ntids[dst_id].push_back(src_id);  // can have duplicates. Use Hashtable to optimize.
    }
    std::vector<bool> updated(graph->NumVertexTypes());
    for (int dst_id = 0; dst_id < dst_src_ntids.size(); ++dst_id) {
      std::fill(updated.begin(), updated.end(), false);
      for (int j = 0; j < dst_src_ntids[dst_id].size(); ++j) {
        int src_id = dst_src_ntids[dst_id][j];
        if (updated[src_id]) continue;
        const DType* feat_data = list_feat[dst_id].Ptr<DType>();
        const IdType* idx_data = list_idx[dst_id].Ptr<IdType>();
        const IdType* idx_ntype_data = list_idx_ntypes[dst_id].Ptr<IdType>();
        DType* out_data = (*list_out)[src_id].Ptr<DType>();
        int dim = 1;
        for (int i = 1; i < (*list_out)[src_id]->ndim; ++i)
          dim *= (*list_out)[src_id]->shape[i];
        int n = list_feat[dst_id]->shape[0];
#pragma omp parallel for
        for (int i = 0; i < n; ++i) {
          for (int k = 0; k < dim; ++k) {
            if (src_id == idx_ntype_data[i * dim + k]) {
              const int write_row = idx_data[i * dim + k];
#pragma omp atomic
              out_data[write_row * dim + k] += feat_data[i * dim + k];  // feat = dZ
            }
          }
        }
        updated[src_id] = true;
      }
    }
  } else if (op == "copy_rhs") {
    for (dgl_type_t etid = 0; etid < graph->NumEdgeTypes(); ++etid) {
      auto pair = graph->meta_graph()->FindEdge(etid);
      const dgl_id_t dst_id = pair.first;  // graph is reversed
      const dgl_id_t src_id = pair.second;
      const DType* feat_data = list_feat[dst_id].Ptr<DType>();
      const IdType* idx_data = list_idx[dst_id].Ptr<IdType>();
      const IdType* idx_ntype_data = list_idx_ntypes[dst_id].Ptr<IdType>();
      DType* out_data = (*list_out)[etid].Ptr<DType>();
      int dim = 1;
      for (int i = 1; i < (*list_out)[etid]->ndim; ++i)
        dim *= (*list_out)[etid]->shape[i];
      int n = list_feat[dst_id]->shape[0];
#pragma omp parallel for
      for (int i = 0; i < n; ++i) {
        for (int k = 0; k < dim; ++k) {
          if (etid == idx_ntype_data[i * dim + k]) {
            const int write_row = idx_data[i * dim + k];
#pragma omp atomic
            out_data[write_row * dim + k] += feat_data[i * dim + k];  // feat = dZ
          }
        }
      }
    }
  } else {
    LOG(FATAL) << "Unsupported binary operator: " << op;
  }
}

187
188
/*!
 * \brief CPU kernel of backward phase of segment min/max.
189
 * \note math equation: out[arg[i, k], k] = feat[i, k]
190
191
192
193
 * \param feat The input tensor.
 * \param arg The argmin/argmax tensor.
 * \param out The output tensor.
 */
194
195
196
197
198
199
200
201
202
template <typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
  int n = feat->shape[0];
  int dim = 1;
  for (int i = 1; i < out->ndim; ++i)
    dim *= out->shape[i];
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* arg_data = arg.Ptr<IdType>();
  DType* out_data = out.Ptr<DType>();
203
204
205
206
207
208
209
  runtime::parallel_for(0, n, [=](int b, int e) {
    for (auto i = b; i < e; ++i) {
      for (int k = 0; k < dim; ++k) {
        int write_row = arg_data[i * dim + k];
        if (write_row >= 0)
          out_data[write_row * dim + k] = feat_data[i * dim + k];
      }
210
    }
211
  });
212
213
214
215
216
217
218
}

}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_SEGMENT_REDUCE_H_