sddmm.cuh 8.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/sddmm.cuh
 * \brief SDDMM CUDA kernel function header.
 */
#ifndef DGL_ARRAY_CUDA_SDDMM_CUH_
#define DGL_ARRAY_CUDA_SDDMM_CUH_

#include <dgl/bcast.h>
#include "macro.cuh"
#include "atomic.cuh"
#include "functor.cuh"
13
#include "./utils.h"
14
#include "../selector.h"
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {

using namespace cuda;

namespace aten {
namespace cuda {

/*!
 * \brief CUDA kernel of g-SDDMM on Coo format.
 * \note it uses edge parallel strategy, different threadblocks (on y-axis)
 *       is responsible for the computation on different edges. Threadblocks
 *       on the x-axis are responsible for the computation on different positions
 *       in feature dimension.
 */
template <typename Idx, typename DType, typename BinaryOp,
32
33
          bool UseBcast = false, bool UseIdx = false,
          int LhsTarget = 0, int RhsTarget = 2>
34
__global__ void SDDMMCooKernel(
35
36
37
38
39
40
  const DType* __restrict__ lhs,
  const DType* __restrict__ rhs,
  DType* __restrict__ out,
  const Idx* __restrict__ row,
  const Idx* __restrict__ col,
  const Idx* __restrict__ edge_map,
41
  int64_t N, int64_t M, int64_t E, int64_t reduce_size,
42
43
44
  const int64_t* __restrict__ lhs_off,
  const int64_t* __restrict__ rhs_off,
  int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
45
46
47
48
49
50
51
52
  // SDDMM with COO.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = _ldg(row + ty);
    const Idx dst = _ldg(col + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
    const DType* lhsoff = BinaryOp::use_lhs ?
53
      (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len): nullptr;
54
    const DType* rhsoff = BinaryOp::use_rhs ?
55
      (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len): nullptr;
56
57
58
59
    DType* outoff = out + eid * out_len;
    int tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int stride_x = blockDim.x * gridDim.x;
    while (tx < out_len) {
60
61
      const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
      const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
      DType val = BinaryOp::Call(
          lhsoff + lhs_add * reduce_size,
          rhsoff + rhs_add * reduce_size,
          reduce_size);
      outoff[tx] = val;
      tx += stride_x;
    }
    ty += stride_y;
  }
}

// Binary search the row_offsets to find the source node of the edge id.
template <typename Idx>
__device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx eid) {
  Idx lo = 0, hi = length - 1;
  while (lo < hi) {
    Idx mid = (lo + hi) >> 1;
    if (_ldg(array + mid) <= eid) {
      lo = mid + 1;
    } else {
      hi = mid;
    }
  }
  // INVARIANT: lo == hi
  if (_ldg(array + hi) == eid) {
    return hi;
  } else {
    return hi - 1;
  }
}

/*!
 * \brief CUDA kernel of g-SDDMM on Csr format.
 * \note it uses edge parallel strategy, different threadblocks (on y-axis)
 *       is responsible for the computation on different edges. Threadblocks
 *       on the x-axis are responsible for the computation on different positions
 *       in feature dimension.
 *       To efficiently find the source node idx and destination node index of an 
 *       given edge on Csr format, it uses binary search (time complexity O(log N)).
 */
template <typename Idx, typename DType, typename BinaryOp,
103
104
          bool UseBcast = false, bool UseIdx = false,
          int LhsTarget = 0, int RhsTarget = 2>
105
__global__ void SDDMMCsrKernel(
106
107
108
109
110
111
  const DType* __restrict__ lhs,
  const DType* __restrict__ rhs,
  DType* __restrict__ out,
  const Idx* __restrict__ indptr,
  const Idx* __restrict__ indices,
  const Idx* __restrict__ edge_map,
112
  int64_t N, int64_t M, int64_t E, int64_t reduce_size,
113
114
115
  const int64_t* __restrict__ lhs_off,
  const int64_t* __restrict__ rhs_off,
  int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
116
117
118
119
120
121
122
123
124
  // SDDMM with Csr.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = BinarySearchSrc<Idx>(indptr, N + 1, ty);
    const Idx dst = _ldg(indices + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
    int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int64_t stride_x = blockDim.x * gridDim.x;
125
126
127
128
    const DType* lhsoff = BinaryOp::use_lhs ?
      (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len): nullptr;
    const DType* rhsoff = BinaryOp::use_rhs ?
      (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len): nullptr;
129
130
    DType* outoff = out + eid * out_len;
    while (tx < out_len) {
131
132
      const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
      const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
      DType val = BinaryOp::Call(
          lhsoff + lhs_add * reduce_size,
          rhsoff + rhs_add * reduce_size,
          reduce_size);
      outoff[tx] = val;
      tx += stride_x;
    }
    ty += stride_y;
  }
}

/*!
 * \brief CUDA implementation of g-SDDMM on Coo format.
 * \param bcast Broadcast information.
 * \param coo The Coo matrix.
148
149
 * \param lhs The left hand side operand feature.
 * \param rhs The right hand size operand feature.
150
151
 * \param out The result feature on edges.
 */
152
153
template <typename Idx, typename DType, typename Op,
          int LhsTarget = 0, int RhsTarget = 2>
154
155
156
void SDDMMCoo(
    const BcastOff& bcast,
    const COOMatrix& coo,
157
158
    NDArray lhs,
    NDArray rhs,
159
160
161
162
    NDArray out) {
  const Idx *row = coo.row.Ptr<Idx>();
  const Idx *col = coo.col.Ptr<Idx>();
  const Idx *edge_map = coo.data.Ptr<Idx>();
163
164
  const DType *lhs_data = lhs.Ptr<DType>();
  const DType *rhs_data = rhs.Ptr<DType>();
165
166
167
  DType *out_data = out.Ptr<DType>();
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();

168
  int64_t *lhs_off = nullptr, *rhs_off = nullptr;
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
  int64_t len = bcast.out_len,
          lhs_len = bcast.lhs_len,
          rhs_len = bcast.rhs_len;
  int64_t reduce_dim = bcast.reduce_size;

  const int64_t nnz = coo.row->shape[0];
  const int ntx = FindNumThreads(len);
  const int nty = CUDA_MAX_NUM_THREADS / ntx;
  const int nbx = (len + ntx - 1) / ntx;
  const int nby = FindNumBlocks<'y'>((nnz + nty - 1) / nty);
  //LOG(INFO) << "nblks=(" << nbx << ", " << nby << ") nthrs=(" << ntx << ", " << nty << ")";
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const bool use_idx = !IsNullArray(coo.data);

184
185
  BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
    SDDMMCooKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>
186
      <<<nblks, nthrs, 0, thr_entry->stream>>>(
187
        lhs_data, rhs_data, out_data,
188
189
        row, col, edge_map,
        coo.num_rows, coo.num_cols, nnz, reduce_dim,
190
        lhs_off, rhs_off,
191
192
193
194
195
196
197
198
199
        lhs_len, rhs_len, len
      );
  });
}

