"tests/vscode:/vscode.git/clone" did not exist on "a1051f0095c43218636f7be7d66d80b705439e6f"
sddmm.cu 7.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/sddmm.cu
 * \brief SDDMM C APIs and definitions.
 */
#include <dgl/array.h>
#include "./sddmm.cuh"
#include "./functor.cuh"

namespace dgl {
namespace aten {

#define SWITCH_OP(op, Op, ...)                                      \
  do {                                                              \
    if ((op) == "add") {                                            \
      typedef cuda::binary::Add<DType> Op;                          \
      { __VA_ARGS__ }                                               \
    } else if ((op) == "sub") {                                     \
      typedef cuda::binary::Sub<DType> Op;                          \
      { __VA_ARGS__ }                                               \
    } else if ((op) == "mul") {                                     \
      typedef cuda::binary::Mul<DType> Op;                          \
      { __VA_ARGS__ }                                               \
    } else if ((op) == "div") {                                     \
      typedef cuda::binary::Div<DType> Op;                          \
      { __VA_ARGS__ }                                               \
27
28
    } else if ((op) == "copy_lhs") {                                \
      typedef cuda::binary::CopyLhs<DType> Op;                      \
29
      { __VA_ARGS__ }                                               \
30
31
    } else if ((op) == "copy_rhs") {                                \
      typedef cuda::binary::CopyRhs<DType> Op;                      \
32
33
34
35
36
37
38
39
40
      { __VA_ARGS__ }                                               \
    } else if ((op) == "dot") {                                     \
      typedef cuda::binary::Dot<DType> Op;                          \
      { __VA_ARGS__ }                                               \
    } else {                                                        \
      LOG(FATAL) << "Unsupported SpMM/SDDMM binary operator: " << op;     \
    }                                                               \
  } while (0)

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#define SWITCH_RHS(rhs_target, RhsTarget, ...)                        \
  do {                                                                \
    if ((rhs_target) == 0) {                                          \
      constexpr int RhsTarget = 0;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 1) {                                   \
      constexpr int RhsTarget = 1;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 2) {                                   \
      constexpr int RhsTarget = 2;                                    \
      { __VA_ARGS__ }                                                 \
    } else {                                                          \
      LOG(INFO) << "Invalid rhs target: " << (rhs_target);            \
    }                                                                 \
  } while (0)

#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
  do {                                                                  \
    if ((lhs_target) == 0) {                                            \
      constexpr int LhsTarget = 0;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 1) {                                     \
      constexpr int LhsTarget = 1;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 2) {                                     \
      constexpr int LhsTarget = 2;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else {                                                            \
      LOG(INFO) << "Invalid lhs target: " << (lhs_target);              \
    }                                                                   \
  } while (0)
72
73
74
75

/*!
 * \brief CUDA implementation of g-SDDMM on Csr format.
 */
76
template <int XPU, typename IdType, int bits>
77
78
79
void SDDMMCsr(const std::string& op,
              const BcastOff& bcast,
              const CSRMatrix& csr,
80
81
82
83
84
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
85
86
87
88
89
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
      });
90
    });
91
92
93
94
95
96
  });
}

/*!
 * \brief CUDA implementation of g-SDDMM on Coo format.
 */
97
template <int XPU, typename IdType, int bits>
98
99
100
void SDDMMCoo(const std::string& op,
              const BcastOff& bcast,
              const COOMatrix& coo,
101
102
103
104
105
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
106
107
108
109
110
  SWITCH_BITS(bits, DType, {
    SWITCH_OP(op, Op, {
      SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
        cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
      });
111
    });
112
113
114
  });
}

115
template void SDDMMCsr<kDLGPU, int32_t, 16>(
116
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
117
118
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
119
template void SDDMMCsr<kDLGPU, int64_t, 16>(
120
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
121
122
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
123
template void SDDMMCsr<kDLGPU, int32_t, 32>(
124
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
125
126
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
127
128
129
130
131
132
133
134
135
template void SDDMMCsr<kDLGPU, int64_t, 32>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int32_t, 64>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, 64>(
136
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
137
138
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
139

140
141
142
143
144
145
146
147
148
template void SDDMMCoo<kDLGPU, int32_t, 16>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, 16>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, 32>(
149
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
150
151
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
152
template void SDDMMCoo<kDLGPU, int64_t, 32>(
153
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
154
155
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
156
template void SDDMMCoo<kDLGPU, int32_t, 64>(
157
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
158
159
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
160
template void SDDMMCoo<kDLGPU, int64_t, 64>(
161
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
162
163
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
164
165
166

}  // namespace aten
}  // namespace dgl