scale_check_overflow_kernel.cu 2.31 KB
Newer Older
1
2
3
4
5
6
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>

#include <assert.h>
7
#include <cuda_runtime.h>
8
9

#define BLOCK_SIZE 1024
Michael Carilli's avatar
Michael Carilli committed
10
#define NBLOCKS 160
11

12
13
14
// It makes sense to lock the output type to fp32 because the downscaled
// grads should be master grads (and in the case of Amp, the params and their
// gradients should always be fp32.
15

Michael Carilli's avatar
Michael Carilli committed
16
// This can be optimized with ILP but it's fine for now.
17
18
template<typename in_t>
__global__ void scale_reduce_overflow(in_t* in,
19
                                      float* out,
Michael Carilli's avatar
Michael Carilli committed
20
                                      int n,
21
                                      float scale,
22
                                      volatile int* overflow_global)
23
{
24
  __shared__ int overflow;
25

26
27
  int tid = blockIdx.x*blockDim.x + threadIdx.x;
  int stride = gridDim.x*blockDim.x;
28

29
30
31
32
33
  // Non-divergent exit condition for the __syncthreads
  for(int i = tid; i - threadIdx.x < n; i += stride)
  {
    if(threadIdx.x == 0)
      overflow = *overflow_global;
34
35
36

    __syncthreads();

37
38
    if(overflow == 1)
      break;
39

Michael Carilli's avatar
Michael Carilli committed
40
    if(i < n)
41
42
43
44
45
46
    {
      float incoming_val = static_cast<float>(in[i]);
      if(isfinite(incoming_val))
        out[i] = incoming_val*scale;
      else
        *overflow_global = 1; // Blindly fire off a write.  These will race but that's ok.
Michael Carilli's avatar
Michael Carilli committed
47
48
49
        // This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
        // I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
        // It's possible we can just lean on the cache (no smem or syncs) and still be fast.
50
51
    }
  }
52
53
}

54

55
void scale_check_overflow_cuda
56
  (const at::Tensor& grads,
57
   float scale,
58
   const at::Tensor& overflow_buf,
59
   const at::Tensor& downscaled_grads)
60
61
62
63
{
  using namespace at;
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  
Michael Carilli's avatar
Michael Carilli committed
64
  int n = grads.numel();
65
66
67
68
69
70
71

  // Lock the output (downscaled) type to float.
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grads.type(),
     "scale_check_overflow_cuda",
     [&]
     {
       // using accscalar_t = acc_type<scalar_t, true>;
Michael Carilli's avatar
Michael Carilli committed
72
       scale_reduce_overflow<<<NBLOCKS, BLOCK_SIZE, 0, stream>>>
73
         (grads.data<scalar_t>(),
74
          downscaled_grads.data<float>(),
75
76
          n,
          scale,
77
78
          overflow_buf.data<int>());
     });
79
80
81

  AT_CUDA_CHECK(cudaGetLastError());
}