sddmm.h 7.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/sddmm.h
 * \brief SDDMM CPU kernel function header.
 */
#ifndef DGL_ARRAY_CPU_SDDMM_H_
#define DGL_ARRAY_CPU_SDDMM_H_

#include <dgl/array.h>
#include <dgl/bcast.h>
11
#include <dgl/runtime/parallel_for.h>
12
#include "../selector.h"
13
14
15
16
17
18
19
20
21

namespace dgl {
namespace aten {
namespace cpu {

/*!
 * \brief CPU kernel of g-SDDMM on Csr format.
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
22
23
 * \param lhs The left hand side operand feature.
 * \param rhs The right hand size operand feature.
24
25
26
27
 * \param out The result feature on edges.
 * \note it uses node parallel strategy, different threads are responsible
 *       for the computation of different nodes.
 */
28
29
template <typename IdType, typename DType, typename Op,
          int LhsTarget = 0, int RhsTarget = 2>
30
31
void SDDMMCsr(const BcastOff& bcast,
              const CSRMatrix& csr,
32
              NDArray lhs, NDArray rhs, NDArray out) {
33
34
35
36
  const bool has_idx = !IsNullArray(csr.data);
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  const IdType* edges = csr.data.Ptr<IdType>();
37
38
  const DType* X = lhs.Ptr<DType>();
  const DType* Y = rhs.Ptr<DType>();
39
40
41
42
43
  const int64_t dim = bcast.out_len,
                lhs_dim = bcast.lhs_len,
                rhs_dim = bcast.rhs_len,
                reduce_size = bcast.reduce_size;
  DType* O = out.Ptr<DType>();
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  runtime::parallel_for(0, csr.num_rows, [=](IdType b, IdType e) {
    for (auto rid = b; rid < e; ++rid) {
      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
      for (IdType j = row_start; j < row_end; ++j) {
        const IdType cid = indices[j];
        const IdType eid = has_idx? edges[j] : j;
        DType* out_off = O + eid * dim;
        for (int64_t k = 0; k < dim; ++k) {
          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
          const DType* lhs_off = Op::use_lhs
            ? X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim + lhs_add * reduce_size
            : nullptr;
          const DType* rhs_off = Op::use_rhs
            ? Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim + rhs_add * reduce_size
            : nullptr;
          out_off[k] = Op::Call(lhs_off, rhs_off, reduce_size);
        }
62
63
      }
    }
64
  });
65
66
67
68
69
70
}

/*!
 * \brief CPU kernel of g-SDDMM on Coo format.
 * \param bcast Broadcast information.
 * \param coo The COO matrix.
71
72
 * \param lhs The left hand side operand feature.
 * \param rhs The right hand size operand feature.
73
74
75
76
 * \param out The result feature on edges.
 * \note it uses edge parallel strategy, different threads are responsible
 *       for the computation of different edges.
 */
77
78
template <typename IdType, typename DType, typename Op,
          int LhsTarget = 0, int RhsTarget = 2>
79
80
void SDDMMCoo(const BcastOff& bcast,
              const COOMatrix& coo,
81
              NDArray lhs, NDArray rhs, NDArray out) {
82
83
84
85
  const bool has_idx = !IsNullArray(coo.data);
  const IdType* row = coo.row.Ptr<IdType>();
  const IdType* col = coo.col.Ptr<IdType>();
  const IdType* edges = coo.data.Ptr<IdType>();
86
87
  const DType* X = lhs.Ptr<DType>();
  const DType* Y = rhs.Ptr<DType>();
88
89
90
91
92
93
  const int64_t dim = bcast.out_len,
                lhs_dim = bcast.lhs_len,
                rhs_dim = bcast.rhs_len,
                reduce_size = bcast.reduce_size;
  DType* O = out.Ptr<DType>();
#pragma omp parallel for
94
  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
95
96
97
98
99
100
101
    const IdType rid = row[i];
    const IdType cid = col[i];
    const IdType eid = has_idx? edges[i] : i;
    DType* out_off = O + eid * dim;
    for (int64_t k = 0; k < dim; ++k) {
      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
102
103
104
105
      const DType* lhs_off = Op::use_lhs ?
        X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim + lhs_add * reduce_size : nullptr;
      const DType* rhs_off = Op::use_rhs ?
        Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim + rhs_add * reduce_size : nullptr;
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
      out_off[k] = Op::Call(lhs_off, rhs_off, bcast.reduce_size);
    }
  }
}

namespace op {

//////////////////////////////// binary operators on CPU ////////////////////////////////
template <typename DType>
struct Add {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
    return *lhs_off + *rhs_off;
  }
};

template <typename DType>
struct Sub {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
    return *lhs_off - *rhs_off;
  }
};

template <typename DType>
struct Mul {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
    return *lhs_off * *rhs_off;
  }
};

template <typename DType>
struct Div {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
    return *lhs_off / *rhs_off;
  }
};

template <typename DType>
struct CopyLhs {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = false;
  inline static DType Call(const DType* lhs_off, const DType*, int64_t len = 1) {
    return *lhs_off;
  }
};

template <typename DType>
struct CopyRhs {
  static constexpr bool use_lhs = false;
  static constexpr bool use_rhs = true;
  inline static DType Call(const DType* , const DType* rhs_off, int64_t len = 1) {
    return *rhs_off;
  }
};

template <typename DType>
struct Dot {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
    DType rst = 0;
    for (int64_t l = 0; l < len; ++l) {
      rst += lhs_off[l] * rhs_off[l];
    }
    return rst;
  }
};

#define SWITCH_OP(op, Op, ...)                                      \
  do {                                                              \
    if ((op) == "add") {                                            \
      typedef dgl::aten::cpu::op::Add<DType> Op;                    \
      { __VA_ARGS__ }                                               \
    } else if ((op) == "sub") {                                     \
      typedef dgl::aten::cpu::op::Sub<DType> Op;                    \
      { __VA_ARGS__ }                                               \
    } else if ((op) == "mul") {                                     \
      typedef dgl::aten::cpu::op::Mul<DType> Op;                    \
      { __VA_ARGS__ }                                               \
    } else if ((op) == "div") {                                     \
      typedef dgl::aten::cpu::op::Div<DType> Op;                    \
      { __VA_ARGS__ }                                               \
195
    } else if ((op) == "copy_lhs") {                                \
196
197
      typedef dgl::aten::cpu::op::CopyLhs<DType> Op;                \
      { __VA_ARGS__ }                                               \
198
    } else if ((op) == "copy_rhs") {                                \
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
      typedef dgl::aten::cpu::op::CopyRhs<DType> Op;                \
      { __VA_ARGS__ }                                               \
    } else if ((op) == "dot") {                                     \
      typedef dgl::aten::cpu::op::Dot<DType> Op;                    \
      { __VA_ARGS__ }                                               \
    } else {                                                        \
      LOG(FATAL) << "Unsupported SDDMM binary operator: " << op;    \
    }                                                               \
  } while (0)

}  // namespace op

}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_SDDMM_H_