multi_tensor_l2norm_kernel.cu 12.7 KB
Newer Older
Michael Carilli's avatar
Michael Carilli committed
1
2
3
4
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
5
#include <c10/cuda/CUDAGuard.h>
Michael Carilli's avatar
Michael Carilli committed
6
7
8
9
10
11
// Another possibility:
// #include <torch/all.h>

#include <assert.h>

#include "type_shim.h"
12
#include "multi_tensor_apply_base.cuh"
Michael Carilli's avatar
Michael Carilli committed
13
14
15
16

#define BLOCK_SIZE 512
#define ILP 4

17
18
19
20
21
22
23
24
25
26
27
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
  return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}

template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}

Michael Carilli's avatar
Michael Carilli committed
28
29
30
template<typename x_t>
struct L2NormFunctor
{
31
  __device__ __forceinline__ void operator()(
Michael Carilli's avatar
Michael Carilli committed
32
33
    int chunk_size,
    volatile int* noop_gmem,
34
    TensorListMetadata<1>& tl,
35
36
37
38
    float* output,
    float* output_per_tensor,
    bool per_tensor,
    int max_chunks_per_tensor)
Michael Carilli's avatar
Michael Carilli committed
39
40
41
42
43
  {
    // I'd like this kernel to propagate infs/nans.
    // if(*noop_gmem == 1)
    //   return;

44
45
46
    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];
Michael Carilli's avatar
Michael Carilli committed
47

48
    x_t* x = (x_t*)tl.addresses[0][tensor_loc];
Michael Carilli's avatar
Michael Carilli committed
49
50
51
52
    x += chunk_idx*chunk_size;

    n -= chunk_idx*chunk_size;

53
    __shared__ float s_vals[512];
Michael Carilli's avatar
Michael Carilli committed
54

55
    float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
56
    x_t r_x[ILP];
57
    for(int i = 0; i < ILP; i++)
58
    {
59
      vals[i] = 0.f;
60
61
      r_x[i] = 0;
    }
62

63
64
    // to make things simple, we put aligned case in a different code path
    if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
Michael Carilli's avatar
Michael Carilli committed
65
    {
66
      for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
67
      {
68
69
70
71
        // load
        load_store(r_x, x, 0 , i_start);
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
72
        {
73
          float next = static_cast<float>(r_x[ii]);
74
75
76
          vals[ii] += next*next;
        }
      }
Michael Carilli's avatar
Michael Carilli committed
77
    }
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    else
    {
      for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
      {
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
        {
          int i = i_start + threadIdx.x + ii*blockDim.x;
          if(i < n && i < chunk_size)
          {
            float next = static_cast<float>(x[i]);
            vals[ii] += next*next;
          }
        }
      }
    }
Michael Carilli's avatar
Michael Carilli committed
94

95
96
97
98
99
    float val = 0.f;
    for(int i = 0; i < ILP; i++)
        val += vals[i];

    float final = reduce_block_into_lanes(s_vals, val);
Michael Carilli's avatar
Michael Carilli committed
100
101
102
103
104
105

    if(threadIdx.x == 0)
    {
      if(!isfinite(final))
        *noop_gmem = 1; // Blindly fire off a write.  These will race but that's ok.
      output[blockIdx.x] += final;
106
      if(per_tensor)
107
        output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
Michael Carilli's avatar
Michael Carilli committed
108
109
110
111
    }
  }
};

112
113
114
115
116
117
118
// Probably better to template, but since we are not likely to support other norm
template<typename x_t>
struct MaxNormFunctor
{
  __device__ __forceinline__ void operator()(
    int chunk_size,
    volatile int* noop_gmem,
119
    TensorListMetadata<1>& tl,
120
121
122
123
124
125
126
127
128
    float* output,
    float* output_per_tensor,
    bool per_tensor,
    int max_chunks_per_tensor)
  {
    // I'd like this kernel to propagate infs/nans.
    // if(*noop_gmem == 1)
    //   return;

129
130
131
    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];
132

133
    x_t* x = (x_t*)tl.addresses[0][tensor_loc];
134
135
136
137
138
139
140
    x += chunk_idx*chunk_size;

    n -= chunk_idx*chunk_size;

    __shared__ float s_vals[512];

