sddmm.hip 4.48 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file array/cuda/sddmm.cu
 * @brief SDDMM C APIs and definitions.
6
7
 */
#include <dgl/array.h>
8

sangwzh's avatar
sangwzh committed
9
10
#include "functor.cuh"
#include "sddmm.cuh"
11
12
13
14

namespace dgl {
namespace aten {

15
/**
16
 * @brief CUDA implementation of g-SDDMM on Csr format.
17
 */
18
template <int XPU, typename IdType, typename DType>
19
20
21
void SDDMMCsr(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
22
23
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
24
25
      cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
          bcast, csr, lhs, rhs, out);
26
    });
27
28
29
  });
}

30
/**
31
 * @brief CUDA implementation of g-SDDMM on Coo format.
32
 */
33
template <int XPU, typename IdType, typename DType>
34
35
36
void SDDMMCoo(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
37
38
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
39
40
      cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
          bcast, coo, lhs, rhs, out);
41
    });
42
43
44
  });
}

45
46
template void SDDMMCsr<kDGLCUDA, int32_t, __half>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
47
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
48
49
template void SDDMMCsr<kDGLCUDA, int64_t, __half>(
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
50
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
51
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
52
template void SDDMMCsr<kDGLCUDA, int32_t, __hip_bfloat16>(
53
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
54
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
sangwzh's avatar
sangwzh committed
55
template void SDDMMCsr<kDGLCUDA, int64_t, __hip_bfloat16>(
56
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
57
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
58
59
#endif  // BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, float>(
60
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
61
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
62
template void SDDMMCsr<kDGLCUDA, int64_t, float>(
63
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
64
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
65
template void SDDMMCsr<kDGLCUDA, int32_t, double>(
66
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
67
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
68
template void SDDMMCsr<kDGLCUDA, int64_t, double>(
69
    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
70
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
71

72
73
template void SDDMMCoo<kDGLCUDA, int32_t, __half>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
74
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
75
76
template void SDDMMCoo<kDGLCUDA, int64_t, __half>(
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
77
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
78
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
79
template void SDDMMCoo<kDGLCUDA, int32_t, __hip_bfloat16>(
80
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
81
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
sangwzh's avatar
sangwzh committed
82
template void SDDMMCoo<kDGLCUDA, int64_t, __hip_bfloat16>(
83
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
84
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
85
86
#endif  // BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, float>(
87
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
88
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
89
template void SDDMMCoo<kDGLCUDA, int64_t, float>(
90
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
91
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
92
template void SDDMMCoo<kDGLCUDA, int32_t, double>(
93
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
94
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
95
template void SDDMMCoo<kDGLCUDA, int64_t, double>(
96
    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
97
    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
98
99
100

}  // namespace aten
}  // namespace dgl