segment_reduce.h 6.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/spmm.h
 * \brief Segment reduce kernel function header.
 */
#ifndef DGL_ARRAY_CPU_SEGMENT_REDUCE_H_
#define DGL_ARRAY_CPU_SEGMENT_REDUCE_H_

#include <dgl/array.h>
10
#include <dgl/runtime/parallel_for.h>
11
12
13
#include <dgl/base_heterograph.h>
#include <vector>
#include <string>
14
15
16
17
18

namespace dgl {
namespace aten {
namespace cpu {

19
20
21
22
23
24
/*!
 * \brief CPU kernel of segment sum.
 * \param feat The input tensor.
 * \param offsets The offset tensor storing the ranges of segments.
 * \param out The output tensor.
 */
25
26
27
28
29
30
31
32
33
template <typename IdType, typename DType>
void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
  int n = out->shape[0];
  int dim = 1;
  for (int i = 1; i < out->ndim; ++i)
    dim *= out->shape[i];
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* offsets_data = offsets.Ptr<IdType>();
  DType *out_data = out.Ptr<DType>();
34
35
36
37
38
39
  runtime::parallel_for(0, n, [=](int b, int e) {
    for (auto i = b; i < e; ++i) {
      for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {
        for (int k = 0; k < dim; ++k) {
          out_data[i * dim + k] += feat_data[j * dim + k];
        }
40
41
      }
    }
42
  });
43
44
}

45
46
47
48
49
50
51
52
/*!
 * \brief CPU kernel of segment min/max.
 * \param feat The input tensor.
 * \param offsets The offset tensor storing the ranges of segments.
 * \param out The output tensor.
 * \param arg An auxiliary tensor storing the argmin/max information
 *        used in backward phase.
 */
53
54
55
56
57
58
59
60
61
62
63
64
65
template <typename IdType, typename DType, typename Cmp>
void SegmentCmp(NDArray feat, NDArray offsets,
                NDArray out, NDArray arg) {
  int n = out->shape[0];
  int dim = 1;
  for (int i = 1; i < out->ndim; ++i)
    dim *= out->shape[i];
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* offsets_data = offsets.Ptr<IdType>();
  DType *out_data = out.Ptr<DType>();
  IdType *arg_data = arg.Ptr<IdType>();
  std::fill(out_data, out_data + out.NumElements(), Cmp::zero);
  std::fill(arg_data, arg_data + arg.NumElements(), -1);
66
67
68
69
70
71
72
73
74
  runtime::parallel_for(0, n, [=](int b, int e) {
    for (auto i = b; i < e; ++i) {
      for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {
        for (int k = 0; k < dim; ++k) {
          const DType val = feat_data[j * dim + k];
          if (Cmp::Call(out_data[i * dim + k], val)) {
            out_data[i * dim + k] = val;
            arg_data[i * dim + k] = j;
          }
75
76
77
        }
      }
    }
78
  });
79
80
}

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/*!
 * \brief CPU kernel of Scatter Add (on first dimension) operator.
 * \note math equation: out[idx[i], *] += feat[i, *]
 * \param feat The input tensor.
 * \param idx The indices tensor.
 * \param out The output tensor.
 */
template <typename IdType, typename DType>
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
  int n = feat->shape[0];
  int dim = 1;
  for (int i = 1; i < out->ndim; ++i)
    dim *= out->shape[i];
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* idx_data = idx.Ptr<IdType>();
  DType* out_data = out.Ptr<DType>();
#pragma omp parallel for
  for (int i = 0; i < n; ++i) {
    const int write_row = idx_data[i];
    for (int k = 0; k < dim; ++k) {
#pragma omp atomic
      out_data[write_row * dim + k] += feat_data[i * dim + k];
    }
  }
}

107
/*!
108
 * \brief CPU kernel to update gradients for reduce op max/min
109
110
111
112
113
114
115
116
117
118
119
120
 * \param graph The input heterogeneous graph.
 * \param op The binary operator, could be `copy_u`, `copy_e'.
 * \param list_feat List of the input tensors.
 * \param list_idx  List of the indices tensors.
 * \param list_idx_etype List of the node- or edge-type tensors.
 * \param list_out List of the output tensors.
 */
template <typename IdType, typename DType>
void UpdateGradMinMax_hetero(HeteroGraphPtr graph,
                       const std::string& op,
                       const std::vector<NDArray>& list_feat,
                       const std::vector<NDArray>& list_idx,
121
                       const std::vector<NDArray>& list_idx_types,
122
                       std::vector<NDArray>* list_out) {
123
124
  if (op == "copy_lhs" || op == "copy_rhs") {
    std::vector<std::vector<dgl_id_t>> src_dst_ntypes(graph->NumVertexTypes(),
125
    std::vector<dgl_id_t>());
126

127
128
    for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
      auto pair = graph->meta_graph()->FindEdge(etype);
129
130
131
132
133
134
135
136
137
138
139
140
141
      const dgl_id_t dst_ntype = pair.first;  // graph is reversed
      const dgl_id_t src_ntype = pair.second;
      auto same_src_dst_ntype = std::find(std::begin(src_dst_ntypes[dst_ntype]),
        std::end(src_dst_ntypes[dst_ntype]), src_ntype);
      // if op is "copy_lhs", relation type with same src and dst node type will be updated once
      if (op == "copy_lhs" && same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))
        continue;
      src_dst_ntypes[dst_ntype].push_back(src_ntype);
      const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();
      const IdType* idx_data = list_idx[dst_ntype].Ptr<IdType>();
      const IdType* idx_type_data = list_idx_types[dst_ntype].Ptr<IdType>();
      int type = (op == "copy_lhs") ? src_ntype : etype;
      DType* out_data = (*list_out)[type].Ptr<DType>();
142
      int dim = 1;
143
144
145
      for (int i = 1; i < (*list_out)[type]->ndim; ++i)
        dim *= (*list_out)[type]->shape[i];
      int n = list_feat[dst_ntype]->shape[0];
146
147
148
#pragma omp parallel for
      for (int i = 0; i < n; ++i) {
        for (int k = 0; k < dim; ++k) {
149
          if (type == idx_type_data[i * dim + k]) {
150
151
152
153
154
155
156
157
158
159
160
161
            const int write_row = idx_data[i * dim + k];
#pragma omp atomic
            out_data[write_row * dim + k] += feat_data[i * dim + k];  // feat = dZ
          }
        }
      }
    }
  } else {
    LOG(FATAL) << "Unsupported binary operator: " << op;
  }
}

162
163
/*!
 * \brief CPU kernel of backward phase of segment min/max.
164
 * \note math equation: out[arg[i, k], k] = feat[i, k]
165
166
167
168
 * \param feat The input tensor.
 * \param arg The argmin/argmax tensor.
 * \param out The output tensor.
 */
169
170
171
172
173
174
175
176
177
template <typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
  int n = feat->shape[0];
  int dim = 1;
  for (int i = 1; i < out->ndim; ++i)
    dim *= out->shape[i];
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* arg_data = arg.Ptr<IdType>();
  DType* out_data = out.Ptr<DType>();
178
179
180
181
182
183
184
  runtime::parallel_for(0, n, [=](int b, int e) {
    for (auto i = b; i < e; ++i) {
      for (int k = 0; k < dim; ++k) {
        int write_row = arg_data[i * dim + k];
        if (write_row >= 0)
          out_data[write_row * dim + k] = feat_data[i * dim + k];
      }
185
    }
186
  });
187
188
189
190
191
192
193
}

}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_SEGMENT_REDUCE_H_