    float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
141
    x_t r_x[ILP];
142
    for(int i = 0; i < ILP; i++)
143
    {
144
      vals[i] = 0.f;
145
146
      r_x[i] = 0;
    }
147

148
149
    // to make things simple, we put aligned case in a different code path
    if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
150
    {
151
      for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
152
      {
153
154
155
156
        // load
        load_store(r_x, x, 0 , i_start);
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
157
        {
158
          float next = static_cast<float>(r_x[ii]);
159
160
161
162
          vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
        }
      }
    }
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    else
    {
      for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
      {
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
        {
          int i = i_start + threadIdx.x + ii*blockDim.x;
          if(i < n && i < chunk_size)
          {
            float next = static_cast<float>(x[i]);
            vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
          }
        }
      }
    }
179
180
181
182
183
184
185
186
187
188
189
190
191

    float val = 0.f;
    for(int i = 0; i < ILP; i++)
        val = fmaxf(fabsf(val), fabsf(vals[i]));

    float final = reduce_block_into_lanes_max_op(s_vals, val);

    if(threadIdx.x == 0)
    {
      if(!isfinite(final))
        *noop_gmem = 1; // Blindly fire off a write.  These will race but that's ok.
      output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
      if(per_tensor)
192
        output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
193
194
195
196
    }
  }
};

197

ashishfarmer's avatar
ashishfarmer committed
198
199
200
201
202
__global__ void
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(1024)
#endif
cleanup(
203
204
205
206
207
208
  float* output,
  float* output_per_tensor,
  float* ret,
  float* ret_per_tensor,
  bool per_tensor,
  int max_chunks_per_tensor)
209
210
211
{
  __shared__ float vals[512];

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
  if(blockIdx.x == 0)
  {
    float val = 0;
    if(threadIdx.x < 320)
      val = output[threadIdx.x];

    float final = reduce_block_into_lanes(vals, val);

    if(threadIdx.x == 0)
      *ret = sqrt(final);
  }

  if(per_tensor)
  {
    float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;

    float val = 0;
    for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
      val += output_this_tensor[i];
231

232
    float final = reduce_block_into_lanes(vals, val);
233

234
235
236
    if(threadIdx.x == 0)
      ret_per_tensor[blockIdx.x] = sqrt(final);
  }
237
238
}

ashishfarmer's avatar
ashishfarmer committed
239
240
241
242
243
__global__ void
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(1024)
#endif
cleanup_v2(
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
  float* output,
  float* output_per_tensor,
  float* ret,
  float* ret_per_tensor,
  bool per_tensor,
  int max_chunks_per_tensor,
  int norm_type,
  float alpha,
  float beta)
{
  __shared__ float vals[512];

  if(blockIdx.x == 0)
  {
    float val = 0;
    if(threadIdx.x < 320)
      val = output[threadIdx.x];

    if (norm_type == 0) {
      float final = reduce_block_into_lanes_max_op(vals, val);
      if(threadIdx.x == 0)
        *ret = alpha * (*ret) + beta * final;
    }
    else {
      float final = reduce_block_into_lanes(vals, val);
      if(threadIdx.x == 0)
        *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
    }
  }

  if(per_tensor)
  {
    float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;

    if (norm_type == 0) {
      float val = 0;
      for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
        val = fmaxf(fabsf(val), fabsf(output_this_tensor[i]));

      float final = reduce_block_into_lanes_max_op(vals, val);

      if(threadIdx.x == 0)
        ret_per_tensor[blockIdx.x] = alpha * ret_per_tensor[blockIdx.x] + beta * final;
    }
    else {
      float val = 0;
      for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
        val += output_this_tensor[i];

      float final = reduce_block_into_lanes(vals, val);

      if(threadIdx.x == 0)
        ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] * ret_per_tensor[blockIdx.x] + beta * final);
    }
  }
}
300
301

std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
Michael Carilli's avatar
Michael Carilli committed
302
303
  int chunk_size,
  at::Tensor noop_flag,
304
305
  std::vector<std::vector<at::Tensor>> tensor_lists,
  at::optional<bool> per_tensor_python)
