sddmm.cc 4.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 *  Copyright (c) 2020 by Contributors
 * \file aten/cpu/sddmm.cc
 * \brief SDDMM C APIs and definitions.
 */
#include "./sddmm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#define SWITCH_RHS(rhs_target, RhsTarget, ...)                        \
  do {                                                                \
    if ((rhs_target) == 0) {                                          \
      constexpr int RhsTarget = 0;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 1) {                                   \
      constexpr int RhsTarget = 1;                                    \
      { __VA_ARGS__ }                                                 \
    } else if ((rhs_target) == 2) {                                   \
      constexpr int RhsTarget = 2;                                    \
      { __VA_ARGS__ }                                                 \
    } else {                                                          \
      LOG(INFO) << "Invalid rhs target: " << (rhs_target);            \
    }                                                                 \
  } while (0)

#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
  do {                                                                  \
    if ((lhs_target) == 0) {                                            \
      constexpr int LhsTarget = 0;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 1) {                                     \
      constexpr int LhsTarget = 1;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else if ((lhs_target) == 2) {                                     \
      constexpr int LhsTarget = 2;                                      \
      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                   \
    } else {                                                            \
      LOG(INFO) << "Invalid lhs target: " << (lhs_target);              \
    }                                                                   \
  } while (0)


45
46
47
48
49
/*! \brief Generalized SDDMM on Csr format. */
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
              const BcastOff& bcast,
              const CSRMatrix& csr,
50
51
52
53
54
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
55
  SWITCH_OP(op, Op, {
56
57
58
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
    });
59
60
61
62
63
  });
}

template void SDDMMCsr<kDLCPU, int32_t, float>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
64
65
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
66
67
template void SDDMMCsr<kDLCPU, int64_t, float>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
68
69
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
70
71
template void SDDMMCsr<kDLCPU, int32_t, double>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
72
73
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
74
75
template void SDDMMCsr<kDLCPU, int64_t, double>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
76
77
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
78
79
80
81
82
83

/*! \brief Generalized SDDMM on Coo format. */
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
              const BcastOff& bcast,
              const COOMatrix& coo,
84
85
86
87
88
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
89
  SWITCH_OP(op, Op, {
90
91
92
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
    });
93
94
95
96
97
  });
}

template void SDDMMCoo<kDLCPU, int32_t, float>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
98
99
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
100
101
template void SDDMMCoo<kDLCPU, int64_t, float>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
102
103
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
104
105
template void SDDMMCoo<kDLCPU, int32_t, double>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
106
107
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
108
109
template void SDDMMCoo<kDLCPU, int64_t, double>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
110
111
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
112
113
114

}  // namespace aten
}  // namespace dgl