sddmm_hetero_csr.hip 4.19 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file array/cuda/sddmm.cu
 * @brief SDDMM C APIs and definitions.
6
7
 */
#include <dgl/array.h>
8

sangwzh's avatar
sangwzh committed
9
#include "sddmm.cuh"
10
11
12
13

namespace dgl {
namespace aten {

14
/**
15
 * @brief CUDA implementation of g-SDDMM on heterograph using Csr format.
16
 */
17
template <int XPU, typename IdType, typename DType>
18
19
20
21
22
23
void SDDMMCsrHetero(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,
    const std::vector<dgl_type_t>& rhs_eid) {
24
25
26
27
28
29
30
31
32
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      /* Call SDDMM CUDA kernel for each relation type sequentially */
      for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
        CSRMatrix csr = vec_csr[etype];
        NDArray lhs = vec_lhs[lhs_eid[etype]];
        NDArray rhs = vec_rhs[rhs_eid[etype]];
        NDArray out = vec_out[etype];
        cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
33
            bcast, csr, lhs, rhs, out);
34
      }
35
36
37
38
    });
  });
}

39
template void SDDMMCsrHetero<kDGLCUDA, int32_t, __half>(
40
    const std::string& op, const BcastOff& bcast,
41
42
43
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
44
    const std::vector<dgl_type_t>& out_eid);
45
template void SDDMMCsrHetero<kDGLCUDA, int64_t, __half>(
46
    const std::string& op, const BcastOff& bcast,
47
48
49
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
50
    const std::vector<dgl_type_t>& out_eid);
51
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
52
template void SDDMMCsrHetero<kDGLCUDA, int32_t, __hip_bfloat16>(
53
    const std::string& op, const BcastOff& bcast,
54
55
56
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
57
    const std::vector<dgl_type_t>& out_eid);
sangwzh's avatar
sangwzh committed
58
template void SDDMMCsrHetero<kDGLCUDA, int64_t, __hip_bfloat16>(
59
    const std::string& op, const BcastOff& bcast,
60
61
62
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
63
    const std::vector<dgl_type_t>& out_eid);
64
65
#endif  // BF16_ENABLED
template void SDDMMCsrHetero<kDGLCUDA, int32_t, float>(
66
    const std::string& op, const BcastOff& bcast,
67
68
69
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
70
    const std::vector<dgl_type_t>& out_eid);
71
72
template void SDDMMCsrHetero<kDGLCUDA, int64_t, float>(
    const std::string& op, const BcastOff& bcast,
73
74
75
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
76
77
78
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int32_t, double>(
    const std::string& op, const BcastOff& bcast,
79
80
81
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
82
83
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, double>(
84
    const std::string& op, const BcastOff& bcast,
85
86
87
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
88
89
90
91
    const std::vector<dgl_type_t>& out_eid);

}  // namespace aten
}  // namespace dgl