"docs/vscode:/vscode.git/clone" did not exist on "76d14f8cb92c73ac75a1d859a088629670de4290"
spmm_hetero.cu 11.3 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/spmm.cu
 * @brief SPMM C APIs and definitions.
5
6
 */
#include <dgl/array.h>
7

8
9
#include <cstdlib>

10
#include "../../runtime/cuda/cuda_common.h"
11
12
13
#include "./functor.cuh"
#include "./ge_spmm.cuh"
#include "./spmm.cuh"
14
15
16
17
18
19
20

namespace dgl {

using namespace cuda;

namespace aten {

21
/**
22
23
 * @brief CUDA implementation of g-SpMM on Csr format.
 * @note use cusparse if the reduce operator is `sum` and there is
24
25
 *       no broadcast, use dgl's kernel in other cases.
 */
26
template <int XPU, typename IdType, typename DType>
27
28
29
30
31
32
33
34
35
36
void SpMMCsrHetero(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& vec_ufeat,
    const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,  // ufeat node type id
    const std::vector<dgl_type_t>& out_ntids) {  // output node type id
  bool is_scalar_efeat =
      vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
37
38
  bool use_efeat = op != "copy_lhs";
  auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
39
  std::vector<DType*> trans_out((*vec_out).size(), NULL);
40
41
42
  bool use_deterministic_alg_only = false;
  if (NULL != std::getenv("USE_DETERMINISTIC_ALG"))
    use_deterministic_alg_only = true;
43

44
45
46
47
  bool use_legacy_cusparsemm =
      (CUDART_VERSION < 11000) && (reduce == "sum") &&
      // legacy cuSPARSE does not care about NNZ, hence the argument "false".
      ((op == "copy_lhs" && cusparse_available<DType, IdType>(false)) ||
48
49
       (op == "mul" && is_scalar_efeat &&
        cusparse_available<DType, IdType>(false)));
50
51
52
53
54
55
  // Create temporary output buffer to store non-transposed output
  if (use_legacy_cusparsemm) {
    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
      const int m = (*vec_out)[ntype]->shape[0];
      const int n = (*vec_out)[ntype]->shape[1];
      if (m == 0) continue;
56
57
      DType* out = static_cast<DType*>(device->AllocWorkspace(
          vec_csr[0].indptr->ctx, m * n * sizeof(DType)));
58
59
      CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType)));
      trans_out[ntype] = out;
60
    }
61
62
63
64
65
66
  }
  // Check shape of ufeat for all relation type and compute feature size
  int64_t x_length = 1;
  for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
    NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
    NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
67
68
    CHECK_EQ(ufeat->ndim, next_ufeat->ndim)
        << "Input features have different shapes";
69
70
71
    for (int i = 1; i < ufeat->ndim; ++i) {
      if (ufeat->shape[i] != next_ufeat->shape[i]) {
        if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
72
73
74
75
76
          LOG(FATAL) << "Homogenized message passing on heterogeneous graphs "
                        "does not support "
                     << "automatic broadcasting.  Please manually broadcast it "
                        "before calling "
                     << "message passing functions.";
77
78
79
        else
          LOG(FATAL) << "Input features have different shapes.";
        return;
80
      }
81

82
      if (etype == 0) x_length *= ufeat->shape[i];
83
    }
84
  }
85
86
87
  // TODO(Israt): Can python do the following initializations while creating the
  // tensors?
  if (reduce == "max" || reduce == "min") {
88
89
90
    const int64_t dim = bcast.out_len;
    std::vector<bool> updated((*vec_out).size(), false);
    for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
91
      DType* out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
92
      if (reduce == "max")
93
94
95
        _Fill(
            out_off, vec_csr[etype].num_rows * dim,
            cuda::reduce::Max<IdType, DType>::zero());
96
      else  // min
97
98
99
        _Fill(
            out_off, vec_csr[etype].num_rows * dim,
            cuda::reduce::Min<IdType, DType>::zero());
100
101
102
103
      const dgl_type_t dst_id = out_ntids[etype];
      if (!updated[dst_id]) {
        updated[dst_id] = true;
        if (op == "copy_lhs") {
104
105
106
107
          IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
          _Fill(
              argu_ntype, vec_csr[etype].num_rows * dim,
              static_cast<IdType>(-1));
108
109
        }
        if (op == "copy_rhs") {
110
111
112
113
          IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
          _Fill(
              arge_etype, vec_csr[etype].num_rows * dim,
              static_cast<IdType>(-1));
114
115
116
        }
      }
    }
