sddmm.cc 9.86 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file aten/cpu/sddmm.cc
 * @brief SDDMM C APIs and definitions.
5
6
7
8
9
10
11
 */
#include "./sddmm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#define SWITCH_RHS(rhs_target, RhsTarget, ...)                        \
  do {                                                                \
    if ((rhs_target) == 0) {                                          \
      constexpr int RhsTarget = 0;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 1) {                                   \
      constexpr int RhsTarget = 1;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 2) {                                   \
      constexpr int RhsTarget = 2;                                    \
      { __VA_ARGS__ }                                                 \
    } else {                                                          \
      LOG(INFO) << "Invalid rhs target: " << (rhs_target);            \
    }                                                                 \
  } while (0)

#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
  do {                                                                  \
    if ((lhs_target) == 0) {                                            \
      constexpr int LhsTarget = 0;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 1) {                                     \
      constexpr int LhsTarget = 1;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 2) {                                     \
      constexpr int LhsTarget = 2;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else {                                                            \
      LOG(INFO) << "Invalid lhs target: " << (lhs_target);              \
    }                                                                   \
  } while (0)


45
/** @brief Generalized SDDMM on Csr format. */
46
template <int XPU, typename IdType, typename DType>
47
48
49
void SDDMMCsr(const std::string& op,
              const BcastOff& bcast,
              const CSRMatrix& csr,
50
51
52
53
54
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
55
56
57
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
58
    });
59
60
61
  });
}

62
/** @brief Generalized SDDMM on Csr format with Heterograph support. */
63
template <int XPU, typename IdType, typename DType>
64
65
66
67
68
69
70
71
72
73
void SDDMMCsrHetero(const std::string& op,
              const BcastOff& bcast,
              const std::vector<CSRMatrix>& vec_csr,
              const std::vector<NDArray>& vec_lhs,
              const std::vector<NDArray>& vec_rhs,
              std::vector<NDArray> vec_out,
              int lhs_target,
              int rhs_target,
              const std::vector<dgl_type_t>& lhs_nid,
              const std::vector<dgl_type_t>& rhs_nid) {
74
75
76
77
78
79
80
81
82
83
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      /* Call  SDDMM for each relation type */
      for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
        CSRMatrix csr = vec_csr[etype];
        NDArray lhs = vec_lhs[lhs_nid[etype]];
        NDArray rhs = vec_rhs[rhs_nid[etype]];
        NDArray out = vec_out[etype];
        cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
      }
84
85
86
87
    });
  });
}

88
template void SDDMMCsr<kDGLCPU, int32_t, float>(
89
90
91
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
92
template void SDDMMCsr<kDGLCPU, int64_t, float>(
93
94
95
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
96
template void SDDMMCsr<kDGLCPU, int32_t, double>(
97
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
98
99
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
100
template void SDDMMCsr<kDGLCPU, int64_t, double>(
101
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
102
103
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
104

105
template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
106
107
108
109
110
111
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
112
template void SDDMMCsrHetero<kDGLCPU, int64_t, float>(
113
114
115
116
117
118
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
119
template void SDDMMCsrHetero<kDGLCPU, int32_t, double>(
120
121
122
123
124
125
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
126
template void SDDMMCsrHetero<kDGLCPU, int64_t, double>(
127
128
129
130
131
132
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
133

134
/** @brief Generalized SDDMM on Coo format. */
135
template <int XPU, typename IdType, typename DType>
136
137
138
void SDDMMCoo(const std::string& op,
              const BcastOff& bcast,
              const COOMatrix& coo,
139
140
141
142
143
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
144
145
146
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
147
    });
148
149
150
  });
}

151
/** @brief Generalized SDDMM on Coo format with Heterograph support. */
152
template <int XPU, typename IdType, typename DType>
153
154
155
156
157
158
159
160
161
162
void SDDMMCooHetero(const std::string& op,
              const BcastOff& bcast,
              const std::vector<COOMatrix>& vec_coo,
              const std::vector<NDArray>& vec_lhs,
              const std::vector<NDArray>& vec_rhs,
              std::vector<NDArray> vec_out,
              int lhs_target,
              int rhs_target,
              const std::vector<dgl_type_t>& lhs_nid,
              const std::vector<dgl_type_t>& rhs_nid) {
163
164
165
166
167
168
169
170
171
172
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      /* Call  SDDMM for each relation type */
      for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
        COOMatrix coo = vec_coo[etype];
        NDArray lhs = vec_lhs[lhs_nid[etype]];
        NDArray rhs = vec_rhs[rhs_nid[etype]];
        NDArray out = vec_out[etype];
        cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
      }
173
174
175
176
    });
  });
}

177
template void SDDMMCoo<kDGLCPU, int32_t, float>(
178
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
179
180
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
181
template void SDDMMCoo<kDGLCPU, int64_t, float>(
182
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
183
184
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
185
template void SDDMMCoo<kDGLCPU, int32_t, double>(
186
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
187
188
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
189
template void SDDMMCoo<kDGLCPU, int64_t, double>(
190
191
192
193
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);

194
template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
195
196
197
198
199
200
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
201
template void SDDMMCooHetero<kDGLCPU, int64_t, float>(
202
203
204
205
206
207
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
208
template void SDDMMCooHetero<kDGLCPU, int32_t, double>(
209
210
211
212
213
214
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
215
template void SDDMMCooHetero<kDGLCPU, int64_t, double>(
216
217
218
219
220
221
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo,
    const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
    std::vector<NDArray> out, int lhs_target, int rhs_target,
    const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
222
223
224

}  // namespace aten
}  // namespace dgl