multi_tensor_scale_kernel.cu 3.69 KB
Newer Older
1
2
3
4
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
5
6
// Another possibility:
// #include <torch/all.h>
7
8

#include <assert.h>
9
10
11
12
13
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>

#include "type_shim.h"
#include "multi_tensor_apply.cuh"
14
15
16
17

#define BLOCK_SIZE 512
#define ILP 4

18
template<typename in_t, typename out_t>
19
20
21
22
23
struct ScaleFunctor
{
   __device__ __forceinline__ void operator()(
    int chunk_size,
    volatile int* noop_gmem,
24
    TensorListMetadata<2>& tl,
25
26
    float scale)
  {
27
28
29
    // I'd like this kernel to propagate infs/nans.
    // if(*noop_gmem == 1)
    //   return;
30
31
32
33
34
35
36
37

    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];

    in_t* in = (in_t*)tl.addresses[0][tensor_loc];
    in += chunk_idx*chunk_size;
   
38
    out_t* out = (out_t*)tl.addresses[1][tensor_loc];
39
40
41
42
    out += chunk_idx*chunk_size;

    n -= chunk_idx*chunk_size;

43
    // Non-divergent exit condition for __syncthreads, not necessary here
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    float incoming_vals[ILP];
    for(int i_start = 0;
        i_start < n && i_start < chunk_size;
        i_start += blockDim.x*ILP)
    {
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        incoming_vals[ii] = 0;
        int i = i_start + threadIdx.x + ii*blockDim.x;
        if(i < n && i < chunk_size)
          incoming_vals[ii] = static_cast<float>(in[i]);
      }

      // note for clarification to future michael:
      // From a pure memory dependency perspective, there's likely no point unrolling
      // the write loop, since writes just fire off once their LDGs arrive.
      // Put another way, the STGs are dependent on the LDGs, but not on each other.
      // There is still compute ILP benefit from unrolling the loop though.
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        int i = i_start + threadIdx.x + ii*blockDim.x;
        if(i < n && i < chunk_size)
          if(isfinite(incoming_vals[ii]))
69
            out[i] = static_cast<out_t>(incoming_vals[ii]*scale);
70
          else
71
72
          {
            out[i] = static_cast<out_t>(incoming_vals[ii]*scale);
73
            *noop_gmem = 1; // Blindly fire off a write.  These will race but that's ok.
74
          }
75
76
77
78
79
80
81
82
83
84
85
      }
    }
  }
};

void multi_tensor_scale_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  float scale)
{
86
  using namespace at;
87
88
89
  // The output (downscaled) type is always float.
  // If build times suffer, think about where to put this dispatch,
  // and what logic should be moved out of multi_tensor_apply.
90

Michael Carilli's avatar
Michael Carilli committed
91
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[0][0].type(),
92
93
94
95
     "multi_tensor_scale_cuda",
     [&]
     {
       // using accscalar_t = acc_type<scalar_t, true>;
96
       switch(tensor_lists[1][0].scalar_type())
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
       {
         case at::ScalarType::Half:
           multi_tensor_apply<2>(
             BLOCK_SIZE,
             chunk_size,
             noop_flag,
             tensor_lists,
             ScaleFunctor<scalar_t, at::Half>(),
             scale);
           break;
         case at::ScalarType::Float:
           multi_tensor_apply<2>(
             BLOCK_SIZE,
             chunk_size,
             noop_flag,
             tensor_lists,
             ScaleFunctor<scalar_t, float>(),
             scale);
           break;
         default:
117
118
119
120
           std::stringstream ss;
           ss << "multi_tensor_scale_cuda not implemented for output type = "
              << tensor_lists[1][0].dtype();
           AT_ERROR(ss.str().c_str());
121
       }
122
123
124
125
126
127
     });

  AT_CUDA_CHECK(cudaGetLastError());

  // AT_CUDA_CHECK(cudaDeviceSynchronize());
}