weight_norm_bwd_cuda.cu 7.66 KB
Newer Older
1
2
3
#include "kernel_utils.cuh"

#include <ATen/ATen.h>
4
5
6
7

#ifdef VERSION_LE_04
#include "ATen/cuda/AccumulateType.cuh"
#else
8
#include "ATen/AccumulateType.h"
9
10
#endif

11
#include "ATen/cuda/CUDATensorMethods.cuh"
Syed Tousif Ahmed's avatar
Syed Tousif Ahmed committed
12
// #include "ATen/cuda/CUDATypeConversion.cuh"
13
// #include <THC/THCTensorMathReduce.cuh>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

template
  <typename scalar_t, 
   typename accscalar_t>
__global__ void weight_norm_bwd_first_dim_kernel
  (scalar_t* __restrict__ pLpv,
   scalar_t* __restrict__ pLpg,
   const scalar_t* __restrict__ pLpw,
   const scalar_t* __restrict__ savedv,
   const scalar_t* __restrict__ savedg,
   const accscalar_t* __restrict__ savedNorms,
   const int rowSize)
{
  // For now, assign one block to each row.
  const int tid = threadIdx.x;
  const int row = blockIdx.x;
  const int stride = blockDim.x;

  // Logical index offset for this flattened row
  const int rowStart = row*rowSize;

  // Hack to get around nvcc complaining when an smem array is declared with the same name
  // but different types in different kernels (in this case different instantiations)
  // extern __shared__ accscalar_t s[]; // error: declaration is incompatible with previous "s"
  extern __shared__ char buf[];
  accscalar_t* s = (accscalar_t*)buf;
  
  accscalar_t thread_sum = 0.f;
  for(int i = tid; i < rowSize; i += stride ) 
  {
    accscalar_t pLpwi = scalar_cast<accscalar_t>(pLpw[i+rowStart]); 
    accscalar_t savedvi = scalar_cast<accscalar_t>(savedv[i+rowStart]); 
    thread_sum += pLpwi*savedvi; // AccumOp, could do Kahan here
  }

49
  reduce_block_into_lanes(s, thread_sum, 1, ReduceAdd<accscalar_t>());
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  accscalar_t result = s[0];

  // Could choose to save reciprocal of norm instead I suppose, but norms is probably
  // more handy to keep around.
  // Broadcast load; could use shared memory instead.
  accscalar_t rnorm = 1.f/savedNorms[row];  
  accscalar_t rnorm3 = rnorm*rnorm*rnorm;

  // Write g gradients.
  if(tid == 0)
    pLpg[row] = scalar_cast<scalar_t>(result*rnorm);

  // Broadcast load, could use shared memory instead.
  accscalar_t g_this_row = scalar_cast<accscalar_t>(savedg[row]);
   
  // Write v gradients.  We are reusing values that were loaded earlier, so there 
  // is an optimization opportunity here (store values persistently).
  for(int j = tid; j < rowSize; j += stride ) 
  {
    accscalar_t pLpwj = scalar_cast<accscalar_t>(pLpw[j+rowStart]);  
    accscalar_t savedvj = scalar_cast<accscalar_t>(savedv[j+rowStart]);  
    accscalar_t pLpvj = g_this_row*(rnorm*pLpwj - rnorm3*savedvj*result);
    pLpv[j+rowStart] = scalar_cast<scalar_t>(pLpvj);
  }
}

template 
  <typename scalar_t, 
   typename accscalar_t>
__global__ void weight_norm_bwd_last_dim_kernel
  (scalar_t* __restrict__ pLpv,
   scalar_t* __restrict__ pLpg,
   const scalar_t* __restrict__ pLpw,
   const scalar_t* __restrict__ savedv,
   const scalar_t* __restrict__ savedg,
   const accscalar_t* __restrict__ savedNorms,
   const int fast_dim_size,
   const int slower_dims_size)
{
  const int fast_dim_location = threadIdx.x + blockIdx.x*blockDim.x;

  extern __shared__ char buf[];
  accscalar_t* s = (accscalar_t*)buf;

  accscalar_t thread_sum = 0.f;

  int slower_dims_location = threadIdx.y;
  int currentIdx = fast_dim_location + fast_dim_size*slower_dims_location;
  if(fast_dim_location < fast_dim_size)
    while(slower_dims_location < slower_dims_size)
    {
      accscalar_t pLpwi = scalar_cast<accscalar_t>(pLpw[currentIdx]); 
      accscalar_t savedvi = scalar_cast<accscalar_t>(savedv[currentIdx]); 
      thread_sum += pLpwi*savedvi; // AccumOp, could do Kahan here
      currentIdx += blockDim.y*fast_dim_size;
      slower_dims_location += blockDim.y; 
    }

108
  reduce_block_into_lanes(s, thread_sum, blockDim.x, ReduceAdd<accscalar_t>()); 
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
  accscalar_t result = s[threadIdx.x];

  // Broadcast load; could use shared memory instead.
  accscalar_t rnorm = 1.f/savedNorms[fast_dim_location];  
  accscalar_t rnorm3 = rnorm*rnorm*rnorm;

  // Write g gradients.
  if(threadIdx.y == 0)
    pLpg[fast_dim_location] = scalar_cast<scalar_t>(result*rnorm);

  // Entire block pulls these values, could use shared memory instead.
  accscalar_t g_this_col = scalar_cast<accscalar_t>(savedg[fast_dim_location]);

  // Write v gradients.
  slower_dims_location = threadIdx.y;
  currentIdx = fast_dim_location + fast_dim_size*slower_dims_location;
  if(fast_dim_location < fast_dim_size)
    while(slower_dims_location < slower_dims_size)
    {
      accscalar_t pLpwj = scalar_cast<accscalar_t>(pLpw[currentIdx]);  
      accscalar_t savedvj = scalar_cast<accscalar_t>(savedv[currentIdx]);  
      accscalar_t pLpvj = g_this_col*(rnorm*pLpwj - rnorm3*savedvj*result);
      pLpv[currentIdx] = scalar_cast<scalar_t>(pLpvj);
      currentIdx += blockDim.y*fast_dim_size;
      slower_dims_location += blockDim.y; 
    } 
}

