segment_reduce.hip 6.69 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file array/cuda/segment_reduce.cu
 * @brief Segment reduce C APIs and definitions.
6
7
 */
#include <dgl/array.h>
8
#include <dgl/base_heterograph.h>
9

sangwzh's avatar
sangwzh committed
10
11
12
#include "functor.cuh"
#include "segment_reduce.cuh"
#include "utils.h"
13
14
15
16
17
18
19

namespace dgl {

using namespace cuda;

namespace aten {

20
template <int XPU, typename IdType, typename DType>
21
22
23
void SegmentReduce(
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
    NDArray arg) {
24
25
26
27
28
29
30
31
32
33
34
35
  if (op == "sum") {
    cuda::SegmentReduce<IdType, DType, cuda::reduce::Sum<IdType, DType>>(
        feat, offsets, out, arg);
  } else if (op == "max") {
    cuda::SegmentReduce<IdType, DType, cuda::reduce::Max<IdType, DType>>(
        feat, offsets, out, arg);
  } else if (op == "min") {
    cuda::SegmentReduce<IdType, DType, cuda::reduce::Min<IdType, DType>>(
        feat, offsets, out, arg);
  } else {
    LOG(FATAL) << "Not implemented";
  }
36
37
}

38
template <int XPU, typename IdType, typename DType>
39
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
40
  cuda::ScatterAdd<IdType, DType>(feat, idx, out);
41
42
}

43
template <int XPU, typename IdType, typename DType>
44
45
46
47
48
49
void UpdateGradMinMax_hetero(
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
  cuda::UpdateGradMinMax_hetero<IdType, DType>(
      g, op, feat, idx, idx_etype, out);
50
51
}

52
template <int XPU, typename IdType, typename DType>
53
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
54
  cuda::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
55
56
}

57
template void SegmentReduce<kDGLCUDA, int32_t, __half>(
58
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
59
    NDArray arg);
60
template void SegmentReduce<kDGLCUDA, int64_t, __half>(
61
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
62
    NDArray arg);
63
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
64
template void SegmentReduce<kDGLCUDA, int32_t, __hip_bfloat16>(
65
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
66
    NDArray arg);
sangwzh's avatar
sangwzh committed
67
template void SegmentReduce<kDGLCUDA, int64_t, __hip_bfloat16>(
68
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
69
    NDArray arg);
70
71
#endif  // BF16_ENABLED
template void SegmentReduce<kDGLCUDA, int32_t, float>(
72
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
73
74
    NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, float>(
75
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
76
    NDArray arg);
77
template void SegmentReduce<kDGLCUDA, int32_t, double>(
78
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
79
    NDArray arg);
80
template void SegmentReduce<kDGLCUDA, int64_t, double>(
81
    const std::string& op, NDArray feat, NDArray offsets, NDArray out,
82
83
84
    NDArray arg);

template void ScatterAdd<kDGLCUDA, int32_t, __half>(
85
    NDArray feat, NDArray idx, NDArray out);
86
template void ScatterAdd<kDGLCUDA, int64_t, __half>(
87
    NDArray feat, NDArray idx, NDArray out);
88
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
89
template void ScatterAdd<kDGLCUDA, int32_t, __hip_bfloat16>(
90
    NDArray feat, NDArray idx, NDArray out);
sangwzh's avatar
sangwzh committed
91
template void ScatterAdd<kDGLCUDA, int64_t, __hip_bfloat16>(
92
    NDArray feat, NDArray idx, NDArray out);
93
94
#endif  // BF16_ENABLED
template void ScatterAdd<kDGLCUDA, int32_t, float>(
95
    NDArray feat, NDArray idx, NDArray out);
96
template void ScatterAdd<kDGLCUDA, int64_t, float>(
97
    NDArray feat, NDArray idx, NDArray out);
98
template void ScatterAdd<kDGLCUDA, int32_t, double>(
99
    NDArray feat, NDArray idx, NDArray out);
100
template void ScatterAdd<kDGLCUDA, int64_t, double>(
101
    NDArray feat, NDArray idx, NDArray out);
102

103
104
105
106
107
108
109
110
111
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, __half>(
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, __half>(
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
112
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, __hip_bfloat16>(
113
114
115
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
sangwzh's avatar
sangwzh committed
116
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, __hip_bfloat16>(
117
118
119
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
120
121
#endif  // BF16_ENABLED
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, float>(
122
123
124
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
125
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, float>(
126
127
128
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
129
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, double>(
130
131
132
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
133
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, double>(
134
135
136
137
    const HeteroGraphPtr& g, const std::string& op,
    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);

138
template void BackwardSegmentCmp<kDGLCUDA, int32_t, __half>(
139
    NDArray feat, NDArray arg, NDArray out);
140
template void BackwardSegmentCmp<kDGLCUDA, int64_t, __half>(
141
    NDArray feat, NDArray arg, NDArray out);
142
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
143
template void BackwardSegmentCmp<kDGLCUDA, int32_t, __hip_bfloat16>(
144
    NDArray feat, NDArray arg, NDArray out);
sangwzh's avatar
sangwzh committed
145
template void BackwardSegmentCmp<kDGLCUDA, int64_t, __hip_bfloat16>(
146
    NDArray feat, NDArray arg, NDArray out);
147
148
#endif  // BF16_ENABLED
template void BackwardSegmentCmp<kDGLCUDA, int32_t, float>(
149
    NDArray feat, NDArray arg, NDArray out);
150
template void BackwardSegmentCmp<kDGLCUDA, int64_t, float>(
151
    NDArray feat, NDArray arg, NDArray out);
152
template void BackwardSegmentCmp<kDGLCUDA, int32_t, double>(
153
    NDArray feat, NDArray arg, NDArray out);
154
template void BackwardSegmentCmp<kDGLCUDA, int64_t, double>(
155
    NDArray feat, NDArray arg, NDArray out);
156
157
158

}  // namespace aten
}  // namespace dgl