"sgl-kernel/vscode:/vscode.git/clone" did not exist on "704ced1b2ec45a87bb42494fe9da18fa5004e546"
sddmm_hetero_coo.hip 4.19 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file array/cuda/sddmm.cu
 * @brief SDDMM C APIs and definitions.
6
7
 */
#include <dgl/array.h>
8

sangwzh's avatar
sangwzh committed
9
#include "sddmm.cuh"
10
11
12
13

namespace dgl {
namespace aten {

14
/**
15
 * @brief CUDA implementation of g-SDDMM on heterograph using
16
17
    Csr format.
 */
18
template <int XPU, typename IdType, typename DType>
19
20
21
22
23
24
void SDDMMCooHetero(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,
    const std::vector<dgl_type_t>& rhs_eid) {
25
26
27
28
29
30
31
32
33
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      /* Call SDDMM CUDA kernel for each relation type sequentially */
      for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
        COOMatrix coo = vec_coo[etype];
        NDArray lhs = vec_lhs[lhs_eid[etype]];
        NDArray rhs = vec_rhs[rhs_eid[etype]];
        NDArray out = vec_out[etype];
        cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
34
            bcast, coo, lhs, rhs, out);
35
      }
36
37
38
39
    });
  });
}

40
41
template void SDDMMCooHetero<kDGLCUDA, int32_t, __half>(
    const std::string& op, const BcastOff& bcast,
42
43
44
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
45
46
47
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, __half>(
    const std::string& op, const BcastOff& bcast,
48
49
50
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
51
52
    const std::vector<dgl_type_t>& out_eid);
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
53
template void SDDMMCooHetero<kDGLCUDA, int32_t, __hip_bfloat16>(
54
    const std::string& op, const BcastOff& bcast,
55
56
57
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
58
    const std::vector<dgl_type_t>& out_eid);
sangwzh's avatar
sangwzh committed
59
template void SDDMMCooHetero<kDGLCUDA, int64_t, __hip_bfloat16>(
60
    const std::string& op, const BcastOff& bcast,
61
62
63
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
64
    const std::vector<dgl_type_t>& out_eid);
65
66
#endif  // BF16_ENABLED
template void SDDMMCooHetero<kDGLCUDA, int32_t, float>(
67
    const std::string& op, const BcastOff& bcast,
68
69
70
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
71
    const std::vector<dgl_type_t>& out_eid);
72
template void SDDMMCooHetero<kDGLCUDA, int64_t, float>(
73
    const std::string& op, const BcastOff& bcast,
74
75
76
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
77
    const std::vector<dgl_type_t>& out_eid);
78
template void SDDMMCooHetero<kDGLCUDA, int32_t, double>(
79
    const std::string& op, const BcastOff& bcast,
80
81
82
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
83
    const std::vector<dgl_type_t>& out_eid);
84
template void SDDMMCooHetero<kDGLCUDA, int64_t, double>(
85
    const std::string& op, const BcastOff& bcast,
86
87
88
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
89
90
91
92
    const std::vector<dgl_type_t>& out_eid);

}  // namespace aten
}  // namespace dgl