spmm_hetero.cu 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/spmm.cu
 * \brief SPMM C APIs and definitions.
 */
#include <dgl/array.h>
#include "./spmm.cuh"
#include "./ge_spmm.cuh"
#include "./functor.cuh"
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {

using namespace cuda;

namespace aten {

/*!
 * \brief CUDA implementation of g-SpMM on Csr format.
 * \note use cusparse if the reduce operator is `sum` and there is
 *       no broadcast, use dgl's kernel in other cases.
 */
23
template <int XPU, typename IdType, typename DType>
24
25
26
27
28
29
30
31
32
33
34
35
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const std::vector<CSRMatrix>& vec_csr,
             const std::vector<NDArray>& vec_ufeat,
             const std::vector<NDArray>& vec_efeat,
             std::vector<NDArray>* vec_out,
             std::vector<std::vector<NDArray>>* out_aux,
             const std::vector<dgl_type_t>& ufeat_ntids,  // ufeat node type id
             const std::vector<dgl_type_t>& out_ntids) {  // output node type id
  bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
  bool use_efeat = op != "copy_lhs";
  auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
36
  std::vector<DType*> trans_out((*vec_out).size(), NULL);
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  bool use_legacy_cusparsemm =
      (CUDART_VERSION < 11000) && (reduce == "sum") &&
      // legacy cuSPARSE does not care about NNZ, hence the argument "false".
      ((op == "copy_lhs" && cusparse_available<DType, IdType>(false)) ||
        (op == "mul" && is_scalar_efeat && cusparse_available<DType, IdType>(false)));
  // Create temporary output buffer to store non-transposed output
  if (use_legacy_cusparsemm) {
    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
      const int m = (*vec_out)[ntype]->shape[0];
      const int n = (*vec_out)[ntype]->shape[1];
      if (m == 0) continue;
      DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx,
        m * n * sizeof(DType)));
      CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType)));
      trans_out[ntype] = out;
53
    }
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  }
  // Check shape of ufeat for all relation type and compute feature size
  int64_t x_length = 1;
  for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
    NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
    NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
    CHECK_EQ(ufeat->ndim, next_ufeat->ndim) << "Input features have different shapes";
    for (int i = 1; i < ufeat->ndim; ++i) {
      if (ufeat->shape[i] != next_ufeat->shape[i]) {
        if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
          LOG(FATAL) <<
            "Homogenized message passing on heterogeneous graphs does not support " <<
            "automatic broadcasting.  Please manually broadcast it before calling " <<
            "message passing functions.";
        else
          LOG(FATAL) << "Input features have different shapes.";
        return;
71
      }
72
73
74

      if (etype == 0)
        x_length *= ufeat->shape[i];
75
    }
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
  }
  // TODO(Israt): Can python do the following initializations while creating the tensors?
  if (reduce == "max" ||  reduce == "min") {
    const int64_t dim = bcast.out_len;
    std::vector<bool> updated((*vec_out).size(), false);
    for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
      DType *out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
      if (reduce == "max")
        _Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Max<IdType, DType>::zero());
      else  // min
        _Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Min<IdType, DType>::zero());
      const dgl_type_t dst_id = out_ntids[etype];
      if (!updated[dst_id]) {
        updated[dst_id] = true;
        if (op == "copy_lhs") {
          IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
          _Fill(argu_ntype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
        }
        if (op == "copy_rhs") {
          IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
          _Fill(arge_etype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
97
98
99
        }
      }
    }
100
  }
101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
  cudaStream_t stream = runtime::getCurrentCUDAStream();
  for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
    const dgl_type_t src_id = ufeat_ntids[etype];
    const dgl_type_t dst_id = out_ntids[etype];
    CSRMatrix csr = vec_csr[etype];
    if (reduce == "sum") {
      bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
        /* Call  SpMM for each relation type */
      if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) {  // cusparse
        /* If CUDA is less than 11.0, put the output in trans_out for later transposition */
        DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
          static_cast<DType*>((*vec_out)[dst_id]->data);
        CusparseCsrmm2Hetero<DType, IdType>(
            csr.indptr->ctx, csr,
            static_cast<DType*>(vec_ufeat[src_id]->data),
            nullptr,
            out,
            x_length, stream);
      } else if (op == "mul" && is_scalar_efeat &&
          cusparse_available<DType, IdType>(more_nnz)) {  // cusparse
        NDArray efeat = vec_efeat[etype];
        if (!IsNullArray(csr.data))
          efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
        CusparseCsrmm2Hetero<DType, IdType>(
            csr.indptr->ctx, csr,
            static_cast<DType*>(vec_ufeat[src_id]->data),
            static_cast<DType*>(efeat->data),
            // TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11
            static_cast<DType*>((*vec_out)[dst_id]->data),
            x_length, stream);
      } else {  // general kernel
        NDArray ufeat = (vec_ufeat.size() == 0) ?
          NullArray() : vec_ufeat[src_id];
        NDArray efeat = (vec_efeat.size() == 0) ?
          NullArray() : vec_efeat[etype];
        SWITCH_OP(op, Op, {
          cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
              bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray());
        });
      }
    } else if (reduce == "max") {
        SWITCH_OP(op, Op, {
144
          NDArray ufeat = (vec_ufeat.size() == 0) ?
145
              NullArray() : vec_ufeat[src_id];
146
          NDArray efeat = (vec_efeat.size() == 0) ?
147
148
149
150
151
              NullArray() : vec_efeat[etype];
          cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
              bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
              (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
              src_id, etype);
152
        });
153
154
155
156
157
158
159
160
161
162
163
164
165
    } else if (reduce == "min") {
        SWITCH_OP(op, Op, {
          NDArray ufeat = (vec_ufeat.size() == 0) ?
              NullArray() : vec_ufeat[src_id];
          NDArray efeat = (vec_efeat.size() == 0) ?
              NullArray() : vec_efeat[etype];
          cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
              bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
              (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
              src_id, etype);
      });
    } else {
      LOG(FATAL) << "Not implemented";
166
    }
167
  }
168

169
170
171
172
173
174
175
176
177
  if (use_legacy_cusparsemm) {
    // transpose output
    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
      const int m = (*vec_out)[ntype]->shape[0];
      const int n = (*vec_out)[ntype]->shape[1];
      if (m == 0) continue;
      DType *C_data = static_cast<DType*>((*vec_out)[ntype]->data);
      _Transpose(trans_out[ntype], C_data, n, m);
      device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
178
    }
179
  }
180
181
}

182
template void SpMMCsrHetero<kDGLCUDA, int32_t, __half>(
183
184
185
186
187
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
188
template void SpMMCsrHetero<kDGLCUDA, int64_t, __half>(
189
190
191
192
193
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
194
195
#if BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
196
197
198
199
200
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
201
template void SpMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
202
203
204
205
206
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
207
208
#endif  // BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, float>(
209
210
211
212
213
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
214
215
216
217
218
219
220
221
222
223
224
225
226
template void SpMMCsrHetero<kDGLCUDA, int64_t, float>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int32_t, double>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, double>(
227
228
229
230
231
232
233
234
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);

}  // namespace aten
}  // namespace dgl