"docs/vscode:/vscode.git/clone" did not exist on "a71886b4cadf5b1e0c0ed464f5874ae680510f07"
sddmm.cuh 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/sddmm.cuh
 * \brief SDDMM CUDA kernel function header.
 */
#ifndef DGL_ARRAY_CUDA_SDDMM_CUH_
#define DGL_ARRAY_CUDA_SDDMM_CUH_

#include <dgl/bcast.h>
#include "macro.cuh"
#include "atomic.cuh"
#include "functor.cuh"
13
#include "./utils.h"
14
#include "../selector.h"
15
16
17
18
19
20
21
22
23
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {

using namespace cuda;

namespace aten {
namespace cuda {

24
25
constexpr unsigned int full_mask = 0xffffffff;

26
27
28
29
30
31
32
33
/*!
 * \brief CUDA kernel of g-SDDMM on Coo format.
 * \note it uses edge parallel strategy, different threadblocks (on y-axis)
 *       is responsible for the computation on different edges. Threadblocks
 *       on the x-axis are responsible for the computation on different positions
 *       in feature dimension.
 */
template <typename Idx, typename DType, typename BinaryOp,
34
35
          bool UseBcast = false, bool UseIdx = false,
          int LhsTarget = 0, int RhsTarget = 2>
36
__global__ void SDDMMCooKernel(
37
38
39
40
41
42
  const DType* __restrict__ lhs,
  const DType* __restrict__ rhs,
  DType* __restrict__ out,
  const Idx* __restrict__ row,
  const Idx* __restrict__ col,
  const Idx* __restrict__ edge_map,
43
  int64_t N, int64_t M, int64_t E, int64_t reduce_size,
44
45
46
  const int64_t* __restrict__ lhs_off,
  const int64_t* __restrict__ rhs_off,
  int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
47
48
49
50
51
52
53
54
  // SDDMM with COO.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = _ldg(row + ty);
    const Idx dst = _ldg(col + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
    const DType* lhsoff = BinaryOp::use_lhs ?
55
      (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len): nullptr;
56
    const DType* rhsoff = BinaryOp::use_rhs ?
57
      (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len): nullptr;
58
59
60
61
    DType* outoff = out + eid * out_len;
    int tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int stride_x = blockDim.x * gridDim.x;
    while (tx < out_len) {
62
63
      const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
      const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
64
65
66
67
68
69
70
71
72
73
74
      DType val = BinaryOp::Call(
          lhsoff + lhs_add * reduce_size,
          rhsoff + rhs_add * reduce_size,
          reduce_size);
      outoff[tx] = val;
      tx += stride_x;
    }
    ty += stride_y;
  }
}

75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
/*!
 * \brief CUDA kernel of SDDMM-dot on Coo format, accelerated with tree reduction.
 * \note it uses edge parallel strategy, different threadblocks (on y-axis)
 *       is responsible for the computation on different edges. Threadblocks
 *       on the x-axis are responsible for the computation on different positions
 *       in feature dimension.
 */
template <typename Idx, typename DType,
          bool UseBcast = false, bool UseIdx = false,
          int LhsTarget = 0, int RhsTarget = 2>
__global__ void SDDMMCooTreeReduceKernel(
  const DType* __restrict__ lhs,
  const DType* __restrict__ rhs,
  DType* __restrict__ out,
  const Idx* __restrict__ row,
  const Idx* __restrict__ col,
  const Idx* __restrict__ edge_map,
  int64_t N, int64_t M, int64_t E, int64_t reduce_size,
  const int64_t* __restrict__ lhs_off,
  const int64_t* __restrict__ rhs_off,
  int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
  Idx ty = blockIdx.x * blockDim.y + threadIdx.y;
  if (ty < E) {
    const Idx src = _ldg(row + ty);
    const Idx dst = _ldg(col + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
    const DType* lhsoff = lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len;
    const DType* rhsoff = rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len;
    DType* outoff = out + eid * out_len;
    int tx = threadIdx.x;  // tx < 32
    for (int i = blockIdx.y; i < out_len; i += gridDim.y) {  // over output feature dimension
      const Idx lhs_add = UseBcast ? __ldg(lhs_off + i) : i;
      const Idx rhs_add = UseBcast ? __ldg(rhs_off + i) : i;
      DType val = 0.;
      for (int j = tx; j < reduce_size; j += 32)
        val += lhsoff[lhs_add * reduce_size + j] * rhsoff[rhs_add * reduce_size + j];
#pragma unroll
      for (int offset = 16; offset > 0; offset /= 2)
        val += __shfl_down_sync(full_mask, val, offset);
      if (tx == 0)
        outoff[i] = val;
    }
  }
}

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Binary search the row_offsets to find the source node of the edge id.
template <typename Idx>
__device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx eid) {
  Idx lo = 0, hi = length - 1;
  while (lo < hi) {
    Idx mid = (lo + hi) >> 1;
    if (_ldg(array + mid) <= eid) {
      lo = mid + 1;
    } else {
      hi = mid;
    }
  }
  // INVARIANT: lo == hi
  if (_ldg(array + hi) == eid) {
    return hi;
  } else {
    return hi - 1;
  }
}

/*!
 * \brief CUDA kernel of g-SDDMM on Csr format.
 * \note it uses edge parallel strategy, different threadblocks (on y-axis)
 *       is responsible for the computation on different edges. Threadblocks
 *       on the x-axis are responsible for the computation on different positions
 *       in feature dimension.
 *       To efficiently find the source node idx and destination node index of an 
 *       given edge on Csr format, it uses binary search (time complexity O(log N)).
 */
template <typename Idx, typename DType, typename BinaryOp,
150
151
          bool UseBcast = false, bool UseIdx = false,
          int LhsTarget = 0, int RhsTarget = 2>
152
__global__ void SDDMMCsrKernel(
153
154
155
156
157
158
  const DType* __restrict__ lhs,
  const DType* __restrict__ rhs,
  DType* __restrict__ out,
  const Idx* __restrict__ indptr,
  const Idx* __restrict__ indices,
  const Idx* __restrict__ edge_map,
159
  int64_t N, int64_t M, int64_t E, int64_t reduce_size,
160
161
162
  const int64_t* __restrict__ lhs_off,
  const int64_t* __restrict__ rhs_off,
  int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
163
164
165
166
167
168
169
170
171
  // SDDMM with Csr.
  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
  const Idx stride_y = blockDim.y * gridDim.y;
  while (ty < E) {
    const Idx src = BinarySearchSrc<Idx>(indptr, N + 1, ty);
    const Idx dst = _ldg(indices + ty);
    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
    int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
    const int64_t stride_x = blockDim.x * gridDim.x;
172
173
174
175
    const DType* lhsoff = BinaryOp::use_lhs ?
      (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len): nullptr;
    const DType* rhsoff = BinaryOp::use_rhs ?
      (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len): nullptr;
176
177
    DType* outoff = out + eid * out_len;
    while (tx < out_len) {
178
179
      const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
      const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
      DType val = BinaryOp::Call(
          lhsoff + lhs_add * reduce_size,
          rhsoff + rhs_add * reduce_size,
          reduce_size);
      outoff[tx] = val;
      tx += stride_x;
    }
    ty += stride_y;
  }
}

/*!
 * \brief CUDA implementation of g-SDDMM on Coo format.
 * \param bcast Broadcast information.
 * \param coo The Coo matrix.
195
196
 * \param lhs The left hand side operand feature.
 * \param rhs The right hand size operand feature.
197
198
 * \param out The result feature on edges.
 */
199
200
template <typename Idx, typename DType, typename Op,
          int LhsTarget = 0, int RhsTarget = 2>
201
202
203
void SDDMMCoo(
    const BcastOff& bcast,
    const COOMatrix& coo,
204
205
    NDArray lhs,
    NDArray rhs,
206
207
208
209
    NDArray out) {
  const Idx *row = coo.row.Ptr<Idx>();
  const Idx *col = coo.col.Ptr<Idx>();
  const Idx *edge_map = coo.data.Ptr<Idx>();
210
211
  const DType *lhs_data = lhs.Ptr<DType>();
  const DType *rhs_data = rhs.Ptr<DType>();
212
213
214
  DType *out_data = out.Ptr<DType>();
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();

215
  int64_t *lhs_off = nullptr, *rhs_off = nullptr;
216
217
218
219
220
221
222
223
  int64_t len = bcast.out_len,
          lhs_len = bcast.lhs_len,
          rhs_len = bcast.rhs_len;
  int64_t reduce_dim = bcast.reduce_size;

  const int64_t nnz = coo.row->shape[0];
  const bool use_idx = !IsNullArray(coo.data);

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
  if (std::is_same<Op, binary::Dot<DType> >::value && reduce_dim >= 32) {
    const int ntx = 32;  // on feature dimension
    const int nty = 8;   // on out dimension
    const int nbx = (nnz + nty - 1) / nty;
    const int nby = FindNumBlocks<'y'>(len);
    const dim3 nblks(nbx, nby);
    const dim3 nthrs(ntx, nty);
    BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
      CUDA_KERNEL_CALL((SDDMMCooTreeReduceKernel<Idx, DType, UseBcast, UseIdx, LhsTarget, RhsTarget>),
          nblks, nthrs, 0, thr_entry->stream,
          lhs_data, rhs_data, out_data,
          row, col, edge_map,
          coo.num_rows, coo.num_cols, nnz, reduce_dim,
          lhs_off, rhs_off,
          lhs_len, rhs_len, len);
    });        
  } else {
    const int ntx = FindNumThreads(len);
    const int nty = CUDA_MAX_NUM_THREADS / ntx;
    const int nbx = (len + ntx - 1) / ntx;
    const int nby = FindNumBlocks<'y'>((nnz + nty - 1) / nty);
    const dim3 nblks(nbx, nby);
    const dim3 nthrs(ntx, nty);
    BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
      CUDA_KERNEL_CALL((SDDMMCooKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
          nblks, nthrs, 0, thr_entry->stream,
          lhs_data, rhs_data, out_data,
          row, col, edge_map,
          coo.num_rows, coo.num_cols, nnz, reduce_dim,
          lhs_off, rhs_off,
          lhs_len, rhs_len, len);
    });
  }
257
258
259
260
261
262
}

