spmm_hetero.hip 11.4 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cuda/spmm.cu
 * @brief SPMM C APIs and definitions.
7
8
 */
#include <dgl/array.h>
9

10
11
#include <cstdlib>

12
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
13
14
15
#include "functor.cuh"
#include "ge_spmm.cuh"
#include "spmm.cuh"
16
17
18
19
20
21
22

namespace dgl {

using namespace cuda;

namespace aten {

23
/**
24
25
 * @brief CUDA implementation of g-SpMM on Csr format.
 * @note use cusparse if the reduce operator is `sum` and there is
26
27
 *       no broadcast, use dgl's kernel in other cases.
 */
28
template <int XPU, typename IdType, typename DType>
29
30
31
32
33
34
35
36
37
38
void SpMMCsrHetero(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& vec_ufeat,
    const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,  // ufeat node type id
    const std::vector<dgl_type_t>& out_ntids) {  // output node type id
  bool is_scalar_efeat =
      vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
39
40
  bool use_efeat = op != "copy_lhs";
  auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
41
  std::vector<DType*> trans_out((*vec_out).size(), NULL);
42
43
44
  bool use_deterministic_alg_only = false;
  if (NULL != std::getenv("USE_DETERMINISTIC_ALG"))
    use_deterministic_alg_only = true;
45

46
  bool use_legacy_cusparsemm =
sangwzh's avatar
sangwzh committed
47
      (DTKRT_VERSION < 11000) && (reduce == "sum") &&
48
49
      // legacy cuSPARSE does not care about NNZ, hence the argument "false".
      ((op == "copy_lhs" && cusparse_available<DType, IdType>(false)) ||
50
51
       (op == "mul" && is_scalar_efeat &&
        cusparse_available<DType, IdType>(false)));
52
53
54
55
56
57
  // Create temporary output buffer to store non-transposed output
  if (use_legacy_cusparsemm) {
    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
      const int m = (*vec_out)[ntype]->shape[0];
      const int n = (*vec_out)[ntype]->shape[1];
      if (m == 0) continue;
58
59
      DType* out = static_cast<DType*>(device->AllocWorkspace(
          vec_csr[0].indptr->ctx, m * n * sizeof(DType)));
sangwzh's avatar
sangwzh committed
60
      CUDA_CALL(hipMemset(out, 0, m * n * sizeof(DType)));
61
      trans_out[ntype] = out;
62
    }
63
64
65
66
67
68
  }
  // Check shape of ufeat for all relation type and compute feature size
  int64_t x_length = 1;
  for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
    NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
    NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
69
70
    CHECK_EQ(ufeat->ndim, next_ufeat->ndim)
        << "Input features have different shapes";
71
72
73
    for (int i = 1; i < ufeat->ndim; ++i) {
      if (ufeat->shape[i] != next_ufeat->shape[i]) {
        if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
74
75
76
77
78
          LOG(FATAL) << "Homogenized message passing on heterogeneous graphs "
                        "does not support "
                     << "automatic broadcasting.  Please manually broadcast it "
                        "before calling "
                     << "message passing functions.";
79
80
81
        else
          LOG(FATAL) << "Input features have different shapes.";
        return;
82
      }
83

84
      if (etype == 0) x_length *= ufeat->shape[i];
85
    }
86
  }
87
88
89
  // TODO(Israt): Can python do the following initializations while creating the
  // tensors?
  if (reduce == "max" || reduce == "min") {
90
91
92
    const int64_t dim = bcast.out_len;
    std::vector<bool> updated((*vec_out).size(), false);
    for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
93
      DType* out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
94
      if (reduce == "max")
95
96
97
        _Fill(
            out_off, vec_csr[etype].num_rows * dim,
            cuda::reduce::Max<IdType, DType>::zero());
98
      else  // min
99
100
101
        _Fill(
            out_off, vec_csr[etype].num_rows * dim,
            cuda::reduce::Min<IdType, DType>::zero());
102
103
104
105
      const dgl_type_t dst_id = out_ntids[etype];
      if (!updated[dst_id]) {
        updated[dst_id] = true;
        if (op == "copy_lhs") {
106
107
108
109
          IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
          _Fill(
              argu_ntype, vec_csr[etype].num_rows * dim,
              static_cast<IdType>(-1));
110
111
        }
        if (op == "copy_rhs") {
112
113
114
115
          IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
          _Fill(
              arge_etype, vec_csr[etype].num_rows * dim,
              static_cast<IdType>(-1));
116
117
118
        }
      }
    }
119
  }
