spmm.cu 7.37 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/spmm.cu
 * @brief SPMM C APIs and definitions.
5
6
7
 */
#include <dgl/array.h>
#include "./spmm.cuh"
8
#include "./ge_spmm.cuh"
9
10
11
12
13
14
15
16
17
#include "./functor.cuh"
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {

using namespace cuda;

namespace aten {

18
/**
19
20
 * @brief CUDA implementation of g-SpMM on Csr format.
 * @note use cusparse if the reduce operator is `sum` and there is
21
22
 *       no broadcast, use dgl's kernel in other cases.
 */
23
template <int XPU, typename IdType, typename DType>
24
25
26
27
28
29
30
void SpMMCsr(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const CSRMatrix& csr,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
31
32
33
  bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0];
  bool use_efeat = op != "copy_lhs";

34
  if (reduce == "sum") {
35
    bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
36
    if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) {
37
      // cusparse
38
39
40
      int64_t x_length = 1;
      for (int i = 1; i < ufeat->ndim; ++i)
        x_length *= ufeat->shape[i];
41
42
43
44
45
46
47
      CusparseCsrmm2<DType, IdType>(
          ufeat->ctx, csr,
          static_cast<DType*>(ufeat->data),
          nullptr,
          static_cast<DType*>(out->data),
          x_length);
    } else if (op == "mul" && is_scalar_efeat && cusparse_available<DType, IdType>(more_nnz)) {
48
      // cusparse
49
50
51
      int64_t x_length = 1;
      for (int i = 1; i < ufeat->ndim; ++i)
        x_length *= ufeat->shape[i];
52
      if (!IsNullArray(csr.data)) {
53
        efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
54
      }
55
56
57
58
59
60
      CusparseCsrmm2<DType, IdType>(
          ufeat->ctx, csr,
          static_cast<DType*>(ufeat->data),
          static_cast<DType*>(efeat->data),
          static_cast<DType*>(out->data),
          x_length);
61
    } else {  // general kernel
62
63
64
      SWITCH_OP(op, Op, {
        cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
            bcast, csr, ufeat, efeat, out, NullArray(), NullArray());
65
66
67
      });
    }
  } else if (reduce == "max") {
68
69
70
    SWITCH_OP(op, Op, {
      cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
          bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
71
72
    });
  } else if (reduce == "min") {
73
74
75
    SWITCH_OP(op, Op, {
      cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
          bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
76
77
78
79
80
81
    });
  } else {
    LOG(FATAL) << "Not implemented";
  }
}

82

83
/**
84
 * @brief CUDA implementation of g-SpMM on Coo format.
85
 */
86
template <int XPU, typename IdType, typename DType>
87
88
89
90
91
92
93
94
void SpMMCoo(const std::string& op, const std::string& reduce,
             const BcastOff& bcast,
             const COOMatrix& coo,
             NDArray ufeat,
             NDArray efeat,
             NDArray out,
             std::vector<NDArray> out_aux) {
  if (reduce == "sum") {
95
96
97
    SWITCH_OP(op, Op, {
      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> > (
          bcast, coo, ufeat, efeat, out, NullArray(), NullArray());
98
99
    });
  } else if (reduce == "max") {
100
101
102
    SWITCH_OP(op, Op, {
      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> > (
          bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
103
104
    });
  }  else if (reduce == "min") {
105
106
107
    SWITCH_OP(op, Op, {
      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> > (
          bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
108
109
110
111
112
113
    });
  } else {
    LOG(FATAL) << "Not implemented";
  }
}

114
template void SpMMCsr<kDGLCUDA, int32_t, __half>(
115
116
117
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
118
template void SpMMCsr<kDGLCUDA, int64_t, __half>(
119
120
121
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
122
123
#if BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
124
125
126
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
127
template void SpMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
128
129
130
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
131
132
#endif  // BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, float>(
133
134
135
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
136
137
138
139
140
141
142
143
144
template void SpMMCsr<kDGLCUDA, int64_t, float>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int32_t, double>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, double>(
145
146
147
148
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

149
template void SpMMCoo<kDGLCUDA, int32_t, __half>(
150
151
152
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
153
template void SpMMCoo<kDGLCUDA, int64_t, __half>(
154
155
156
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
157
158
#if BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
159
160
161
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
162
template void SpMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
163
164
165
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
166
167
#endif  // BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, float>(
168
169
170
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
171
172
173
174
175
176
177
178
179
template void SpMMCoo<kDGLCUDA, int64_t, float>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, double>(
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, double>(
180
181
182
    const std::string& op, const std::string& reduce,
    const BcastOff& bcast, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
183
184
185

}  // namespace aten
}  // namespace dgl