spmm_hetero.hip 11.2 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cuda/spmm.cu
 * @brief SPMM C APIs and definitions.
7
8
 */
#include <dgl/array.h>
9

10
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
11
12
13
#include "functor.cuh"
#include "ge_spmm.cuh"
#include "spmm.cuh"
14
15
16
17
18
19
20

namespace dgl {

using namespace cuda;

namespace aten {

21
/**
22
23
 * @brief CUDA implementation of g-SpMM on Csr format.
 * @note use cusparse if the reduce operator is `sum` and there is
24
25
 *       no broadcast, use dgl's kernel in other cases.
 */
26
template <int XPU, typename IdType, typename DType>
27
28
29
30
31
32
33
34
35
36
void SpMMCsrHetero(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& vec_ufeat,
    const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,  // ufeat node type id
    const std::vector<dgl_type_t>& out_ntids) {  // output node type id
  bool is_scalar_efeat =
      vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
37
38
  bool use_efeat = op != "copy_lhs";
  auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
39
  std::vector<DType*> trans_out((*vec_out).size(), NULL);
40

41
  bool use_legacy_cusparsemm =
sangwzh's avatar
sangwzh committed
42
      (DTKRT_VERSION < 11000) && (reduce == "sum") &&
43
44
      // legacy cuSPARSE does not care about NNZ, hence the argument "false".
      ((op == "copy_lhs" && cusparse_available<DType, IdType>(false)) ||
45
46
       (op == "mul" && is_scalar_efeat &&
        cusparse_available<DType, IdType>(false)));
47
48
49
50
51
52
  // Create temporary output buffer to store non-transposed output
  if (use_legacy_cusparsemm) {
    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
      const int m = (*vec_out)[ntype]->shape[0];
      const int n = (*vec_out)[ntype]->shape[1];
      if (m == 0) continue;
53
54
      DType* out = static_cast<DType*>(device->AllocWorkspace(
          vec_csr[0].indptr->ctx, m * n * sizeof(DType)));
sangwzh's avatar
sangwzh committed
55
      CUDA_CALL(hipMemset(out, 0, m * n * sizeof(DType)));
56
      trans_out[ntype] = out;
57
    }
58
59
60
61
62
63
  }
  // Check shape of ufeat for all relation type and compute feature size
  int64_t x_length = 1;
  for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
    NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
    NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
64
65
    CHECK_EQ(ufeat->ndim, next_ufeat->ndim)
        << "Input features have different shapes";
66
67
68
    for (int i = 1; i < ufeat->ndim; ++i) {
      if (ufeat->shape[i] != next_ufeat->shape[i]) {
        if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
69
70
71
72
73
          LOG(FATAL) << "Homogenized message passing on heterogeneous graphs "
                        "does not support "
                     << "automatic broadcasting.  Please manually broadcast it "
                        "before calling "
                     << "message passing functions.";
74
75
76
        else
          LOG(FATAL) << "Input features have different shapes.";
        return;
77
      }
78

79
      if (etype == 0) x_length *= ufeat->shape[i];
80
    }
81
  }
82
83
84
  // TODO(Israt): Can python do the following initializations while creating the
  // tensors?
  if (reduce == "max" || reduce == "min") {
85
86
87
    const int64_t dim = bcast.out_len;
    std::vector<bool> updated((*vec_out).size(), false);
    for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
88
      DType* out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
89
      if (reduce == "max")
90
91
92
        _Fill(
            out_off, vec_csr[etype].num_rows * dim,
            cuda::reduce::Max<IdType, DType>::zero());
93
      else  // min
94
95
96
        _Fill(
            out_off, vec_csr[etype].num_rows * dim,
            cuda::reduce::Min<IdType, DType>::zero());
97
98
99
100
      const dgl_type_t dst_id = out_ntids[etype];
      if (!updated[dst_id]) {
        updated[dst_id] = true;
        if (op == "copy_lhs") {
101
102
103
104
          IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
          _Fill(
              argu_ntype, vec_csr[etype].num_rows * dim,
              static_cast<IdType>(-1));
105
106
        }
        if (op == "copy_rhs") {
107
108
109
110
          IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
          _Fill(
              arge_etype, vec_csr[etype].num_rows * dim,
              static_cast<IdType>(-1));
111
112
113
        }
      }
    }
114
  }
