sddmm.cc 9.33 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file aten/cpu/sddmm.cc
 * @brief SDDMM C APIs and definitions.
5
6
 */
#include "./sddmm.h"
7

8
9
10
11
12
#include <dgl/array.h>

namespace dgl {
namespace aten {

13
14
15
16
17
18
19
20
21
22
23
24
25
26
#define SWITCH_RHS(rhs_target, RhsTarget, ...)             \
  do {                                                     \
    if ((rhs_target) == 0) {                               \
      constexpr int RhsTarget = 0;                         \
      { __VA_ARGS__ }                                      \
    } else if ((rhs_target) == 1) {                        \
      constexpr int RhsTarget = 1;                         \
      { __VA_ARGS__ }                                      \
    } else if ((rhs_target) == 2) {                        \
      constexpr int RhsTarget = 2;                         \
      { __VA_ARGS__ }                                      \
    } else {                                               \
      LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
    }                                                      \
27
28
  } while (0)

29
30
31
32
33
34
35
36
37
38
39
40
41
42
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \
  do {                                                                   \
    if ((lhs_target) == 0) {                                             \
      constexpr int LhsTarget = 0;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else if ((lhs_target) == 1) {                                      \
      constexpr int LhsTarget = 1;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else if ((lhs_target) == 2) {                                      \
      constexpr int LhsTarget = 2;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else {                                                             \
      LOG(INFO) << "Invalid lhs target: " << (lhs_target);               \
    }                                                                    \
43
44
  } while (0)

45
/** @brief Generalized SDDMM on Csr format. */
46
template <int XPU, typename IdType, typename DType>
47
48
49
void SDDMMCsr(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
50
51
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
52
53
      cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
          bcast, csr, lhs, rhs, out);
54
    });
55
56
57
  });
}

58
/** @brief Generalized SDDMM on Csr format with Heterograph support. */
59
template <int XPU, typename IdType, typename DType>
60
61
62
63
64
65
void SDDMMCsrHetero(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
    const std::vector<dgl_type_t>& rhs_nid) {
66
67
68
69
70
71
72
73
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      /* Call  SDDMM for each relation type */
      for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
        CSRMatrix csr = vec_csr[etype];
        NDArray lhs = vec_lhs[lhs_nid[etype]];
        NDArray rhs = vec_rhs[rhs_nid[etype]];
        NDArray out = vec_out[etype];
74
75
        cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
            bcast, csr, lhs, rhs, out);
76
      }
77
78
79
80
    });
  });
}

81
template void SDDMMCsr<kDGLCPU, int32_t, float>(
82
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
83
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
84
template void SDDMMCsr<kDGLCPU, int64_t, float>(
85
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
86
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
87
template void SDDMMCsr<kDGLCPU, int32_t, double>(
88
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
89
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
90
template void SDDMMCsr<kDGLCPU, int64_t, double>(
91
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
92
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
93

94
template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
95
    const std::string& op, const BcastOff& bcast,
96
97
98
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
99
    const std::vector<dgl_type_t>& out_eid);
100
template void SDDMMCsrHetero<kDGLCPU, int64_t, float>(
101
    const std::string& op, const BcastOff& bcast,
102
103
104
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
105
    const std::vector<dgl_type_t>& out_eid);
106
template void SDDMMCsrHetero<kDGLCPU, int32_t, double>(
107
    const std::string& op, const BcastOff& bcast,
108
109
110
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
111
    const std::vector<dgl_type_t>& out_eid);
112
template void SDDMMCsrHetero<kDGLCPU, int64_t, double>(
113
    const std::string& op, const BcastOff& bcast,
114
115
116
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
117
    const std::vector<dgl_type_t>& out_eid);
118

119
/** @brief Generalized SDDMM on Coo format. */
120
template <int XPU, typename IdType, typename DType>
121
122
123
void SDDMMCoo(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
124
125
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
126
127
      cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
          bcast, coo, lhs, rhs, out);
128
    });
129
130
131
  });
}

132
/** @brief Generalized SDDMM on Coo format with Heterograph support. */
133
template <int XPU, typename IdType, typename DType>
134
135
136
137
138
139
void SDDMMCooHetero(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
    const std::vector<dgl_type_t>& rhs_nid) {
140
141
142
143
144
145
146
147
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      /* Call  SDDMM for each relation type */
      for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
        COOMatrix coo = vec_coo[etype];
        NDArray lhs = vec_lhs[lhs_nid[etype]];
        NDArray rhs = vec_rhs[rhs_nid[etype]];
        NDArray out = vec_out[etype];
148
149
        cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
            bcast, coo, lhs, rhs, out);
150
      }
151
152
153
154
    });
  });
}

155
template void SDDMMCoo<kDGLCPU, int32_t, float>(
156
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
157
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
158
template void SDDMMCoo<kDGLCPU, int64_t, float>(
159
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
160
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
161
template void SDDMMCoo<kDGLCPU, int32_t, double>(
162
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
163
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
164
template void SDDMMCoo<kDGLCPU, int64_t, double>(
165
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
166
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
167

168
template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
169
    const std::string& op, const BcastOff& bcast,
170
171
172
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
173
    const std::vector<dgl_type_t>& out_eid);
174
template void SDDMMCooHetero<kDGLCPU, int64_t, float>(
175
    const std::string& op, const BcastOff& bcast,
176
177
178
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
179
    const std::vector<dgl_type_t>& out_eid);
180
template void SDDMMCooHetero<kDGLCPU, int32_t, double>(
181
    const std::string& op, const BcastOff& bcast,
182
183
184
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
185
    const std::vector<dgl_type_t>& out_eid);
186
template void SDDMMCooHetero<kDGLCPU, int64_t, double>(
187
    const std::string& op, const BcastOff& bcast,
188
189
190
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
191
    const std::vector<dgl_type_t>& out_eid);
192
193
194

}  // namespace aten
}  // namespace dgl