spmm.cu 7.18 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/spmm.cu
 * @brief SPMM C APIs and definitions.
5
6
 */
#include <dgl/array.h>
7

8
#include "../../runtime/cuda/cuda_common.h"
9
10
11
#include "./functor.cuh"
#include "./ge_spmm.cuh"
#include "./spmm.cuh"
12
13
14
15
16
17
18

namespace dgl {

using namespace cuda;

namespace aten {

19
/**
20
21
 * @brief CUDA implementation of g-SpMM on Csr format.
 * @note use cusparse if the reduce operator is `sum` and there is
22
23
 *       no broadcast, use dgl's kernel in other cases.
 */
24
template <int XPU, typename IdType, typename DType>
25
26
27
28
void SpMMCsr(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux) {
29
30
31
  bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0];
  bool use_efeat = op != "copy_lhs";

32
  if (reduce == "sum") {
33
    bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
34
    if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) {
35
      // cusparse
36
      int64_t x_length = 1;
37
      for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
38
      CusparseCsrmm2<DType, IdType>(
39
40
41
42
43
          ufeat->ctx, csr, static_cast<DType*>(ufeat->data), nullptr,
          static_cast<DType*>(out->data), x_length);
    } else if (
        op == "mul" && is_scalar_efeat &&
        cusparse_available<DType, IdType>(more_nnz)) {
44
      // cusparse
45
      int64_t x_length = 1;
46
      for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
47
      if (!IsNullArray(csr.data)) {
48
        efeat = IndexSelect(efeat, csr.data);
49
      }
50
      CusparseCsrmm2<DType, IdType>(
51
52
          ufeat->ctx, csr, static_cast<DType*>(ufeat->data),
          static_cast<DType*>(efeat->data), static_cast<DType*>(out->data),
53
          x_length);
54
    } else {  // general kernel
55
56
57
      SWITCH_OP(op, Op, {
        cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
            bcast, csr, ufeat, efeat, out, NullArray(), NullArray());
58
59
60
      });
    }
  } else if (reduce == "max") {
61
62
63
    SWITCH_OP(op, Op, {
      cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
          bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
64
65
    });
  } else if (reduce == "min") {
66
67
68
    SWITCH_OP(op, Op, {
      cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
          bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
69
70
71
72
73
74
    });
  } else {
    LOG(FATAL) << "Not implemented";
  }
}

75
/**
76
 * @brief CUDA implementation of g-SpMM on Coo format.
77
 */
78
template <int XPU, typename IdType, typename DType>
79
80
81
82
void SpMMCoo(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux) {
83
  if (reduce == "sum") {
84
    SWITCH_OP(op, Op, {
85
      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> >(
86
          bcast, coo, ufeat, efeat, out, NullArray(), NullArray());
87
88
    });
  } else if (reduce == "max") {
89
    SWITCH_OP(op, Op, {
90
      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> >(
91
          bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
92
    });
93
  } else if (reduce == "min") {
94
    SWITCH_OP(op, Op, {
95
      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> >(
96
          bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
97
98
99
100
101
102
    });
  } else {
    LOG(FATAL) << "Not implemented";
  }
}

103
template void SpMMCsr<kDGLCUDA, int32_t, __half>(
104
105
106
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
107
template void SpMMCsr<kDGLCUDA, int64_t, __half>(
108
109
110
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
111
112
#if BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
113
114
115
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
116
template void SpMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
117
118
119
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
120
121
#endif  // BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, float>(
122
123
124
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
125
template void SpMMCsr<kDGLCUDA, int64_t, float>(
126
127
128
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
129
template void SpMMCsr<kDGLCUDA, int32_t, double>(
130
131
132
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
133
template void SpMMCsr<kDGLCUDA, int64_t, double>(
134
135
136
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
137

138
template void SpMMCoo<kDGLCUDA, int32_t, __half>(
139
140
141
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
142
template void SpMMCoo<kDGLCUDA, int64_t, __half>(
143
144
145
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
146
147
#if BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
148
149
150
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
151
template void SpMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
152
153
154
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
155
156
#endif  // BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, float>(
157
158
159
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
160
template void SpMMCoo<kDGLCUDA, int64_t, float>(
161
162
163
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
164
template void SpMMCoo<kDGLCUDA, int32_t, double>(
165
166
167
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
168
template void SpMMCoo<kDGLCUDA, int64_t, double>(
169
170
171
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
172
173
174

}  // namespace aten
}  // namespace dgl