"git@developer.sourcefind.cn:OpenDAS/llama-factory.git" did not exist on "c7d1b209a125ef2c1c082c169d6be2908a467328"
spmm_hetero.cu 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/spmm.cu
 * \brief SPMM C APIs and definitions.
 */
#include <dgl/array.h>
#include "./spmm.cuh"
#include "./ge_spmm.cuh"
#include "./functor.cuh"
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {

using namespace cuda;

namespace aten {

/*!
 * \brief Determine whether cusparse SpMM function is applicable.
 */
template <int bits, typename IdType>
inline bool cusparse_available(bool more_nnz_than_matrix_size) {
#if CUDART_VERSION < 11000
  if (std::is_same<IdType, int>::value)
    if (bits > 16)
      return true;
  return false;
#else
  if (bits == 16)
    return false;  // cusparse's SpMM on fp16 is slow, temporally disabled.
  // If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
  return !more_nnz_than_matrix_size;
#endif
}

/*!
 * \brief CUDA implementation of g-SpMM on Csr format.
 * \note use cusparse if the reduce operator is `sum` and there is
 *       no broadcast, use dgl's kernel in other cases.
 */
template <int XPU, typename IdType, int bits>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const std::vector<CSRMatrix>& vec_csr,
             const std::vector<NDArray>& vec_ufeat,
             const std::vector<NDArray>& vec_efeat,
             std::vector<NDArray>* vec_out,
             std::vector<std::vector<NDArray>>* out_aux,
             const std::vector<dgl_type_t>& ufeat_ntids,  // ufeat node type id
             const std::vector<dgl_type_t>& out_ntids) {  // output node type id
  bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
  bool use_efeat = op != "copy_lhs";
  auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
  SWITCH_BITS(bits, DType, {
    std::vector<DType*> trans_out((*vec_out).size(), NULL);

    bool use_legacy_cusparsemm =
        // (CUDART_VERSION < 11000) && (reduce == "sum") &&
        (CUDART_VERSION_LT_11000) && (reduce == "sum") &&
        // legacy cuSPARSE does not care about NNZ, hence the argument "false".
        ((op == "copy_lhs" && cusparse_available<bits, IdType>(false)) ||
         (op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(false)));
    // Create temporary output buffer to store non-transposed output
    if (use_legacy_cusparsemm) {
      for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
        const int m = (*vec_out)[ntype]->shape[0];
        const int n = (*vec_out)[ntype]->shape[1];
        if (m == 0) continue;
        DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx,
          m * n * sizeof(DType)));
        CUDA_CALL(hipMemset(out, 0, m * n * sizeof(DType)));
        trans_out[ntype] = out;
      }
    }
    // Check shape of ufeat for all relation type and compute feature size
    int64_t x_length = 1;
    for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
      NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
      NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
      CHECK_EQ(ufeat->ndim, next_ufeat->ndim) << "Input features have different shapes";
      for (int i = 1; i < ufeat->ndim; ++i) {
        if (ufeat->shape[i] != next_ufeat->shape[i]) {
          if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
            LOG(FATAL) <<
              "Homogenized message passing on heterogeneous graphs does not support " <<
              "automatic broadcasting.  Please manually broadcast it before calling " <<
              "message passing functions.";
          else
            LOG(FATAL) << "Input features have different shapes.";
          return;
        }

        if (etype == 0)
          x_length *= ufeat->shape[i];
      }
    }
    // TODO(Israt): Can python do the following initializations while creating the tensors?
    if (reduce == "max" ||  reduce == "min") {
      const int64_t dim = bcast.out_len;
      std::vector<bool> updated((*vec_out).size(), false);
      for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
        DType *out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
        if (reduce == "max")
          _Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Max<IdType, DType>::zero());
        else  // min
          _Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Min<IdType, DType>::zero());
        const dgl_type_t dst_id = out_ntids[etype];
        if (!updated[dst_id]) {
          updated[dst_id] = true;
          if (op == "copy_lhs") {
            IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
            _Fill(argu_ntype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
          }
          if (op == "copy_rhs") {
            IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
            _Fill(arge_etype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
          }
        }
      }
    }

    hipStream_t stream = runtime::getCurrentCUDAStream();
    for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
      const dgl_type_t src_id = ufeat_ntids[etype];
      const dgl_type_t dst_id = out_ntids[etype];
      CSRMatrix csr = vec_csr[etype];
      if (reduce == "sum") {
        bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
          /* Call  SpMM for each relation type */
        if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) {  // cusparse
          /* If CUDA is less than 11.0, put the output in trans_out for later transposition */
          // DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
          DType *out = (CUDART_VERSION_LT_11000) ? trans_out[dst_id] :
            static_cast<DType*>((*vec_out)[dst_id]->data);
          CusparseCsrmm2Hetero<DType, IdType>(
              csr.indptr->ctx, csr,
              static_cast<DType*>(vec_ufeat[src_id]->data),
              nullptr,
              out,
              x_length, stream);
        } else if (op == "mul" && is_scalar_efeat &&
            cusparse_available<bits, IdType>(more_nnz)) {  // cusparse
          NDArray efeat = vec_efeat[etype];
          if (!IsNullArray(csr.data))
            efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
          CusparseCsrmm2Hetero<DType, IdType>(
              csr.indptr->ctx, csr,
              static_cast<DType*>(vec_ufeat[src_id]->data),
              static_cast<DType*>(efeat->data),
              // TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11
              static_cast<DType*>((*vec_out)[dst_id]->data),
              x_length, stream);
        } else {  // general kernel
          NDArray ufeat = (vec_ufeat.size() == 0) ?
            NullArray() : vec_ufeat[src_id];
          NDArray efeat = (vec_efeat.size() == 0) ?
            NullArray() : vec_efeat[etype];
          SWITCH_OP(op, Op, {
            cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
                bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray());
          });
        }
      } else if (reduce == "max") {
          SWITCH_OP(op, Op, {
            NDArray ufeat = (vec_ufeat.size() == 0) ?
                NullArray() : vec_ufeat[src_id];
            NDArray efeat = (vec_efeat.size() == 0) ?
                NullArray() : vec_efeat[etype];
            cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
                bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
                (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
                src_id, etype);
          });
      } else if (reduce == "min") {
          SWITCH_OP(op, Op, {
            NDArray ufeat = (vec_ufeat.size() == 0) ?
                NullArray() : vec_ufeat[src_id];
            NDArray efeat = (vec_efeat.size() == 0) ?
                NullArray() : vec_efeat[etype];
            cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
                bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
                (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
                src_id, etype);
        });
      } else {
        LOG(FATAL) << "Not implemented";
      }
    }

    if (use_legacy_cusparsemm) {
      // transpose output
      for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
        const int m = (*vec_out)[ntype]->shape[0];
        const int n = (*vec_out)[ntype]->shape[1];
        if (m == 0) continue;
        DType *C_data = static_cast<DType*>((*vec_out)[ntype]->data);
        _Transpose(trans_out[ntype], C_data, n, m);
        device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
      }
    }
  });
}

lisj's avatar
lisj committed
204
template void SpMMCsrHetero<kDLROCM, int32_t, 16>(
205
206
207
208
209
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
lisj's avatar
lisj committed
210
template void SpMMCsrHetero<kDLROCM, int64_t, 16>(
211
212
213
214
215
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
lisj's avatar
lisj committed
216
template void SpMMCsrHetero<kDLROCM, int32_t, 32>(
217
218
219
220
221
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
lisj's avatar
lisj committed
222
template void SpMMCsrHetero<kDLROCM, int64_t, 32>(
223
224
225
226
227
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
lisj's avatar
lisj committed
228
template void SpMMCsrHetero<kDLROCM, int32_t, 64>(
229
230
231
232
233
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
lisj's avatar
lisj committed
234
template void SpMMCsrHetero<kDLROCM, int64_t, 64>(
235
236
237
238
239
240
241
242
243
244
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
    const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
    std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
    const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);



}  // namespace aten
}  // namespace dgl