spmm_hetero.cu 11.1 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/spmm.cu
 * @brief SPMM C APIs and definitions.
5
6
 */
#include <dgl/array.h>
7

8
#include "../../runtime/cuda/cuda_common.h"
9
10
11
#include "./functor.cuh"
#include "./ge_spmm.cuh"
#include "./spmm.cuh"
12
13
14
15
16
17
18

namespace dgl {

using namespace cuda;

namespace aten {

19
/**
20
21
 * @brief CUDA implementation of g-SpMM on Csr format.
 * @note use cusparse if the reduce operator is `sum` and there is
22
23
 *       no broadcast, use dgl's kernel in other cases.
 */
24
template <int XPU, typename IdType, typename DType>
25
26
27
28
29
30
31
32
33
34
void SpMMCsrHetero(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& vec_csr,
    const std::vector<NDArray>& vec_ufeat,
    const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,  // ufeat node type id
    const std::vector<dgl_type_t>& out_ntids) {  // output node type id
  bool is_scalar_efeat =
      vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
35
36
  bool use_efeat = op != "copy_lhs";
  auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
37
  std::vector<DType*> trans_out((*vec_out).size(), NULL);
38

39
40
41
42
  bool use_legacy_cusparsemm =
      (CUDART_VERSION < 11000) && (reduce == "sum") &&
      // legacy cuSPARSE does not care about NNZ, hence the argument "false".
      ((op == "copy_lhs" && cusparse_available<DType, IdType>(false)) ||
43
44
       (op == "mul" && is_scalar_efeat &&
        cusparse_available<DType, IdType>(false)));
45
46
47
48
49
50
  // Create temporary output buffer to store non-transposed output
  if (use_legacy_cusparsemm) {
    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
      const int m = (*vec_out)[ntype]->shape[0];
      const int n = (*vec_out)[ntype]->shape[1];
      if (m == 0) continue;
51
52
      DType* out = static_cast<DType*>(device->AllocWorkspace(
          vec_csr[0].indptr->ctx, m * n * sizeof(DType)));
53
54
      CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType)));
      trans_out[ntype] = out;
55
    }
56
57
58
59
60
61
  }
  // Check shape of ufeat for all relation type and compute feature size
  int64_t x_length = 1;
  for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
    NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
    NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
62
63
    CHECK_EQ(ufeat->ndim, next_ufeat->ndim)
        << "Input features have different shapes";
64
65
66
    for (int i = 1; i < ufeat->ndim; ++i) {
      if (ufeat->shape[i] != next_ufeat->shape[i]) {
        if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
67
68
69
70
71
          LOG(FATAL) << "Homogenized message passing on heterogeneous graphs "
                        "does not support "
                     << "automatic broadcasting.  Please manually broadcast it "
                        "before calling "
                     << "message passing functions.";
72
73
74
        else
          LOG(FATAL) << "Input features have different shapes.";
        return;
75
      }
76

77
      if (etype == 0) x_length *= ufeat->shape[i];
78
    }
79
  }
80
81
82
  // TODO(Israt): Can python do the following initializations while creating the
  // tensors?
  if (reduce == "max" || reduce == "min") {
83
84
85
    const int64_t dim = bcast.out_len;
    std::vector<bool> updated((*vec_out).size(), false);
    for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
86
      DType* out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
87
      if (reduce == "max")
88
89
90
        _Fill(
            out_off, vec_csr[etype].num_rows * dim,
            cuda::reduce::Max<IdType, DType>::zero());
91
      else  // min
92
93
94
        _Fill(
            out_off, vec_csr[etype].num_rows * dim,
            cuda::reduce::Min<IdType, DType>::zero());
95
96
97
98
      const dgl_type_t dst_id = out_ntids[etype];
      if (!updated[dst_id]) {
        updated[dst_id] = true;
        if (op == "copy_lhs") {
99
100
101
102
          IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
          _Fill(
              argu_ntype, vec_csr[etype].num_rows * dim,
              static_cast<IdType>(-1));
103
104
        }
        if (op == "copy_rhs") {
105
106
107
108
          IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
          _Fill(
              arge_etype, vec_csr[etype].num_rows * dim,
              static_cast<IdType>(-1));
109
110
111
        }
      }
    }