void weight_norm_bwd_cuda
  (const at::Tensor& pLpv,
   const at::Tensor& pLpg,
   const at::Tensor& pLpw,
   const at::Tensor& savedv,
   const at::Tensor& savedg,
   const at::Tensor& savedNorms,
   int dim)
{
#ifdef DEBUG_ANY
  using namespace std;
148
  cout << "Hello from send_to_bwd with pLpw.type() = " << pLpw.type() << endl;
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#endif

  const int ndims = savedv.ndimension();

  if(dim == 0) 
  {
    // Find logical size of each flattened slowest-dim row
    int rowSize = 1;
    for(int i = ndims - 1; i > 0; i--)
      rowSize *= savedv.size(i);

    using namespace at;
    cudaStream_t stream = globalContext().getCurrentCUDAStream();
    AT_DISPATCH_FLOATING_TYPES_AND_HALF
      (savedv.type(), 
       "weight_norm_bwd_first_dim_kernel",  
       [&]
       {
167
         using cuda_scalar_t = apex::cuda::type<scalar_t>;
168
         USING_ACCSCALAR_T
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

	 weight_norm_bwd_first_dim_kernel
	   <<<pLpw.size(0), 
	      BLOCK, 
	      BLOCK*sizeof(accscalar_t),
              stream>>>
	   (pLpv.data<cuda_scalar_t>(),
	    pLpg.data<cuda_scalar_t>(),
	    pLpw.data<cuda_scalar_t>(),
	    savedv.data<cuda_scalar_t>(),
	    savedg.data<cuda_scalar_t>(),
	    savedNorms.data<accscalar_t>(),
	    rowSize);
       });
  }
  else if(dim == ndims - 1)
  {
    // Precompute slower_dims_size and fast_dim_size because they involve dynamically indexing an array.
    int slower_dims_size = 1;
    for(int i = 0; i < ndims - 1; i++)
      slower_dims_size *= savedv.size(i);

    int fast_dim_size = savedv.size(ndims-1);

    using namespace at;
    cudaStream_t stream = globalContext().getCurrentCUDAStream();
    AT_DISPATCH_FLOATING_TYPES_AND_HALF
      (savedv.type(), 
       "weight_norm_bwd_last_dim_kernel",  
       [&]
       {
200
         using cuda_scalar_t = apex::cuda::type<scalar_t>;
201
         USING_ACCSCALAR_T
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

         weight_norm_bwd_last_dim_kernel
           <<<(fast_dim_size+TILE_W-1)/TILE_W,
              dim3(TILE_W,TILE_H), 
              (TILE_W*TILE_H + TILE_W)*sizeof(accscalar_t),
              stream>>>
           (pLpv.data<cuda_scalar_t>(),
            pLpg.data<cuda_scalar_t>(),
            pLpw.data<cuda_scalar_t>(),
            savedv.data<cuda_scalar_t>(),
            savedg.data<cuda_scalar_t>(),
            savedNorms.data<accscalar_t>(),
            fast_dim_size,
            slower_dims_size);
       });
  }
  // else 
  // {
  //   intermediate dim kernel.  Error checking on the dim was already done in 
  //   Module.cpp:weight_norm_bwd.  Could put that logic here instead, if we include
  //   <python.h> in both files.
  // }

  // The kernel execution is asynchronous, so this will only catch errors on the kernel launch,
  // not the kernel's execution.  Errors in kernel execution aren't guaranteed to be caught
  // until a later error check on a synchronizing CUDA call.  Unfortunately, without manually 
  // synchronizing here, this is the best we can do.
  THCudaCheck(cudaGetLastError());

#ifdef DEBUG_PROFILE
  THCudaCheck(cudaDeviceSynchronize());
#endif
}