sddmm.h 8 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/sddmm.h
 * @brief SDDMM CPU kernel function header.
5
6
7
8
9
10
 */
#ifndef DGL_ARRAY_CPU_SDDMM_H_
#define DGL_ARRAY_CPU_SDDMM_H_

#include <dgl/array.h>
#include <dgl/bcast.h>
11
#include <dgl/runtime/parallel_for.h>
12

13
#include "../selector.h"
14
15
16
17
18
19

namespace dgl {
namespace aten {
namespace cpu {

/*!
20
21
22
23
24
25
26
 * @brief CPU kernel of g-SDDMM on Csr format.
 * @param bcast Broadcast information.
 * @param csr The Csr matrix.
 * @param lhs The left hand side operand feature.
 * @param rhs The right hand size operand feature.
 * @param out The result feature on edges.
 * @note it uses node parallel strategy, different threads are responsible
27
28
 *       for the computation of different nodes.
 */
29
30
31
32
33
34
template <
    typename IdType, typename DType, typename Op, int LhsTarget = 0,
    int RhsTarget = 2>
void SDDMMCsr(
    const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs,
    NDArray out) {
35
36
37
38
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
39
40
  const DType* X = lhs.Ptr<DType>();
  const DType* Y = rhs.Ptr<DType>();
41
42
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
                rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size;
43
  DType* O = out.Ptr<DType>();
44
45
46
47
48
  runtime::parallel_for(0, csr.num_rows, [=](IdType b, IdType e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      for (IdType j = row_start; j < row_end; ++j) {
        const IdType cid = indices[j];
49
        const IdType eid = has_idx ? edges[j] : j;
50
51
52
53
        DType* out_off = O + eid * dim;
        for (int64_t k = 0; k < dim; ++k) {
          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
54
55
56
57
58
59
60
61
62
63
          const DType* lhs_off =
              Op::use_lhs
                  ? X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim +
                        lhs_add * reduce_size
                  : nullptr;
          const DType* rhs_off =
              Op::use_rhs
                  ? Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim +
                        rhs_add * reduce_size
                  : nullptr;
64
65
          out_off[k] = Op::Call(lhs_off, rhs_off, reduce_size);
        }
66
67
      }
    }
68
  });
69
70
71
}

/*!
72
73
74
75
76
77
78
 * @brief CPU kernel of g-SDDMM on Coo format.
 * @param bcast Broadcast information.
 * @param coo The COO matrix.
 * @param lhs The left hand side operand feature.
 * @param rhs The right hand size operand feature.
 * @param out The result feature on edges.
 * @note it uses edge parallel strategy, different threads are responsible
79
80
 *       for the computation of different edges.
 */
81
82
83
84
85
86
template <
    typename IdType, typename DType, typename Op, int LhsTarget = 0,
    int RhsTarget = 2>
void SDDMMCoo(
    const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs,
    NDArray out) {
87
88
89
90
  const bool has_idx = !IsNullArray(coo.data);
  const IdType* row = coo.row.Ptr<IdType>();
  const IdType* col = coo.col.Ptr<IdType>();
  const IdType* edges = coo.data.Ptr<IdType>();
91
92
  const DType* X = lhs.Ptr<DType>();
  const DType* Y = rhs.Ptr<DType>();
93
94
  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
                rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size;
95
96
  DType* O = out.Ptr<DType>();
#pragma omp parallel for
97
  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
98
99
    const IdType rid = row[i];
    const IdType cid = col[i];
100
    const IdType eid = has_idx ? edges[i] : i;
101
102
103
104
    DType* out_off = O + eid * dim;
    for (int64_t k = 0; k < dim; ++k) {
      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
105
106
107
108
109
110
111
112
      const DType* lhs_off =
          Op::use_lhs ? X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim +
                            lhs_add * reduce_size
                      : nullptr;
      const DType* rhs_off =
          Op::use_rhs ? Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim +
                            rhs_add * reduce_size
                      : nullptr;
113
114
115
116
117
118
119
      out_off[k] = Op::Call(lhs_off, rhs_off, bcast.reduce_size);
    }
  }
}

namespace op {

120
////////////////////////// binary operators on CPU /////////////////////////////
121
122
123
124
template <typename DType>
struct Add {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
125
126
  inline static DType Call(
      const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
127
128
129
130
131
132
133
134
    return *lhs_off + *rhs_off;
  }
};

template <typename DType>
struct Sub {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
135
136
  inline static DType Call(
      const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
137
138
139
140
141
142
143
144
    return *lhs_off - *rhs_off;
  }
};

template <typename DType>
struct Mul {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
145
146
  inline static DType Call(
      const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
147
148
149
150
151
152
153
154
    return *lhs_off * *rhs_off;
  }
};

template <typename DType>
struct Div {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
155
156
  inline static DType Call(
      const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
157
158
159
160
161
162
163
164
    return *lhs_off / *rhs_off;
  }
};

template <typename DType>
struct CopyLhs {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = false;
165
166
  inline static DType Call(
      const DType* lhs_off, const DType*, int64_t len = 1) {
167
168
169
170
171
172
173
174
    return *lhs_off;
  }
};

template <typename DType>
struct CopyRhs {
  static constexpr bool use_lhs = false;
  static constexpr bool use_rhs = true;
175
176
  inline static DType Call(
      const DType*, const DType* rhs_off, int64_t len = 1) {
177
178
179
180
181
182
183
184
    return *rhs_off;
  }
};

template <typename DType>
struct Dot {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
185
186
  inline static DType Call(
      const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
187
188
189
190
191
192
193
194
    DType rst = 0;
    for (int64_t l = 0; l < len; ++l) {
      rst += lhs_off[l] * rhs_off[l];
    }
    return rst;
  }
};

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
#define SWITCH_OP(op, Op, ...)                                   \
  do {                                                           \
    if ((op) == "add") {                                         \
      typedef dgl::aten::cpu::op::Add<DType> Op;                 \
      { __VA_ARGS__ }                                            \
    } else if ((op) == "sub") {                                  \
      typedef dgl::aten::cpu::op::Sub<DType> Op;                 \
      { __VA_ARGS__ }                                            \
    } else if ((op) == "mul") {                                  \
      typedef dgl::aten::cpu::op::Mul<DType> Op;                 \
      { __VA_ARGS__ }                                            \
    } else if ((op) == "div") {                                  \
      typedef dgl::aten::cpu::op::Div<DType> Op;                 \
      { __VA_ARGS__ }                                            \
    } else if ((op) == "copy_lhs") {                             \
      typedef dgl::aten::cpu::op::CopyLhs<DType> Op;             \
      { __VA_ARGS__ }                                            \
    } else if ((op) == "copy_rhs") {                             \
      typedef dgl::aten::cpu::op::CopyRhs<DType> Op;             \
      { __VA_ARGS__ }                                            \
    } else if ((op) == "dot") {                                  \
      typedef dgl::aten::cpu::op::Dot<DType> Op;                 \
      { __VA_ARGS__ }                                            \
    } else {                                                     \
      LOG(FATAL) << "Unsupported SDDMM binary operator: " << op; \
    }                                                            \
221
222
223
224
225
226
227
228
229
  } while (0)

}  // namespace op

}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_SDDMM_H_