scale_check_overflow_kernel.cu 2.56 KB
Newer Older
1
2
3
4
5
6
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>

#include <assert.h>
7
#include <cuda_runtime.h>
8

Michael Carilli's avatar
Michael Carilli committed
9
10
11
#define BLOCK_SIZE 256
#define NBLOCKS 160*4
#define ILP 4
12

13
14
// It makes sense to lock the output type to fp32 because the downscaled
// grads should be master grads (and in the case of Amp, the params and their
Michael Carilli's avatar
Michael Carilli committed
15
// gradients should always be fp32).
16

17
18
template<typename in_t>
__global__ void scale_reduce_overflow(in_t* in,
19
                                      float* out,
Michael Carilli's avatar
Michael Carilli committed
20
                                      int n,
21
                                      float scale,
22
                                      volatile int* overflow_global)
23
{
24
  __shared__ int overflow;
Michael Carilli's avatar
Michael Carilli committed
25
  float incoming_vals[4];
26

27
  // Non-divergent exit condition for the __syncthreads
Michael Carilli's avatar
Michael Carilli committed
28
29
30
  for(int chunk_start = blockIdx.x*blockDim.x*ILP;
      chunk_start < n;
      chunk_start += gridDim.x*blockDim.x*ILP)
31
32
33
  {
    if(threadIdx.x == 0)
      overflow = *overflow_global;
34
35
36

    __syncthreads();

37
38
    if(overflow == 1)
      break;
39

Michael Carilli's avatar
Michael Carilli committed
40
41
    #pragma unroll
    for(int ii = 0; ii < ILP; ii++)
42
    {
Michael Carilli's avatar
Michael Carilli committed
43
44
45
46
      incoming_vals[ii] = 0;
      int i = chunk_start + threadIdx.x + ii*blockDim.x;
      if(i < n)
        incoming_vals[ii] = static_cast<float>(in[i]);
47
    }
Michael Carilli's avatar
Michael Carilli committed
48
49
50
51
52
53
54
55
56
57
58
59
60

    #pragma unroll
    for(int ii = 0; ii < ILP; ii++)
    {
      int i = chunk_start + threadIdx.x + ii*blockDim.x;
      if(i < n)
        if(isfinite(incoming_vals[ii]))
          out[i] = incoming_vals[ii]*scale;
        else
          *overflow_global = 1; // Blindly fire off a write.  These will race but that's ok.
    }    // This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
  }      // I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
}        // It's possible we can just lean on the cache (no smem or syncs) and still be fast.
61

62

63
void scale_check_overflow_cuda
64
  (const at::Tensor& grads,
65
   float scale,
66
   const at::Tensor& overflow_buf,
67
   const at::Tensor& downscaled_grads)
68
69
70
71
{
  using namespace at;
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  
Michael Carilli's avatar
Michael Carilli committed
72
  int n = grads.numel();
73
74
75
76
77
78
79

  // Lock the output (downscaled) type to float.
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grads.type(),
     "scale_check_overflow_cuda",
     [&]
     {
       // using accscalar_t = acc_type<scalar_t, true>;
Michael Carilli's avatar
Michael Carilli committed
80
       scale_reduce_overflow<<<NBLOCKS, BLOCK_SIZE, 0, stream>>>
81
         (grads.data<scalar_t>(),
82
          downscaled_grads.data<float>(),
83
84
          n,
          scale,
85
86
          overflow_buf.data<int>());
     });
87
88
89

  AT_CUDA_CHECK(cudaGetLastError());
}