117
  }
118

119
120
121
122
123
124
125
  cudaStream_t stream = runtime::getCurrentCUDAStream();
  for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
    const dgl_type_t src_id = ufeat_ntids[etype];
    const dgl_type_t dst_id = out_ntids[etype];
    CSRMatrix csr = vec_csr[etype];
    if (reduce == "sum") {
      bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
126
127
128
129
130
131
132
133
      /* Call  SpMM for each relation type */
      if (op == "copy_lhs" &&
          cusparse_available<DType, IdType>(more_nnz)) {  // cusparse
        /* If CUDA is less than 11.0, put the output in trans_out for later
         * transposition */
        DType* out = (CUDART_VERSION < 11000)
                         ? trans_out[dst_id]
                         : static_cast<DType*>((*vec_out)[dst_id]->data);
134
        CusparseCsrmm2Hetero<DType, IdType>(
135
            csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
136
            nullptr, out, x_length, stream, use_deterministic_alg_only);
137
138
      } else if (
          op == "mul" && is_scalar_efeat &&
139
140
          cusparse_available<DType, IdType>(more_nnz)) {  // cusparse
        NDArray efeat = vec_efeat[etype];
141
        if (!IsNullArray(csr.data)) efeat = IndexSelect(efeat, csr.data);
142
        CusparseCsrmm2Hetero<DType, IdType>(
143
            csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
144
            static_cast<DType*>(efeat->data),
145
146
            // TODO(Israt): Change (*vec_out) to trans_out to support CUDA
            // version < 11
147
148
            static_cast<DType*>((*vec_out)[dst_id]->data), x_length, stream,
            use_deterministic_alg_only);
149
      } else {  // general kernel
150
151
152
153
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
154
        SWITCH_OP(op, Op, {
155
156
157
          cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType>>(
              bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(),
              NullArray());
158
159
160
        });
      }
    } else if (reduce == "max") {
161
162
163
164
165
166
167
168
169
170
171
      SWITCH_OP(op, Op, {
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
        cuda::SpMMCmpCsrHetero<
            IdType, DType, Op, cuda::reduce::Max<IdType, DType>>(
            bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
            (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
            src_id, etype);
      });
172
    } else if (reduce == "min") {
173
174
175
176
177
178
179
180
181
182
      SWITCH_OP(op, Op, {
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
        cuda::SpMMCmpCsrHetero<
            IdType, DType, Op, cuda::reduce::Min<IdType, DType>>(
            bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
            (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
            src_id, etype);
183
184
185
      });
    } else {
      LOG(FATAL) << "Not implemented";
186
    }
187
  }
188

189
190
191
192
193
194
  if (use_legacy_cusparsemm) {
    // transpose output
    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
      const int m = (*vec_out)[ntype]->shape[0];
      const int n = (*vec_out)[ntype]->shape[1];
      if (m == 0) continue;
195
      DType* C_data = static_cast<DType*>((*vec_out)[ntype]->data);
196
197
      _Transpose(trans_out[ntype], C_data, n, m);
      device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
198
    }
199
  }
200
201
}

202
template void SpMMCsrHetero<kDGLCUDA, int32_t, __half>(
203
204
205
206
207
208
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
209
template void SpMMCsrHetero<kDGLCUDA, int64_t, __half>(
210
211
212
213
214
215
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
216
217
#if BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
218
219
220
221
222
223
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
224
template void SpMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
225
226
227
228
229
230
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
231
232
#endif  // BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, float>(
233
234
235
236
237
238
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
239
template void SpMMCsrHetero<kDGLCUDA, int64_t, float>(
240
241
242
243
244
245
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
246
template void SpMMCsrHetero<kDGLCUDA, int32_t, double>(
247
248
249
250
251
252
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
253
template void SpMMCsrHetero<kDGLCUDA, int64_t, double>(
254
255
256
257
258
259
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
260
261
262

}  // namespace aten
}  // namespace dgl