bgmv_impl.cuh 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cuda/pipeline>
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>

#include "vec_dtypes.cuh"

namespace cg = cooperative_groups;

// nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
          size_t W_copy_size, int tx, int ty, int tz, typename in_T,
          typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
                   const W_T *__restrict__ W,
                   const int64_t *__restrict__ indicies, int64_t y_offset,
                   int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
                   float scale) {
  size_t batch_idx = blockIdx.y;
  int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
  if (idx < 0) {
    return;
  }

  auto block = cg::this_thread_block();
  size_t j = blockIdx.x;
  constexpr size_t num_pipeline_stages = 2;
  constexpr size_t tile_size = tx * ty * vec_size;
  __shared__ W_T W_shared[num_pipeline_stages * tile_size];
  __shared__ in_T X_shared[num_pipeline_stages * tile_size];
  __shared__ float y_warpwise[ty];

  size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
  size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
  auto pipe = cuda::make_pipeline();

  // pipeline load W/X and compute WX;
  pipe.producer_acquire();
  cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
                     W + (idx * feat_out + j) * feat_in +
                         (threadIdx.y * tx + threadIdx.x) * vec_size,
                     cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
  cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
                     X + (batch_idx * feat_in) +
                         (threadIdx.y * tx + threadIdx.x) * vec_size,
                     cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
  pipe.producer_commit();
  size_t copy_idx, compute_idx;
  float y = 0.f;
  vec_t<in_T, vec_size> x_vec;
  vec_t<W_T, vec_size> w_vec;
  size_t tile_idx;

#pragma unroll
  for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
       ++tile_idx) {
    copy_idx = tile_idx % num_pipeline_stages;
    // pipeline stage: async copy W fragment
    pipe.producer_acquire();
    if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
      cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
                             (threadIdx.y * tx + threadIdx.x) * vec_size,
                         W + (idx * feat_out + j) * feat_in +
                             tile_idx * tile_size +
                             (threadIdx.y * tx + threadIdx.x) * vec_size,
                         cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
      cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
                             (threadIdx.y * tx + threadIdx.x) * vec_size,
                         X + (batch_idx * feat_in) + tile_idx * tile_size +
                             (threadIdx.y * tx + threadIdx.x) * vec_size,
                         cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
    }
    pipe.producer_commit();

    compute_idx = (tile_idx - 1) % num_pipeline_stages;
    // pipeline stage: compute WX
    pipe.consumer_wait();
    block.sync();
    x_vec.load(X_shared + X_shared_offset[compute_idx] +
               (threadIdx.y * tx + threadIdx.x) * vec_size);
    w_vec.load(W_shared + W_shared_offset[compute_idx] +
               (threadIdx.y * tx + threadIdx.x) * vec_size);
    float sum = 0.f;
#pragma unroll
    for (size_t i = 0; i < vec_size; ++i) {
      sum += float(w_vec[i]) * float(x_vec[i]) * scale;
    }
#pragma unroll
    for (size_t offset = tx / 2; offset > 0; offset /= 2) {
      sum += __shfl_down_sync(0xffffffff, sum, offset);
    }
    y_warpwise[threadIdx.y] = sum;
    block.sync();
#pragma unroll
    for (size_t i = 0; i < ty; ++i) {
      y += y_warpwise[i];
    }

    block.sync();
    pipe.consumer_release();
  }

  compute_idx = (tile_idx - 1) % num_pipeline_stages;
  // final pipeline stage
  pipe.consumer_wait();
  block.sync();
  x_vec.load(X_shared + X_shared_offset[compute_idx] +
             (threadIdx.y * tx + threadIdx.x) * vec_size);
  w_vec.load(W_shared + W_shared_offset[compute_idx] +
             (threadIdx.y * tx + threadIdx.x) * vec_size);
  float sum = 0.f;
#pragma unroll
  for (size_t i = 0; i < vec_size; ++i) {
    sum += float(w_vec[i]) * float(x_vec[i]) * scale;
  }
#pragma unroll
  for (size_t offset = tx / 2; offset > 0; offset /= 2) {
    sum += __shfl_down_sync(0xffffffff, sum, offset);
  }
  y_warpwise[threadIdx.y] =
      ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
          ? sum
          : 0.f;
  block.sync();
#pragma unroll
  for (size_t i = 0; i < ty; ++i) {
    y += y_warpwise[i];
  }

  block.sync();
  pipe.consumer_release();

  // write Y;
  if (block.thread_rank() == 0) {
    Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
  }
}

// nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
          typename in_T, typename out_T, typename W_T>
__global__ void
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
                   const W_T *__restrict__ W,
                   const int64_t *__restrict__ indicies, int64_t y_offset,
                   int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
                   float scale) {
  size_t batch_idx = blockIdx.y;
  int64_t idx = indicies[batch_idx] * num_layers + layer_idx;

  if (idx < 0) {
    return;
  }

  auto block = cg::this_thread_block();
  size_t tile_idx = blockIdx.x;

  // load X;
  vec_t<in_T, vec_size> x_vec;
  x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);

  // load W;
  vec_t<W_T, vec_size> w_vec;
  w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
             block.thread_rank() * vec_size);

  float sum = 0.f;
#pragma unroll
  for (size_t i = 0; i < vec_size; ++i) {
    sum += float(w_vec[i]) * float(x_vec[i]) * scale;
  }

  cg::thread_block_tile g = cg::tiled_partition<tx>(block);
#pragma unroll
  for (size_t offset = tx / 2; offset > 0; offset /= 2) {
    sum += g.shfl_down(sum, offset);
  }
  sum = g.shfl(sum, 0);

  if (threadIdx.x == 0) {
    Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
      threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
  }
}

template <int feat_in, int feat_out, typename in_T, typename out_T,
          typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
                 const W_T *__restrict__ W,
                 const int64_t *__restrict__ indicies, int64_t y_offset,
                 int64_t full_y_size, int64_t batch_size, int64_t num_layers,
                 int64_t layer_idx, float scale) {
  constexpr size_t vec_size = 8;
  constexpr int tz = 4;
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  if constexpr (feat_in < feat_out) {
    static_assert(feat_in % vec_size == 0);
    constexpr int tx = feat_in / vec_size;

    static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
                  (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
                  (8 % tx == 0 && feat_out % (8 / tx * tz) == 0));

    if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
      constexpr int ty = 32 / tx;
      dim3 nblks(feat_out / (ty * tz), batch_size);
      dim3 nthrs(tx, ty, tz);

      bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
                                        full_y_size, num_layers, layer_idx,
                                        scale);
    } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
      constexpr int ty = 16 / tx;
      dim3 nblks(feat_out / (ty * tz), batch_size);
      dim3 nthrs(tx, ty, tz);

      bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
                                        full_y_size, num_layers, layer_idx,
                                        scale);
    } else {
      constexpr int ty = 8 / tx;
      dim3 nblks(feat_out / (ty * tz), batch_size);
      dim3 nthrs(tx, ty, tz);

      bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
                                        full_y_size, num_layers, layer_idx,
                                        scale);
    }
  } else {
    static_assert(feat_in % (vec_size * 32) == 0 ||
                  feat_in % (vec_size * 16) == 0 ||
                  feat_in % (vec_size * 8) == 0);

    if constexpr (feat_in % (vec_size * 32) == 0) {
      constexpr int tx = 32;
      constexpr int ty = 4;

      dim3 nblks(feat_out, batch_size);
      dim3 nthrs(tx, ty);

      bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
                         vec_size * sizeof(W_T), tx, ty, tz>
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
                                        full_y_size, num_layers, layer_idx,
                                        scale);
    } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
      constexpr int tx = 32;
      constexpr int ty = 4;

      dim3 nblks(feat_out, batch_size);
      dim3 nthrs(tx, ty);

      bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
                         vec_size * sizeof(in_T) / 2,
                         vec_size * sizeof(W_T) / 2, tx, ty, tz>
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
                                        full_y_size, num_layers, layer_idx,
                                        scale);
    } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
      constexpr int tx = 16;
      constexpr int ty = 4;

      dim3 nblks(feat_out, batch_size);
      dim3 nthrs(tx, ty);

      bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
                         vec_size * sizeof(in_T) / 2,
                         vec_size * sizeof(W_T) / 2, tx, ty, tz>
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
                                        full_y_size, num_layers, layer_idx,
                                        scale);
    }
  }
}

#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)                         \
  template void bgmv_kernel<feat_in, feat_out>(                                \
      out_T * __restrict__ Y, const in_T *__restrict__ X,                      \
      const W_T *__restrict__ W, const int64_t *__restrict__ indicies,         \
      int64_t y_offset, int64_t full_y_size, int64_t batch_size,               \
      int64_t num_layers, int64_t layer_idx, float scale);

#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide)                      \
  INST_BGMV(narrow, wide, in_T, out_T, W_T)                                    \
  INST_BGMV(wide, narrow, in_T, out_T, W_T)