"...schedulers/scheduling_cosine_dpmsolver_multistep.py" did not exist on "16ad13b61da804901f7a2c3b207534db34007614"
sddmm_hetero_coo.cu 4.13 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/sddmm.cu
 * @brief SDDMM C APIs and definitions.
5
6
 */
#include <dgl/array.h>
7

8
9
10
11
12
#include "./sddmm.cuh"

namespace dgl {
namespace aten {

13
/**
14
 * @brief CUDA implementation of g-SDDMM on heterograph using
15
16
    Csr format.
 */
17
template <int XPU, typename IdType, typename DType>
18
19
20
21
22
23
void SDDMMCooHetero(
    const std::string& op, const BcastOff& bcast,
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,
    const std::vector<dgl_type_t>& rhs_eid) {
24
25
26
27
28
29
30
31
32
  SWITCH_OP(op, Op, {
    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
      /* Call SDDMM CUDA kernel for each relation type sequentially */
      for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
        COOMatrix coo = vec_coo[etype];
        NDArray lhs = vec_lhs[lhs_eid[etype]];
        NDArray rhs = vec_rhs[rhs_eid[etype]];
        NDArray out = vec_out[etype];
        cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
33
            bcast, coo, lhs, rhs, out);
34
      }
35
36
37
38
    });
  });
}

39
40
template void SDDMMCooHetero<kDGLCUDA, int32_t, __half>(
    const std::string& op, const BcastOff& bcast,
41
42
43
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
44
45
46
    const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, __half>(
    const std::string& op, const BcastOff& bcast,
47
48
49
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
50
51
52
    const std::vector<dgl_type_t>& out_eid);
#if BF16_ENABLED
template void SDDMMCooHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
53
    const std::string& op, const BcastOff& bcast,
54
55
56
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
57
    const std::vector<dgl_type_t>& out_eid);
58
template void SDDMMCooHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
59
    const std::string& op, const BcastOff& bcast,
60
61
62
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
63
    const std::vector<dgl_type_t>& out_eid);
64
65
#endif  // BF16_ENABLED
template void SDDMMCooHetero<kDGLCUDA, int32_t, float>(
66
    const std::string& op, const BcastOff& bcast,
67
68
69
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
70
    const std::vector<dgl_type_t>& out_eid);
71
template void SDDMMCooHetero<kDGLCUDA, int64_t, float>(
72
    const std::string& op, const BcastOff& bcast,
73
74
75
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
76
    const std::vector<dgl_type_t>& out_eid);
77
template void SDDMMCooHetero<kDGLCUDA, int32_t, double>(
78
    const std::string& op, const BcastOff& bcast,
79
80
81
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
82
    const std::vector<dgl_type_t>& out_eid);
83
template void SDDMMCooHetero<kDGLCUDA, int64_t, double>(
84
    const std::string& op, const BcastOff& bcast,
85
86
87
    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
    int rhs_target, const std::vector<dgl_type_t>& in_eid,
88
89
90
91
    const std::vector<dgl_type_t>& out_eid);

}  // namespace aten
}  // namespace dgl