sddmm.cu 4.42 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/sddmm.cu
 * @brief SDDMM C APIs and definitions.
5
6
 */
#include <dgl/array.h>
7

8
#include "./functor.cuh"
9
#include "./sddmm.cuh"
10
11
12
13

namespace dgl {
namespace aten {

14
/**
15
 * @brief CUDA implementation of g-SDDMM on Csr format.
16
 */
17
template <int XPU, typename IdType, typename DType>
18
19
20
void SDDMMCsr(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
21
22
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
23
24
      cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
          bcast, csr, lhs, rhs, out);
25
    });
26
27
28
  });
}

29
/**
30
 * @brief CUDA implementation of g-SDDMM on Coo format.
31
 */
32
template <int XPU, typename IdType, typename DType>
33
34
35
void SDDMMCoo(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
36
37
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
38
39
      cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
          bcast, coo, lhs, rhs, out);
40
    });
41
42
43
  });
}

44
45
template void SDDMMCsr<kDGLCUDA, int32_t, __half>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
46
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
47
48
template void SDDMMCsr<kDGLCUDA, int64_t, __half>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
49
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
50
51
#if BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
52
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
53
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
54
template void SDDMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
55
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
56
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
57
58
#endif  // BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, float>(
59
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
60
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
61
template void SDDMMCsr<kDGLCUDA, int64_t, float>(
62
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
63
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
64
template void SDDMMCsr<kDGLCUDA, int32_t, double>(
65
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
66
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
67
template void SDDMMCsr<kDGLCUDA, int64_t, double>(
68
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
69
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
70

71
72
template void SDDMMCoo<kDGLCUDA, int32_t, __half>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
73
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
74
75
template void SDDMMCoo<kDGLCUDA, int64_t, __half>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
76
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
77
78
#if BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
79
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
80
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
81
template void SDDMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
82
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
83
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
84
85
#endif  // BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, float>(
86
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
87
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
88
template void SDDMMCoo<kDGLCUDA, int64_t, float>(
89
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
90
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
91
template void SDDMMCoo<kDGLCUDA, int32_t, double>(
92
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
93
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
94
template void SDDMMCoo<kDGLCUDA, int64_t, double>(
95
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
96
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
97
98
99

}  // namespace aten
}  // namespace dgl