/*!
 * \brief CUDA implementation of g-SDDMM on Csr format.
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
200
201
 * \param lhs The left hand side operand feature.
 * \param rhs The right hand size operand feature.
202
203
 * \param out The result feature on edges.
 */
204
205
template <typename Idx, typename DType, typename Op,
          int LhsTarget = 0, int RhsTarget = 2>
206
207
208
void SDDMMCsr(
    const BcastOff& bcast,
    const CSRMatrix& csr,
209
210
211
212
    NDArray lhs,
    NDArray rhs,
    NDArray out) {
  const Idx *indptr = csr.indptr.Ptr<Idx>();
213
214
  const Idx *indices = csr.indices.Ptr<Idx>();
  const Idx *edge_map = csr.data.Ptr<Idx>();
215
216
  const DType *lhs_data = lhs.Ptr<DType>();
  const DType *rhs_data = rhs.Ptr<DType>();
217
218
219
220
  DType *out_data = out.Ptr<DType>();
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];

221
  int64_t *lhs_off = nullptr, *rhs_off = nullptr;
222
223
224
225
226
227
228
229
230
231
232
233
234
  int64_t len = bcast.out_len,
          lhs_len = bcast.lhs_len,
          rhs_len = bcast.rhs_len;
  int64_t reduce_dim = bcast.reduce_size;

  const int ntx = FindNumThreads(len);
  const int nty = CUDA_MAX_NUM_THREADS / ntx;
  const int nbx = (len + ntx - 1) / ntx;
  const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const bool use_idx = !IsNullArray(csr.data);

235
236
  BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
    SDDMMCsrKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>
237
      <<<nblks, nthrs, 0, thr_entry->stream>>>(
238
        lhs_data, rhs_data, out_data,
239
240
        indptr, indices, edge_map,
        N, M, E, reduce_dim,
241
        lhs_off, rhs_off,
242
243
244
245
246
247
248
249
250
251
        lhs_len, rhs_len, len
      );
  });
}

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

#endif