sddmm_hetero_csr.cu 4.13 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/sddmm.cu
 * @brief SDDMM C APIs and definitions.
5
6
 */
#include <dgl/array.h>
7

8
9
10
11
12
#include "./sddmm.cuh"

namespace dgl {
namespace aten {

13
/**
14
 * @brief CUDA implementation of g-SDDMM on heterograph using Csr format.
15
 */
16
template <int XPU, typename IdType, typename DType>
17
18
19
20
21
22
void SDDMMCsrHetero(
    const std::string& op, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,
    const std::vector<dgl_type_t>& rhs_eid) {
23
24
25
26
27
28
29
30
31
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      /* Call SDDMM CUDA kernel for each relation type sequentially */
      for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
        CSRMatrix csr = vec_csr[etype];
        NDArray lhs = vec_lhs[lhs_eid[etype]];
        NDArray rhs = vec_rhs[rhs_eid[etype]];
        NDArray out = vec_out[etype];
        cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
32
            bcast, csr, lhs, rhs, out);
33
      }
34
35
36
37
    });
  });
}

38
template void SDDMMCsrHetero<kDGLCUDA, int32_t, __half>(
39
    const std::string& op, const BcastOff& bcast,
40
41
42
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
43
    const std::vector<dgl_type_t>& out_eid);
44
template void SDDMMCsrHetero<kDGLCUDA, int64_t, __half>(
45
    const std::string& op, const BcastOff& bcast,
46
47
48
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
49
    const std::vector<dgl_type_t>& out_eid);
50
51
#if BF16_ENABLED
template void SDDMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
52
    const std::string& op, const BcastOff& bcast,
53
54
55
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
56
    const std::vector<dgl_type_t>& out_eid);
57
template void SDDMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
58
    const std::string& op, const BcastOff& bcast,
59
60
61
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
62
    const std::vector<dgl_type_t>& out_eid);
63
64
#endif  // BF16_ENABLED
template void SDDMMCsrHetero<kDGLCUDA, int32_t, float>(
65
    const std::string& op, const BcastOff& bcast,
66
67
68
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
69
    const std::vector<dgl_type_t>& out_eid);
70
71
template void SDDMMCsrHetero<kDGLCUDA, int64_t, float>(
    const std::string& op, const BcastOff& bcast,
72
73
74
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
75
76
77
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int32_t, double>(
    const std::string& op, const BcastOff& bcast,
78
79
80
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
81
82
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, double>(
83
    const std::string& op, const BcastOff& bcast,
84
85
86
    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
87
88
89
90
    const std::vector<dgl_type_t>& out_eid);

}  // namespace aten
}  // namespace dgl