multi_tensor_axpby_kernel.cu 4.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>

#include <assert.h>

#include "type_shim.h"
#include "multi_tensor_apply.cuh"

#define BLOCK_SIZE 512
#define ILP 4

16
17
18
19
20
21
22
23
24
25
26
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
  return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}

template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
  typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}

27
28
29
30
31
32
33
34
template<typename x_t, typename y_t, typename out_t>
struct AxpbyFunctor
{
   __device__ __forceinline__ void operator()(
    int chunk_size,
    volatile int* noop_gmem,
    TensorListMetadata<3>& tl,
    float a,
35
36
    float b,
    int arg_to_check)
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  {
    // I'd like this kernel to propagate infs/nans.
    // if(*noop_gmem == 1)
    //   return;

    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];

    x_t* x = (x_t*)tl.addresses[0][tensor_loc];
    x += chunk_idx*chunk_size;

    y_t* y = (y_t*)tl.addresses[1][tensor_loc];
    y += chunk_idx*chunk_size;

    out_t* out = (out_t*)tl.addresses[2][tensor_loc];
    out += chunk_idx*chunk_size;

    n -= chunk_idx*chunk_size;

57
58
59
60
61
62
63
    bool finite = true;
    x_t r_x[ILP];
    y_t r_y[ILP];
    out_t r_out[ILP];

    // to make things simple, we put aligned case in a different code path
    if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out))
64
    {
65
      for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
66
      {
67
68
69
70
71
        // load
        load_store(r_x, x, 0 , i_start);
        load_store(r_y, y, 0 , i_start);
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
72
        {
73
74
75
76
77
78
79
          r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
          if(arg_to_check == -1)
            finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
          if(arg_to_check == 0)
            finite = finite && isfinite(r_x[ii]);
          if(arg_to_check == 1)
            finite = finite && isfinite(r_y[ii]);
80
        }
81
82
        // store
        load_store(out, r_out, i_start , 0);
83
      }
84
85
86
87
88
    }
    else
    {
      // Non-divergent exit condition for __syncthreads, not necessary here
      for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
89
      {
90
91
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
92
        {
93
94
95
96
97
98
99
100
101
102
103
104
105
          r_x[ii] = 0;
          r_y[ii] = 0;
          int i = i_start + threadIdx.x + ii*blockDim.x;
          if(i < n && i < chunk_size)
          {
            r_x[ii] = x[i];
            r_y[ii] = y[i];
          }
        }
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
        {
          r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
106
          if(arg_to_check == -1)
107
            finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
108
          if(arg_to_check == 0)
109
            finite = finite && isfinite(r_x[ii]);
110
          if(arg_to_check == 1)
111
112
113
114
115
116
117
118
119
            finite = finite && isfinite(r_y[ii]);
        }
        // see note in multi_tensor_scale_kernel.cu
#pragma unroll
        for(int ii = 0; ii < ILP; ii++)
        {
          int i = i_start + threadIdx.x + ii*blockDim.x;
          if(i < n && i < chunk_size)
            out[i] = r_out[ii];
120
        }
121
122
      }
    }
123
124
    if(!finite)
      *noop_gmem = 1; // Blindly fire off a write.  These will race but that's ok.
125
126
127
128
129
130
131
132
  }
};

void multi_tensor_axpby_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  float a,
133
134
  float b,
  int arg_to_check)
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
{
  using namespace at;
  // The output (downscaled) type is always float.
  // If build times suffer, think about where to put this dispatch,
  // and what logic should be moved out of multi_tensor_apply.

  DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
    DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
      DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
           multi_tensor_apply<3>(
             BLOCK_SIZE,
             chunk_size,
             noop_flag,
             tensor_lists,
             AxpbyFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(),
             a,
151
152
             b,
             arg_to_check); )))
153
154
155
156
157

  AT_CUDA_CHECK(cudaGetLastError());

  // AT_CUDA_CHECK(cudaDeviceSynchronize());
}