115

sangwzh's avatar
sangwzh committed
116
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
117
118
119
120
121
122
  for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
    const dgl_type_t src_id = ufeat_ntids[etype];
    const dgl_type_t dst_id = out_ntids[etype];
    CSRMatrix csr = vec_csr[etype];
    if (reduce == "sum") {
      bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
123
124
125
126
127
      /* Call  SpMM for each relation type */
      if (op == "copy_lhs" &&
          cusparse_available<DType, IdType>(more_nnz)) {  // cusparse
        /* If CUDA is less than 11.0, put the output in trans_out for later
         * transposition */
sangwzh's avatar
sangwzh committed
128
        DType* out = (DTKRT_VERSION < 11000)
129
130
                         ? trans_out[dst_id]
                         : static_cast<DType*>((*vec_out)[dst_id]->data);
131
        CusparseCsrmm2Hetero<DType, IdType>(
132
133
134
135
            csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
            nullptr, out, x_length, stream);
      } else if (
          op == "mul" && is_scalar_efeat &&
136
137
138
          cusparse_available<DType, IdType>(more_nnz)) {  // cusparse
        NDArray efeat = vec_efeat[etype];
        if (!IsNullArray(csr.data))
139
          efeat = IndexSelect(efeat, csr.data);
140
        CusparseCsrmm2Hetero<DType, IdType>(
141
            csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
142
            static_cast<DType*>(efeat->data),
143
144
145
            // TODO(Israt): Change (*vec_out) to trans_out to support CUDA
            // version < 11
            static_cast<DType*>((*vec_out)[dst_id]->data), x_length, stream);
146
      } else {  // general kernel
147
148
149
150
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
151
        SWITCH_OP(op, Op, {
152
153
154
          cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType>>(
              bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(),
              NullArray());
155
156
157
        });
      }
    } else if (reduce == "max") {
158
159
160
161
162
163
164
165
166
167
168
      SWITCH_OP(op, Op, {
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
        cuda::SpMMCmpCsrHetero<
            IdType, DType, Op, cuda::reduce::Max<IdType, DType>>(
            bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
            (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
            src_id, etype);
      });
169
    } else if (reduce == "min") {
170
171
172
173
174
175
176
177
178
179
      SWITCH_OP(op, Op, {
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
        cuda::SpMMCmpCsrHetero<
            IdType, DType, Op, cuda::reduce::Min<IdType, DType>>(
            bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
            (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
            src_id, etype);
180
181
182
      });
    } else {
      LOG(FATAL) << "Not implemented";
183
    }
184
  }
185

186
187
188
189
190
191
  if (use_legacy_cusparsemm) {
    // transpose output
    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
      const int m = (*vec_out)[ntype]->shape[0];
      const int n = (*vec_out)[ntype]->shape[1];
      if (m == 0) continue;
192
      DType* C_data = static_cast<DType*>((*vec_out)[ntype]->data);
193
194
      _Transpose(trans_out[ntype], C_data, n, m);
      device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
195
    }
196
  }
197
198
}

199
template void SpMMCsrHetero<kDGLCUDA, int32_t, __half>(
200
201
202
203
204
205
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
206
template void SpMMCsrHetero<kDGLCUDA, int64_t, __half>(
207
208
209
210
211
212
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
213
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
214
template void SpMMCsrHetero<kDGLCUDA, int32_t, __hip_bfloat16>(
215
216
217
218
219
220
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
sangwzh's avatar
sangwzh committed
221
template void SpMMCsrHetero<kDGLCUDA, int64_t, __hip_bfloat16>(
222
223
224
225
226
227
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
228
229
#endif  // BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, float>(
230
231
232
233
234
235
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
236
template void SpMMCsrHetero<kDGLCUDA, int64_t, float>(
237
238
239
240
241
242
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
243
template void SpMMCsrHetero<kDGLCUDA, int32_t, double>(
244
245
246
247
248
249
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
250
template void SpMMCsrHetero<kDGLCUDA, int64_t, double>(
251
252
253
254
255
256
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
257
258
259

}  // namespace aten
}  // namespace dgl