sddmm.cc 11.6 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file aten/cpu/sddmm.cc
 * @brief SDDMM C APIs and definitions.
6
 */
sangwzh's avatar
sangwzh committed
7
#include "sddmm.h"
8

9
10
11
12
13
#include <dgl/array.h>

namespace dgl {
namespace aten {

14
15
16
17
18
19
20
21
22
23
24
25
26
27
#define SWITCH_RHS(rhs_target, RhsTarget, ...)             \
  do {                                                     \
    if ((rhs_target) == 0) {                               \
      constexpr int RhsTarget = 0;                         \
      { __VA_ARGS__ }                                      \
    } else if ((rhs_target) == 1) {                        \
      constexpr int RhsTarget = 1;                         \
      { __VA_ARGS__ }                                      \
    } else if ((rhs_target) == 2) {                        \
      constexpr int RhsTarget = 2;                         \
      { __VA_ARGS__ }                                      \
    } else {                                               \
      LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
    }                                                      \
28
29
  } while (0)

30
31
32
33
34
35
36
37
38
39
40
41
42
43
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \
  do {                                                                   \
    if ((lhs_target) == 0) {                                             \
      constexpr int LhsTarget = 0;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else if ((lhs_target) == 1) {                                      \
      constexpr int LhsTarget = 1;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else if ((lhs_target) == 2) {                                      \
      constexpr int LhsTarget = 2;                                       \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \
    } else {                                                             \
      LOG(INFO) << "Invalid lhs target: " << (lhs_target);               \
    }                                                                    \
44
45
  } while (0)

46
/** @brief Generalized SDDMM on Csr format. */
47
template <int XPU, typename IdType, typename DType>
48
49
50
void SDDMMCsr(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
51
52
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
53
54
      cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
          bcast, csr, lhs, rhs, out);
55
    });
56
57
58
  });
}

59
/** @brief Generalized SDDMM on Csr format with Heterograph support. */
60
template <int XPU, typename IdType, typename DType>
61
62
63
64
65
66
void SDDMMCsrHetero(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
    const std::vector<dgl_type_t>& rhs_nid) {
67
68
69
70
71
72
73
74
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      /* Call  SDDMM for each relation type */
      for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
        CSRMatrix csr = vec_csr[etype];
        NDArray lhs = vec_lhs[lhs_nid[etype]];
        NDArray rhs = vec_rhs[rhs_nid[etype]];
        NDArray out = vec_out[etype];
75
76
        cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
            bcast, csr, lhs, rhs, out);
77
      }
78
79
80
81
    });
  });
}

82
83
84
85
86
87
template void SDDMMCsr<kDGLCPU, int32_t, BFloat16>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, BFloat16>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
88
template void SDDMMCsr<kDGLCPU, int32_t, float>(
89
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
90
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
91
template void SDDMMCsr<kDGLCPU, int64_t, float>(
92
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
93
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
94
template void SDDMMCsr<kDGLCPU, int32_t, double>(
95
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
96
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
97
template void SDDMMCsr<kDGLCPU, int64_t, double>(
98
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
99
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
100

101
102
103
104
105
106
107
108
109
110
111
112
template void SDDMMCsrHetero<kDGLCPU, int32_t, BFloat16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, BFloat16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
113
template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
114
    const std::string& op, const BcastOff& bcast,
115
116
117
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
118
    const std::vector<dgl_type_t>& out_eid);
119
template void SDDMMCsrHetero<kDGLCPU, int64_t, float>(
120
    const std::string& op, const BcastOff& bcast,
121
122
123
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
124
    const std::vector<dgl_type_t>& out_eid);
125
template void SDDMMCsrHetero<kDGLCPU, int32_t, double>(
126
    const std::string& op, const BcastOff& bcast,
127
128
129
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
130
    const std::vector<dgl_type_t>& out_eid);
131
template void SDDMMCsrHetero<kDGLCPU, int64_t, double>(
132
    const std::string& op, const BcastOff& bcast,
133
134
135
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
136
    const std::vector<dgl_type_t>& out_eid);
137

138
/** @brief Generalized SDDMM on Coo format. */
139
template <int XPU, typename IdType, typename DType>
140
141
142
void SDDMMCoo(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
143
144
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
145
146
      cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
          bcast, coo, lhs, rhs, out);
147
    });
148
149
150
  });
}

151
/** @brief Generalized SDDMM on Coo format with Heterograph support. */
152
template <int XPU, typename IdType, typename DType>
153
154
155
156
157
158
void SDDMMCooHetero(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
    const std::vector<dgl_type_t>& rhs_nid) {
159
160
161
162
163
164
165
166
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      /* Call  SDDMM for each relation type */
      for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
        COOMatrix coo = vec_coo[etype];
        NDArray lhs = vec_lhs[lhs_nid[etype]];
        NDArray rhs = vec_rhs[rhs_nid[etype]];
        NDArray out = vec_out[etype];
167
168
        cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
            bcast, coo, lhs, rhs, out);
169
      }
170
171
172
173
    });
  });
}

174
175
176
177
178
179
template void SDDMMCoo<kDGLCPU, int32_t, BFloat16>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, BFloat16>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
180
template void SDDMMCoo<kDGLCPU, int32_t, float>(
181
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
182
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
183
template void SDDMMCoo<kDGLCPU, int64_t, float>(
184
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
185
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
186
template void SDDMMCoo<kDGLCPU, int32_t, double>(
187
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
188
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
189
template void SDDMMCoo<kDGLCPU, int64_t, double>(
190
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
191
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
192

193
194
195
196
197
198
199
200
201
202
203
204
template void SDDMMCooHetero<kDGLCPU, int32_t, BFloat16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, BFloat16>(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
    const std::vector<dgl_type_t>& out_eid);
205
template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
206
    const std::string& op, const BcastOff& bcast,
207
208
209
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
210
    const std::vector<dgl_type_t>& out_eid);
211
template void SDDMMCooHetero<kDGLCPU, int64_t, float>(
212
    const std::string& op, const BcastOff& bcast,
213
214
215
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
216
    const std::vector<dgl_type_t>& out_eid);
217
template void SDDMMCooHetero<kDGLCPU, int32_t, double>(
218
    const std::string& op, const BcastOff& bcast,
219
220
221
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
222
    const std::vector<dgl_type_t>& out_eid);
223
template void SDDMMCooHetero<kDGLCPU, int64_t, double>(
224
    const std::string& op, const BcastOff& bcast,
225
226
227
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
228
    const std::vector<dgl_type_t>& out_eid);
229
230
231

}  // namespace aten
}  // namespace dgl