gather_kernel.cu 6.77 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
rusty1s's avatar
rusty1s committed
4
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
5
6

#include "compat.cuh"
rusty1s's avatar
rusty1s committed
7
#include "indptr.cuh"
rusty1s's avatar
rusty1s committed
8
9
10

#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

template <typename scalar_t, int TB>
__global__ void
gather_csr_kernel(const scalar_t *src_data,
                  const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
                  scalar_t *out_data, size_t N, size_t E) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / TB;
  int lane_idx = thread_idx % TB;

  if (row_idx < N) {
    int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
    int row_start = __ldg(indptr_info.data + offset);
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
    scalar_t val = __ldg(src_data + row_idx);

    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
    for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) {
      out_data[offset + out_idx] = val; // "Mostly" coalesced.
    }
  }
}

template <typename scalar_t>
__global__ void gather_csr_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
    scalar_t *out_data, size_t N, size_t K, size_t E) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int lane_idx = thread_idx % K;

  if (thread_idx < N * K) {
    int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
    int row_start = __ldg(indptr_info.data + offset);
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);

    scalar_t val = src_data[thread_idx]; // Coalesced.

    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
    for (int out_idx = row_start; out_idx < row_end; out_idx++) {
      out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced.
    }
  }
}
rusty1s's avatar
rusty1s committed
60

rusty1s's avatar
rusty1s committed
61
62
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
                              torch::optional<torch::Tensor> out_opt) {
rusty1s's avatar
rusty1s committed
63

rusty1s's avatar
rusty1s committed
64
  cudaSetDevice(src.get_device());
rusty1s's avatar
rusty1s committed
65
  AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
rusty1s's avatar
rusty1s committed
66
  for (int i = 0; i < indptr.dim() - 1; i++)
rusty1s's avatar
rusty1s committed
67
    AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
rusty1s's avatar
rusty1s committed
68
69
70

  src = src.contiguous();
  auto gather_dim = indptr.dim() - 1;
rusty1s's avatar
rusty1s committed
71
72
  AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
             "Input mismatch");
rusty1s's avatar
rusty1s committed
73

rusty1s's avatar
rusty1s committed
74
  torch::Tensor out;
rusty1s's avatar
rusty1s committed
75
76
77
78
  if (out_opt.has_value()) {
    out = out_opt.value().contiguous();
    for (int i = 0; i < out.dim(); i++)
      if (i != gather_dim)
rusty1s's avatar
rusty1s committed
79
        AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
rusty1s's avatar
rusty1s committed
80
  } else {
rusty1s's avatar
rusty1s committed
81
82
    auto d_gather_size = indptr.flatten()[-1].DATA_PTR<int64_t>();
    auto h_gather_size = (int64_t *)malloc(sizeof(int64_t));
rusty1s's avatar
rusty1s committed
83
84
85
86
87
88
89
90
    cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t),
               cudaMemcpyDeviceToHost);

    auto sizes = src.sizes().vec();
    sizes[gather_dim] = *h_gather_size;
    out = at::empty(sizes, src.options());
  }

rusty1s's avatar
rusty1s committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
  auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1));
  auto K = src.numel() / N;
  auto E = out.size(gather_dim);

  auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

    if (K == 1) {
      gather_csr_kernel<scalar_t, 4><<<BLOCKS(1, 4 * N), THREADS, 0, stream>>>(
          src_data, indptr_info, out_data, N, E);
    } else {
      gather_csr_broadcast_kernel<scalar_t>
          <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
                                                     out_data, N, K, E);
    }
  });

  return out;
}

template <typename scalar_t>
__global__ void
gather_coo_kernel(const scalar_t *src_data,
                  const at::cuda::detail::TensorInfo<int64_t, int> index_info,
                  scalar_t *out_data, size_t E, size_t N) {

  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;

  if (row_idx < E) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int row = index_info.data[offset];

    offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N;
    scalar_t val = __ldg(src_data + offset + row);

    out_data[row_idx] = val;
  }
}

template <typename scalar_t>
__global__ void gather_coo_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
    scalar_t *out_data, size_t E, size_t K, size_t N) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int col_idx = thread_idx % K;

  if (thread_idx < E * K) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int row = index_info.data[offset];

    offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K;
    scalar_t val = __ldg(src_data + offset + K * row + col_idx);

    out_data[thread_idx] = val;
  }
}

rusty1s's avatar
rusty1s committed
156
157
158
159
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
                              torch::optional<torch::Tensor> out_opt) {

  cudaSetDevice(src.get_device());
rusty1s's avatar
rusty1s committed
160

rusty1s's avatar
rusty1s committed
161
  AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
rusty1s's avatar
rusty1s committed
162
  for (int i = 0; i < index.dim() - 1; i++)
rusty1s's avatar
rusty1s committed
163
    AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch");
rusty1s's avatar
rusty1s committed
164
165
166
167

  src = src.contiguous();
  auto gather_dim = index.dim() - 1;

rusty1s's avatar
rusty1s committed
168
  torch::Tensor out;
rusty1s's avatar
rusty1s committed
169
170
171
  if (out_opt.has_value()) {
    out = out_opt.value().contiguous();
    for (int i = 0; i < index.dim(); i++)
rusty1s's avatar
rusty1s committed
172
      AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch");
rusty1s's avatar
rusty1s committed
173
    for (int i = index.dim() + 1; i < src.dim(); i++)
rusty1s's avatar
rusty1s committed
174
      AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch");
rusty1s's avatar
rusty1s committed
175
176
177
  } else {
    auto sizes = src.sizes().vec();
    sizes[gather_dim] = index.size(gather_dim);
rusty1s's avatar
rusty1s committed
178
    out = torch::empty(sizes, src.options());
rusty1s's avatar
rusty1s committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
  }

  auto E = index.numel();
  auto K = out.numel() / E;
  auto N = src.size(gather_dim);

  auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

    if (K == 1) {
      gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
          src_data, index_info, out_data, E, N);
    } else {
      gather_coo_broadcast_kernel<scalar_t>
          <<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info,
                                                     out_data, E, K, N);
    }
  });

rusty1s's avatar
rusty1s committed
201
202
  return out;
}