/*!
 * \brief CUDA implementation of g-SDDMM on Csr format.
 * \param bcast Broadcast information.
 * \param csr The Csr matrix.
263
264
 * \param lhs The left hand side operand feature.
 * \param rhs The right hand size operand feature.
265
266
 * \param out The result feature on edges.
 */
267
268
template <typename Idx, typename DType, typename Op,
          int LhsTarget = 0, int RhsTarget = 2>
269
270
271
void SDDMMCsr(
    const BcastOff& bcast,
    const CSRMatrix& csr,
272
273
274
275
    NDArray lhs,
    NDArray rhs,
    NDArray out) {
  const Idx *indptr = csr.indptr.Ptr<Idx>();
276
277
  const Idx *indices = csr.indices.Ptr<Idx>();
  const Idx *edge_map = csr.data.Ptr<Idx>();
278
279
  const DType *lhs_data = lhs.Ptr<DType>();
  const DType *rhs_data = rhs.Ptr<DType>();
280
281
282
283
  DType *out_data = out.Ptr<DType>();
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];

284
  int64_t *lhs_off = nullptr, *rhs_off = nullptr;
285
286
287
288
289
290
291
292
293
294
295
296
297
  int64_t len = bcast.out_len,
          lhs_len = bcast.lhs_len,
          rhs_len = bcast.rhs_len;
  int64_t reduce_dim = bcast.reduce_size;

  const int ntx = FindNumThreads(len);
  const int nty = CUDA_MAX_NUM_THREADS / ntx;
  const int nbx = (len + ntx - 1) / ntx;
  const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const bool use_idx = !IsNullArray(csr.data);

298
  BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
299
300
    CUDA_KERNEL_CALL((SDDMMCsrKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
        nblks, nthrs, 0, thr_entry->stream,
301
        lhs_data, rhs_data, out_data,
302
303
        indptr, indices, edge_map,
        N, M, E, reduce_dim,
304
        lhs_off, rhs_off,
305
        lhs_len, rhs_len, len);
306
307
308
309
310
311
312
313
  });
}

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

#endif