Michael Carilli's avatar
Michael Carilli committed
306
{
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
  bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;

  auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
  auto output = at::zeros({320}, float_options);

  at::Tensor output_per_tensor;
  at::Tensor ret_per_tensor;

  int ntensors = tensor_lists[0].size();
  int max_chunks_per_tensor = -1;

  if(per_tensor)
  {
    for(int t = 0; t < ntensors; t++)
    {
      int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
      if(max_chunks_this_tensor > max_chunks_per_tensor)
        max_chunks_per_tensor = max_chunks_this_tensor;
    }
    output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
    ret_per_tensor = at::empty({ntensors}, float_options);
  }
  else
  {
    ret_per_tensor = at::empty({0}, float_options);
  }
Michael Carilli's avatar
Michael Carilli committed
333

334
  DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
Michael Carilli's avatar
Michael Carilli committed
335
336
337
338
339
340
    multi_tensor_apply<1>(
      BLOCK_SIZE,
      chunk_size,
      noop_flag,
      tensor_lists,
      L2NormFunctor<scalar_t_0>(),
mcarilli's avatar
mcarilli committed
341
342
      output.DATA_PTR<float>(),
      per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
343
344
      per_tensor,
      max_chunks_per_tensor);)
Michael Carilli's avatar
Michael Carilli committed
345
346
347
348

  AT_CUDA_CHECK(cudaGetLastError());
  // AT_CUDA_CHECK(cudaDeviceSynchronize());

349
  // This involves one more small kernel launches, but will be negligible end to end.
Michael Carilli's avatar
Michael Carilli committed
350
351
  // I could get rid of these by hacking the functor + multi tensor harness with persistence
  // logic, but keeping it simple for now
352
  auto ret = at::empty({1}, output.options());
353
  const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
354
  auto stream = at::cuda::getCurrentCUDAStream();
355
  cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
mcarilli's avatar
mcarilli committed
356
357
358
359
    output.DATA_PTR<float>(),
    per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
    ret.DATA_PTR<float>(),
    per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr,
360
361
362
363
    per_tensor,
    max_chunks_per_tensor);

  return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
Michael Carilli's avatar
Michael Carilli committed
364
}
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380


// Compute and update grad norm
// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by
// L-2: gn = sqrt(a * gn^2 + b * n^2)
// L-inf: gn = a * gn + b * n
void multi_tensor_norm_out_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  at::Tensor out,
  const float alpha,
  const float beta,
  const int norm_type)
{
  auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
381
  TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors");
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
  // we don't need global thus uses empty here
  auto output = at::empty({320}, float_options);

  at::Tensor output_per_tensor;
  at::Tensor ret_per_tensor;

  int ntensors = tensor_lists[0].size();
  int max_chunks_per_tensor = -1;

  for(int t = 0; t < ntensors; t++)
  {
    int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
    if(max_chunks_this_tensor > max_chunks_per_tensor)
      max_chunks_per_tensor = max_chunks_this_tensor;
  }

  // Although it is single write then read, still need to be zero
  // Since tailing element also participate cleanup
  output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);

  if (norm_type == 0) {
403
    DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(
404
405
406
407
408
409
410
      tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
      multi_tensor_apply<1>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        MaxNormFunctor<scalar_t_0>(),
mcarilli's avatar
mcarilli committed
411
412
        output.DATA_PTR<float>(),
        output_per_tensor.DATA_PTR<float>(),
413
414
415
416
        true,
        max_chunks_per_tensor);)
  }
  else {
417
    DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(
418
419
420
421
422
423
424
      tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
      multi_tensor_apply<1>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        L2NormFunctor<scalar_t_0>(),
mcarilli's avatar
mcarilli committed
425
426
        output.DATA_PTR<float>(),
        output_per_tensor.DATA_PTR<float>(),
427
428
429
430
431
432
433
434
435
436
437
        true,
        max_chunks_per_tensor);)
  }
  AT_CUDA_CHECK(cudaGetLastError());

  // AT_CUDA_CHECK(cudaDeviceSynchronize());

  // This involves one more small kernel launches, but will be negligible end to end.
  // I could get rid of these by hacking the functor + multi tensor harness with persistence
  // logic, but keeping it simple for now
  auto ret = at::empty({1}, output.options());
438
439
440
441
442

  // Adding the following device guard since it happens sometimes that the 
  // tensors are on one device and the cuda stream is on another device which  
  // results in ILLEGAL MEM ACCESS error. 
  const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
443
444
  auto stream = at::cuda::getCurrentCUDAStream();
  cleanup_v2<<<ntensors, 512, 0, stream>>>(
mcarilli's avatar
mcarilli committed
445
446
447
448
    output.DATA_PTR<float>(),
    output_per_tensor.DATA_PTR<float>(),
    ret.DATA_PTR<float>(),
    out.DATA_PTR<float>(),
449
450
451
452
453
454
455
456
    true,
    max_chunks_per_tensor,
    norm_type,
    alpha,
    beta);

  return ;
}