spmm.hip 7.44 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file array/cuda/spmm.cu
 * @brief SPMM C APIs and definitions.
6
7
 */
#include <dgl/array.h>
8

9
10
#include <cstdlib>

11
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
12
13
14
#include "functor.cuh"
#include "ge_spmm.cuh"
#include "spmm.cuh"
15
16
17
18
19
20
21

namespace dgl {

using namespace cuda;

namespace aten {

22
/**
23
24
 * @brief CUDA implementation of g-SpMM on Csr format.
 * @note use cusparse if the reduce operator is `sum` and there is
25
26
 *       no broadcast, use dgl's kernel in other cases.
 */
27
template <int XPU, typename IdType, typename DType>
28
29
30
31
void SpMMCsr(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux) {
32
33
  bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0];
  bool use_efeat = op != "copy_lhs";
34
35
36
  bool use_deterministic_alg_only = false;
  if (NULL != std::getenv("USE_DETERMINISTIC_ALG"))
    use_deterministic_alg_only = true;
37

38
  if (reduce == "sum") {
39
    bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
40
    if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) {
41
      // cusparse
42
      int64_t x_length = 1;
43
      for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
44
      CusparseCsrmm2<DType, IdType>(
45
          ufeat->ctx, csr, static_cast<DType*>(ufeat->data), nullptr,
46
          static_cast<DType*>(out->data), x_length, use_deterministic_alg_only);
47
48
49
    } else if (
        op == "mul" && is_scalar_efeat &&
        cusparse_available<DType, IdType>(more_nnz)) {
50
      // cusparse
51
      int64_t x_length = 1;
52
      for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
53
      if (!IsNullArray(csr.data)) {
54
        efeat = IndexSelect(efeat, csr.data);
55
      }
56
      CusparseCsrmm2<DType, IdType>(
57
58
          ufeat->ctx, csr, static_cast<DType*>(ufeat->data),
          static_cast<DType*>(efeat->data), static_cast<DType*>(out->data),
59
          x_length, use_deterministic_alg_only);
60
    } else {  // general kernel
61
62
63
      SWITCH_OP(op, Op, {
        cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
            bcast, csr, ufeat, efeat, out, NullArray(), NullArray());
64
65
66
      });
    }
  } else if (reduce == "max") {
67
68
69
    SWITCH_OP(op, Op, {
      cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
          bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
70
71
    });
  } else if (reduce == "min") {
72
73
74
    SWITCH_OP(op, Op, {
      cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
          bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
75
76
77
78
79
80
    });
  } else {
    LOG(FATAL) << "Not implemented";
  }
}

81
/**
82
 * @brief CUDA implementation of g-SpMM on Coo format.
83
 */
84
template <int XPU, typename IdType, typename DType>
85
86
87
88
void SpMMCoo(
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux) {
89
  if (reduce == "sum") {
90
    SWITCH_OP(op, Op, {
91
      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> >(
92
          bcast, coo, ufeat, efeat, out, NullArray(), NullArray());
93
94
    });
  } else if (reduce == "max") {
95
    SWITCH_OP(op, Op, {
96
      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> >(
97
          bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
98
    });
99
  } else if (reduce == "min") {
100
    SWITCH_OP(op, Op, {
101
      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> >(
102
          bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
103
104
105
106
107
108
    });
  } else {
    LOG(FATAL) << "Not implemented";
  }
}

109
template void SpMMCsr<kDGLCUDA, int32_t, __half>(
110
111
112
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
113
template void SpMMCsr<kDGLCUDA, int64_t, __half>(
114
115
116
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
117
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
118
template void SpMMCsr<kDGLCUDA, int32_t, __hip_bfloat16>(
119
120
121
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
sangwzh's avatar
sangwzh committed
122
template void SpMMCsr<kDGLCUDA, int64_t, __hip_bfloat16>(
123
124
125
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
126
127
#endif  // BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, float>(
128
129
130
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
131
template void SpMMCsr<kDGLCUDA, int64_t, float>(
132
133
134
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
135
template void SpMMCsr<kDGLCUDA, int32_t, double>(
136
137
138
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
139
template void SpMMCsr<kDGLCUDA, int64_t, double>(
140
141
142
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
143

144
template void SpMMCoo<kDGLCUDA, int32_t, __half>(
145
146
147
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
148
template void SpMMCoo<kDGLCUDA, int64_t, __half>(
149
150
151
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
152
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
153
template void SpMMCoo<kDGLCUDA, int32_t, __hip_bfloat16>(
154
155
156
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
sangwzh's avatar
sangwzh committed
157
template void SpMMCoo<kDGLCUDA, int64_t, __hip_bfloat16>(
158
159
160
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
161
162
#endif  // BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, float>(
163
164
165
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
166
template void SpMMCoo<kDGLCUDA, int64_t, float>(
167
168
169
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
170
template void SpMMCoo<kDGLCUDA, int32_t, double>(
171
172
173
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
174
template void SpMMCoo<kDGLCUDA, int64_t, double>(
175
176
177
    const std::string& op, const std::string& reduce, const BcastOff& bcast,
    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
    std::vector<NDArray> out_aux);
178
179
180

}  // namespace aten
}  // namespace dgl