120

sangwzh's avatar
sangwzh committed
121
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
122
123
124
125
126
127
  for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
    const dgl_type_t src_id = ufeat_ntids[etype];
    const dgl_type_t dst_id = out_ntids[etype];
    CSRMatrix csr = vec_csr[etype];
    if (reduce == "sum") {
      bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
128
129
130
131
132
      /* Call  SpMM for each relation type */
      if (op == "copy_lhs" &&
          cusparse_available<DType, IdType>(more_nnz)) {  // cusparse
        /* If CUDA is less than 11.0, put the output in trans_out for later
         * transposition */
sangwzh's avatar
sangwzh committed
133
        DType* out = (DTKRT_VERSION < 11000)
134
135
                         ? trans_out[dst_id]
                         : static_cast<DType*>((*vec_out)[dst_id]->data);
136
        CusparseCsrmm2Hetero<DType, IdType>(
137
            csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
138
            nullptr, out, x_length, stream, use_deterministic_alg_only);
139
140
      } else if (
          op == "mul" && is_scalar_efeat &&
141
142
          cusparse_available<DType, IdType>(more_nnz)) {  // cusparse
        NDArray efeat = vec_efeat[etype];
143
        if (!IsNullArray(csr.data)) efeat = IndexSelect(efeat, csr.data);
144
        CusparseCsrmm2Hetero<DType, IdType>(
145
            csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
146
            static_cast<DType*>(efeat->data),
147
148
            // TODO(Israt): Change (*vec_out) to trans_out to support CUDA
            // version < 11
149
150
            static_cast<DType*>((*vec_out)[dst_id]->data), x_length, stream,
            use_deterministic_alg_only);
151
      } else {  // general kernel
152
153
154
155
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
156
        SWITCH_OP(op, Op, {
157
158
159
          cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType>>(
              bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(),
              NullArray());
160
161
162
        });
      }
    } else if (reduce == "max") {
163
164
165
166
167
168
169
170
171
172
173
      SWITCH_OP(op, Op, {
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
        cuda::SpMMCmpCsrHetero<
            IdType, DType, Op, cuda::reduce::Max<IdType, DType>>(
            bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
            (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
            src_id, etype);
      });
174
    } else if (reduce == "min") {
175
176
177
178
179
180
181
182
183
184
      SWITCH_OP(op, Op, {
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
        cuda::SpMMCmpCsrHetero<
            IdType, DType, Op, cuda::reduce::Min<IdType, DType>>(
            bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
            (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
            src_id, etype);
185
186
187
      });
    } else {
      LOG(FATAL) << "Not implemented";
188
    }
189
  }
190

191
192
193
194
195
196
  if (use_legacy_cusparsemm) {
    // transpose output
    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
      const int m = (*vec_out)[ntype]->shape[0];
      const int n = (*vec_out)[ntype]->shape[1];
      if (m == 0) continue;
197
      DType* C_data = static_cast<DType*>((*vec_out)[ntype]->data);
198
199
      _Transpose(trans_out[ntype], C_data, n, m);
      device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
200
    }
201
  }
202
203
}

204
template void SpMMCsrHetero<kDGLCUDA, int32_t, __half>(
205
206
207
208
209
210
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
211
template void SpMMCsrHetero<kDGLCUDA, int64_t, __half>(
212
213
214
215
216
217
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
218
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
219
template void SpMMCsrHetero<kDGLCUDA, int32_t, __hip_bfloat16>(
220
221
222
223
224
225
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
sangwzh's avatar
sangwzh committed
226
template void SpMMCsrHetero<kDGLCUDA, int64_t, __hip_bfloat16>(
227
228
229
230
231
232
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
233
234
#endif  // BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, float>(
235
236
237
238
239
240
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
241
template void SpMMCsrHetero<kDGLCUDA, int64_t, float>(
242
243
244
245
246
247
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
248
template void SpMMCsrHetero<kDGLCUDA, int32_t, double>(
249
250
251
252
253
254
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
255
template void SpMMCsrHetero<kDGLCUDA, int64_t, double>(
256
257
258
259
260
261
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
262
263
264

}  // namespace aten
}  // namespace dgl