112
  }
113

114
115
116
117
118
119
120
  cudaStream_t stream = runtime::getCurrentCUDAStream();
  for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
    const dgl_type_t src_id = ufeat_ntids[etype];
    const dgl_type_t dst_id = out_ntids[etype];
    CSRMatrix csr = vec_csr[etype];
    if (reduce == "sum") {
      bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
121
122
123
124
125
126
127
128
      /* Call  SpMM for each relation type */
      if (op == "copy_lhs" &&
          cusparse_available<DType, IdType>(more_nnz)) {  // cusparse
        /* If CUDA is less than 11.0, put the output in trans_out for later
         * transposition */
        DType* out = (CUDART_VERSION < 11000)
                         ? trans_out[dst_id]
                         : static_cast<DType*>((*vec_out)[dst_id]->data);
129
        CusparseCsrmm2Hetero<DType, IdType>(
130
131
132
133
            csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
            nullptr, out, x_length, stream);
      } else if (
          op == "mul" && is_scalar_efeat &&
134
135
136
          cusparse_available<DType, IdType>(more_nnz)) {  // cusparse
        NDArray efeat = vec_efeat[etype];
        if (!IsNullArray(csr.data))
137
          efeat = IndexSelect(efeat, csr.data);
138
        CusparseCsrmm2Hetero<DType, IdType>(
139
            csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
140
            static_cast<DType*>(efeat->data),
141
142
143
            // TODO(Israt): Change (*vec_out) to trans_out to support CUDA
            // version < 11
            static_cast<DType*>((*vec_out)[dst_id]->data), x_length, stream);
144
      } else {  // general kernel
145
146
147
148
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
149
        SWITCH_OP(op, Op, {
150
151
152
          cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType>>(
              bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(),
              NullArray());
153
154
155
        });
      }
    } else if (reduce == "max") {
156
157
158
159
160
161
162
163
164
165
166
      SWITCH_OP(op, Op, {
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
        cuda::SpMMCmpCsrHetero<
            IdType, DType, Op, cuda::reduce::Max<IdType, DType>>(
            bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
            (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
            src_id, etype);
      });
167
    } else if (reduce == "min") {
168
169
170
171
172
173
174
175
176
177
      SWITCH_OP(op, Op, {
        NDArray ufeat =
            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
        NDArray efeat =
            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
        cuda::SpMMCmpCsrHetero<
            IdType, DType, Op, cuda::reduce::Min<IdType, DType>>(
            bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
            (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
            src_id, etype);
178
179
180
      });
    } else {
      LOG(FATAL) << "Not implemented";
181
    }
182
  }
183

184
185
186
187
188
189
  if (use_legacy_cusparsemm) {
    // transpose output
    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
      const int m = (*vec_out)[ntype]->shape[0];
      const int n = (*vec_out)[ntype]->shape[1];
      if (m == 0) continue;
190
      DType* C_data = static_cast<DType*>((*vec_out)[ntype]->data);
191
192
      _Transpose(trans_out[ntype], C_data, n, m);
      device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
193
    }
194
  }
195
196
}

197
template void SpMMCsrHetero<kDGLCUDA, int32_t, __half>(
198
199
200
201
202
203
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
204
template void SpMMCsrHetero<kDGLCUDA, int64_t, __half>(
205
206
207
208
209
210
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
211
212
#if BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
213
214
215
216
217
218
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
219
template void SpMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
220
221
222
223
224
225
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
226
227
#endif  // BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, float>(
228
229
230
231
232
233
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
234
template void SpMMCsrHetero<kDGLCUDA, int64_t, float>(
235
236
237
238
239
240
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
241
template void SpMMCsrHetero<kDGLCUDA, int32_t, double>(
242
243
244
245
246
247
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
248
template void SpMMCsrHetero<kDGLCUDA, int64_t, double>(
249
250
251
252
253
254
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
    std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids,
    const std::vector<dgl_type_t>& out_ntids);
255
256
257

}  // namespace aten
}  // namespace dgl