multi_tensor_scale_kernel.cu 4.09 KB
Newer Older
1
2
3
4
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
5
6
// Another possibility:
// #include <torch/all.h>
7
8

#include <assert.h>
9
10
11
12
13
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>

#include "type_shim.h"
#include "multi_tensor_apply.cuh"
14

15
#define BLOCK_SIZE 1024
16
17
#define ILP 4

18
19
20
21
22
23
24
25
26
27
28
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
  return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}

template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}

29
template<typename in_t, typename out_t>
30
31
32
33
34
struct ScaleFunctor
{
   __device__ __forceinline__ void operator()(
    int chunk_size,
    volatile int* noop_gmem,
35
    TensorListMetadata<2>& tl,
36
37
    float scale)
  {
38
39
40
    // I'd like this kernel to propagate infs/nans.
    // if(*noop_gmem == 1)
    //   return;
41

42
43
44
    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];
45

46
    in_t* in = (in_t*)tl.addresses[0][tensor_loc];
47
    in += chunk_idx*chunk_size;
48

49
    out_t* out = (out_t*)tl.addresses[1][tensor_loc];
50
51
52
53
    out += chunk_idx*chunk_size;

    n -= chunk_idx*chunk_size;

54
55
56
57
58
59
    bool finite = true;
    in_t r_in[ILP];
    out_t r_out[ILP];

    // to make things simple, we put aligned case in a different code path
    if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out))
60
    {
61
      for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
62
      {
63
64
65
66
67
68
69
70
71
72
        // load
        load_store(r_in, in, 0 , i_start);
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
        {
          r_out[ii] = static_cast<float>(r_in[ii]) * scale;
          finite = finite && isfinite(r_in[ii]);
        }
        // store
        load_store(out, r_out, i_start, 0);
73
      }
74
75
76
77
78
    }
    else
    {
      // Non-divergent exit condition for __syncthreads, not necessary here
      for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
79
      {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
        {
          r_in[ii] = 0;
          int i = i_start + threadIdx.x + ii*blockDim.x;
          if(i < n && i < chunk_size)
            r_in[ii] = in[i];
        }
        // note for clarification to future michael:
        // From a pure memory dependency perspective, there's likely no point unrolling
        // the write loop, since writes just fire off once their LDGs arrive.
        // Put another way, the STGs are dependent on the LDGs, but not on each other.
        // There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
        {
          r_out[ii] = static_cast<float>(r_in[ii]) * scale;
          finite = finite && isfinite(r_in[ii]);
        }
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
101
        {
102
103
104
          int i = i_start + threadIdx.x + ii*blockDim.x;
          if(i < n && i < chunk_size)
            out[i] = r_out[ii];
105
        }
106
107
      }
    }
108
109
    if(!finite)
      *noop_gmem = 1; // Blindly fire off a write.  These will race but that's ok.
110
111
112
113
114
115
116
117
118
  }
};

void multi_tensor_scale_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  float scale)
{
119
  using namespace at;
120
121
122
  // The output (downscaled) type is always float.
  // If build times suffer, think about where to put this dispatch,
  // and what logic should be moved out of multi_tensor_apply.
123

124
125
  DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
    DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
126
127
128
129
130
131
132
      multi_tensor_apply<2>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        ScaleFunctor<scalar_t_0, scalar_t_1>(),
        scale); ))
133
134
135
136
  AT_CUDA_CHECK(cudaGetLastError());

  // AT_CUDA_CHECK(cudaDeviceSynchronize());
}