multi_tensor_sgd_kernel.cu 6.74 KB
Newer Older
Simon Layton's avatar
Simon Layton committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"

#include <assert.h>
#include <cuda_runtime.h>

#define BLOCK_SIZE 512
#define ILP 4

/**
 * Perform fused SGD on multiple buffers
15
 * N: number of tensors
Simon Layton's avatar
Simon Layton committed
16
17
18
 * tl[0] : gradients
 * tl[1] : weights
 * tl[2] : momentum buffers
19
 * tl[3] : fp16 weights (if appropriate)
Simon Layton's avatar
Simon Layton committed
20
21
22
23
24
25
26
 * wd : weight_decay (scalar)
 * momentum : momentum (scalar)
 * dampening : momentum dampening (scalar)
 * lr : learning rate (scalar)
 * nesterov : enable nesterov (bool)
 * first run : necessary for proper momentum handling & init
 **/
27
template<int N, typename T_grad, typename T>
Simon Layton's avatar
Simon Layton committed
28
29
30
31
32
struct SGDFunctor
{
   __device__ __forceinline__ void operator()(
    int chunk_size,
    volatile int* noop_gmem,
33
    TensorList<N>& tl,
Simon Layton's avatar
Simon Layton committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    float wd,
    float momentum,
    float dampening,
    float lr,
    bool nesterov,
    bool first_run)
  {
    __shared__ int noop_smem;

    if(threadIdx.x == 0)
      noop_smem = *noop_gmem;
    __syncthreads();
    if(noop_smem == 1)
      return;

    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];

53
    T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
Simon Layton's avatar
Simon Layton committed
54
55
56
57
58
59
60
61
    grad_in += chunk_idx*chunk_size;
   
    T* weight_in = (T*)tl.addresses[1][tensor_loc];
    weight_in += chunk_idx*chunk_size;

    T* mom_in = (T*)tl.addresses[2][tensor_loc];
    mom_in += chunk_idx*chunk_size;

62
63
64
65
66
67
    half *model_weights_out = nullptr;
    if (N == 4) {
      model_weights_out = (half*)tl.addresses[3][tensor_loc];
      model_weights_out += chunk_idx*chunk_size;
    }

Simon Layton's avatar
Simon Layton committed
68
69
70
    n -= chunk_idx*chunk_size;

    // Non-divergent exit condition for the __syncthreads
71
72
73
    T incoming_grads[ILP];
    T incoming_weights[ILP];
    T incoming_moms[ILP];
Simon Layton's avatar
Simon Layton committed
74
75
76
77
78
79
80
81
82
83
84
85
    for(int i_start = 0;
        i_start < n && i_start < chunk_size;
        i_start += blockDim.x*ILP)
    {
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        incoming_grads[ii] = 0;
        incoming_weights[ii] = 0;
        incoming_moms[ii] = 0;
        int i = i_start + threadIdx.x + ii*blockDim.x;
        if(i < n && i < chunk_size)
86
87
88
          incoming_grads[ii] = static_cast<T>(grad_in[i]);
          incoming_weights[ii] = static_cast<T>(weight_in[i]);
          incoming_moms[ii] = static_cast<T>(mom_in[i]);
Simon Layton's avatar
Simon Layton committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
      }

      // note for clarification to future michael:
      // From a pure memory dependency perspective, there's likely no point unrolling
      // the write loop, since writes just fire off once their LDGs arrive.
      // Put another way, the STGs are dependent on the LDGs, but not on each other.
      // There is still compute ILP benefit from unrolling the loop though.
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        int i = i_start + threadIdx.x + ii*blockDim.x;
        if(i < n && i < chunk_size) {
          // apply weight decay
          if (wd != 0.f) {
            incoming_grads[ii] += wd * incoming_weights[ii];
          }
          if (momentum != 0.f) {
            if (!first_run) {
              incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
            }

            if (nesterov) {
              incoming_grads[ii] += momentum * incoming_moms[ii];
            }
          }

          // adjust the weight and write out
          weight_in[i] += (-lr * incoming_grads[ii]);

118
119
120
121
122
          // if necessary, write out an fp16 copy of the weights
          if (N == 4) {
            model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
          }

Simon Layton's avatar
Simon Layton committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
          // also write out the new momentum
          if (momentum != 0.f) {
            mom_in[i] = incoming_moms[ii];
          }
        }
      }

      // *noop_gmem = 1 is NOT guaranteed to be seen immediately by thread 0.  I wonder if
      // we can rig block-wide and grid-wide short-circuiting with only one syncthreads.
      // It's possible we can just lean on the cache (no smem or syncs) and still be fast.
      if(threadIdx.x == 0)
        noop_smem = *noop_gmem;
      __syncthreads();
      if(noop_smem == 1)
        break;
    }
  }
};

void multi_tensor_sgd_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  float wd,
  float momentum,
  float dampening,
  float lr,
  bool nesterov,
  bool first_run)
{
153
  auto num_tensors = tensor_lists.size();
Simon Layton's avatar
Simon Layton committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
  auto grad_type = tensor_lists[0][0].type().scalarType();
  auto weight_type = tensor_lists[0][0].type().scalarType();

  // We have 4 potentials to handle here, in terms of
  // grad_type, param_type, momentum_type, requires_fp16_copy
  // 1. fp16, fp16, fp16, No
  // 2. fp16, fp32, fp32, No
  // 3. fp16, fp32, fp32, Yes
  // 4. fp32, fp32, fp32, No
  // It's easier to hardcode these possibilities than to use
  // switches etc. to handle the cross-product of cases where
  // we don't want the majority of them.

  // Case 1. fp16, fp16, fp16, No
  if (grad_type == at::ScalarType::Half &&
      weight_type == at::ScalarType::Half &&
      num_tensors == 3) {
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<3, at::Half, at::Half>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
        first_run);
  }
  // Case 2. fp16, fp32, fp32, No
  else if (grad_type == at::ScalarType::Half &&
           weight_type == at::ScalarType::Float &&
           num_tensors == 3) {
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<3, at::Half, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
        first_run);
  }
  // Case 3. fp16, fp32, fp32, Yes
  else if (grad_type == at::ScalarType::Half &&
           weight_type == at::ScalarType::Float &&
           num_tensors == 4) {
    multi_tensor_apply<4>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<4, at::Half, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
        first_run);
  }
  // Case 4. fp32, fp32, fp32, No
  else if (grad_type == at::ScalarType::Float &&
      weight_type == at::ScalarType::Float &&
      num_tensors == 3) {
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<3, float, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
        first_run);
  }
  else {
    AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
             "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
238
  }
Simon Layton's avatar
Simon Layton committed
239
240
241

  AT_CUDA_CHECK(cudaGetLastError());
}