scale_check_overflow_kernel.cu 2.03 KB
Newer Older
1
2
3
4
5
6
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>

#include <assert.h>
7
#include <cuda_runtime.h>
8
9
10
11

#define BLOCK_SIZE 1024
#define MAX_BLOCKS 1024

12
13
14
// It makes sense to lock the output type to fp32 because the downscaled
// grads should be master grads (and in the case of Amp, the params and their
// gradients should always be fp32.
15

16
17
18
19
20
21
template<typename in_t>
__global__ void scale_reduce_overflow(in_t* in,
                                      float* out, 
                                      size_t n, 
                                      float scale,
                                      volatile int* overflow_global) 
22
{
23
  __shared__ int overflow;
24

25
26
  int tid = blockIdx.x*blockDim.x + threadIdx.x;
  int stride = gridDim.x*blockDim.x;
27

28
29
30
31
32
  // Non-divergent exit condition for the __syncthreads
  for(int i = tid; i - threadIdx.x < n; i += stride)
  {
    if(threadIdx.x == 0)
      overflow = *overflow_global;
33
34
35

    __syncthreads();

36
37
38
39
40
41
42
43
44
45
46
47
    if(overflow == 1)
      break;
    
    if(tid < n)
    {
      float incoming_val = static_cast<float>(in[i]);
      if(isfinite(incoming_val))
        out[i] = incoming_val*scale;
      else
        *overflow_global = 1; // Blindly fire off a write.  These will race but that's ok.
    } 
  }  
48
49
}

50

51
void scale_check_overflow_cuda
52
  (const at::Tensor& grads, 
53
   float scale,
54
55
   const at::Tensor& overflow_buf,
   const at::Tensor& downscaled_grads) 
56
57
58
59
{
  using namespace at;
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
  size_t n = grads.numel();

  int num_blks = 160;
 
  // Lock the output (downscaled) type to float.
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grads.type(),
     "scale_check_overflow_cuda",
     [&]
     {
       // using accscalar_t = acc_type<scalar_t, true>;
       scale_reduce_overflow<<<num_blks, BLOCK_SIZE, 0, stream>>>
         (grads.data<scalar_t>(), 
          downscaled_grads.data<float>(),
          n, 
          scale, 
          overflow_buf.data<int>());
     });
77
78
79

  AT_CUDA_CHECK(cudaGetLastError());
}