sddmm.cu 4.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/sddmm.cu
 * \brief SDDMM C APIs and definitions.
 */
#include <dgl/array.h>
#include "./sddmm.cuh"
#include "./functor.cuh"

namespace dgl {
namespace aten {

/*!
 * \brief CUDA implementation of g-SDDMM on Csr format.
 */
16
template <int XPU, typename IdType, typename DType>
17
18
19
void SDDMMCsr(const std::string& op,
              const BcastOff& bcast,
              const CSRMatrix& csr,
20
21
22
23
24
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
25
26
27
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
28
    });
29
30
31
  });
}

32

33
34
35
/*!
 * \brief CUDA implementation of g-SDDMM on Coo format.
 */
36
template <int XPU, typename IdType, typename DType>
37
38
39
void SDDMMCoo(const std::string& op,
              const BcastOff& bcast,
              const COOMatrix& coo,
40
41
42
43
44
              NDArray lhs,
              NDArray rhs,
              NDArray out,
              int lhs_target,
              int rhs_target) {
45
46
47
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
48
    });
49
50
51
  });
}

52
53
54
55
56
57
58
59
60
61
template void SDDMMCsr<kDGLCUDA, int32_t, __half>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, __half>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
#if BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
62
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
63
64
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
65
template void SDDMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
66
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
67
68
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
69
70
#endif  // BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, float>(
71
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
72
73
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
74
template void SDDMMCsr<kDGLCUDA, int64_t, float>(
75
76
77
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
78
template void SDDMMCsr<kDGLCUDA, int32_t, double>(
79
80
81
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
82
template void SDDMMCsr<kDGLCUDA, int64_t, double>(
83
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
84
85
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
86

87
88
89
90
91
92
93
94
95
96
template void SDDMMCoo<kDGLCUDA, int32_t, __half>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, __half>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
#if BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
97
98
99
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
100
template void SDDMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
101
102
103
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
104
105
#endif  // BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, float>(
106
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
107
108
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
109
template void SDDMMCoo<kDGLCUDA, int64_t, float>(
110
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
111
112
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
113
template void SDDMMCoo<kDGLCUDA, int32_t, double>(
114
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
115
116
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
117
template void SDDMMCoo<kDGLCUDA, int64_t, double>(
118
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
119
120
    NDArray lhs, NDArray rhs, NDArray out,
    int lhs_target, int rhs_target);
121
122
123

}  // namespace aten
}  // namespace dgl