sddmm.cc 9.66 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 *  Copyright (c) 2020 by Contributors
 * \file aten/cpu/sddmm.cc
 * \brief SDDMM C APIs and definitions.
 */
#include "./sddmm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#define SWITCH_RHS(rhs_target, RhsTarget, ...)                        \
  do {                                                                \
    if ((rhs_target) == 0) {                                          \
      constexpr int RhsTarget = 0;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 1) {                                   \
      constexpr int RhsTarget = 1;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 2) {                                   \
      constexpr int RhsTarget = 2;                                    \
      { __VA_ARGS__ }                                                 \
    } else {                                                          \
      LOG(INFO) << "Invalid rhs target: " << (rhs_target);            \
    }                                                                 \
  } while (0)

#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
  do {                                                                  \
    if ((lhs_target) == 0) {                                            \
      constexpr int LhsTarget = 0;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 1) {                                     \
      constexpr int LhsTarget = 1;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 2) {                                     \
      constexpr int LhsTarget = 2;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else {                                                            \
      LOG(INFO) << "Invalid lhs target: " << (lhs_target);              \
    }                                                                   \
  } while (0)

44
45
46
47
48
49
50
51
52
#define SWITCH_BITS(bits, DType, ...)                           \
  do {                                                          \
    if ((bits) == 16 || (bits) == 32) {                         \
      typedef float DType;                                      \
      { __VA_ARGS__ }                                           \
    } else if ((bits) == 64) {                                  \
      typedef double DType;                                     \
      { __VA_ARGS__ }                                           \
    } else {                                                    \
53
      LOG(FATAL) << "Data type not recognized with bits " << bits; \
54
55
56
    }                                                           \
  } while (0)

57

58
/*! \brief Generalized SDDMM on Csr format. */
59
template <int XPU, typename IdType, int bits>
60
61
62
void SDDMMCsr(const std::string& op,
              const BcastOff& bcast,
              const CSRMatrix& csr,
63
64
65
66
67
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
68
69
70
71
72
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
      });
73
    });
74
75
76
  });
}

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
/*! \brief Generalized SDDMM on Csr format with Heterograph support. */
template <int XPU, typename IdType, int bits>
void SDDMMCsrHetero(const std::string& op,
              const BcastOff& bcast,
              const std::vector<CSRMatrix>& vec_csr,
              const std::vector<NDArray>& vec_lhs,
              const std::vector<NDArray>& vec_rhs,
              std::vector<NDArray> vec_out,
              int lhs_target,
              int rhs_target,
              const std::vector<dgl_type_t>& lhs_nid,
              const std::vector<dgl_type_t>& rhs_nid) {
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        /* Call  SDDMM for each relation type */
        for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
          CSRMatrix csr = vec_csr[etype];
          NDArray lhs = vec_lhs[lhs_nid[etype]];
          NDArray rhs = vec_rhs[rhs_nid[etype]];
          NDArray out = vec_out[etype];
          cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
        }
      });
    });
  });
}

105
106
107
108
109
110
111
112
113
template void SDDMMCsr<kDLCPU, int32_t, 16>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 16>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int32_t, 32>(
114
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
115
116
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
117
template void SDDMMCsr<kDLCPU, int64_t, 32>(
118
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
119
120
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
121
template void SDDMMCsr<kDLCPU, int32_t, 64>(
122
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
123
124
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
125
template void SDDMMCsr<kDLCPU, int64_t, 64>(
126
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
127
128
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
template void SDDMMCsrHetero<kDLCPU, int32_t, 16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 32>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 32>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 64>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 64>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
172

173
/*! \brief Generalized SDDMM on Coo format. */
174
template <int XPU, typename IdType, int bits>
175
176
177
void SDDMMCoo(const std::string& op,
              const BcastOff& bcast,
              const COOMatrix& coo,
178
179
180
181
182
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
183
184
185
186
187
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
      });
188
    });
189
190
191
  });
}

192
193
194
195
196
template void SDDMMCoo<kDLCPU, int32_t, 16>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 16>(
197
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
198
199
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
200
template void SDDMMCoo<kDLCPU, int32_t, 32>(
201
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
202
203
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
204
template void SDDMMCoo<kDLCPU, int64_t, 32>(
205
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
206
207
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
208
template void SDDMMCoo<kDLCPU, int32_t, 64>(
209
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
210
211
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
212
213
214
215
216
template void SDDMMCoo<kDLCPU, int64_t, 64>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);

217
218
219

}  // namespace aten
}  // namespace dgl