multi_tensor_sgd_kernel.cu 8.91 KB
Newer Older
Simon Layton's avatar
Simon Layton committed
1
2
3
4
5
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
6
#include "compat.h"
Simon Layton's avatar
Simon Layton committed
7
8
9
10

#include <assert.h>
#include <cuda_runtime.h>

11
#define BLOCK_SIZE 1024
Simon Layton's avatar
Simon Layton committed
12
13
14
15
#define ILP 4

/**
 * Perform fused SGD on multiple buffers
16
 * N: number of tensors
Simon Layton's avatar
Simon Layton committed
17
18
19
 * tl[0] : gradients
 * tl[1] : weights
 * tl[2] : momentum buffers
20
 * tl[3] : fp16 weights (if appropriate)
Simon Layton's avatar
Simon Layton committed
21
22
23
24
25
26
 * wd : weight_decay (scalar)
 * momentum : momentum (scalar)
 * dampening : momentum dampening (scalar)
 * lr : learning rate (scalar)
 * nesterov : enable nesterov (bool)
 * first run : necessary for proper momentum handling & init
27
 * wd_after_momentum : apply weight decay _after_ momentum instead of before
Simon Layton's avatar
Simon Layton committed
28
 **/
29
template<int N, typename T_grad, typename T_weight>
Simon Layton's avatar
Simon Layton committed
30
31
32
33
34
struct SGDFunctor
{
   __device__ __forceinline__ void operator()(
    int chunk_size,
    volatile int* noop_gmem,
35
    TensorListMetadata<N>& tl,
Simon Layton's avatar
Simon Layton committed
36
37
38
39
40
    float wd,
    float momentum,
    float dampening,
    float lr,
    bool nesterov,
41
    bool first_run,
42
43
    bool wd_after_momentum,
    float scale)
Simon Layton's avatar
Simon Layton committed
44
  {
Simon Layton's avatar
Simon Layton committed
45
46
    // Early exit if we don't need to do anything
    if (*noop_gmem) return;
Simon Layton's avatar
Simon Layton committed
47

48
49
50
    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];
Simon Layton's avatar
Simon Layton committed
51

52
    T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
Simon Layton's avatar
Simon Layton committed
53
    grad_in += chunk_idx*chunk_size;
Michael Carilli's avatar
cleanup  
Michael Carilli committed
54

55
    T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
Simon Layton's avatar
Simon Layton committed
56
57
    weight_in += chunk_idx*chunk_size;

58
    T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
Simon Layton's avatar
Simon Layton committed
59
60
    mom_in += chunk_idx*chunk_size;

61
    at::Half *model_weights_out = nullptr;
Michael Carilli's avatar
Michael Carilli committed
62
63
    if(N == 4)
    {
64
      model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
65
66
67
      model_weights_out += chunk_idx*chunk_size;
    }

Simon Layton's avatar
Simon Layton committed
68
69
70
    n -= chunk_idx*chunk_size;

    // Non-divergent exit condition for the __syncthreads
71
72
73
    float incoming_grads[ILP];
    float incoming_weights[ILP];
    float incoming_moms[ILP];
Simon Layton's avatar
Simon Layton committed
74
75
76
77
78
79
80
81
82
83
84
85
    for(int i_start = 0;
        i_start < n && i_start < chunk_size;
        i_start += blockDim.x*ILP)
    {
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        incoming_grads[ii] = 0;
        incoming_weights[ii] = 0;
        incoming_moms[ii] = 0;
        int i = i_start + threadIdx.x + ii*blockDim.x;
        if(i < n && i < chunk_size)
Michael Carilli's avatar
Michael Carilli committed
86
        {
87
          incoming_grads[ii] = static_cast<float>(grad_in[i])*scale;
88
89
          incoming_weights[ii] = static_cast<float>(weight_in[i]);
          incoming_moms[ii] = static_cast<float>(mom_in[i]);
Michael Carilli's avatar
Michael Carilli committed
90
        }
Simon Layton's avatar
Simon Layton committed
91
92
93
94
95
96
97
98
99
100
101
      }

      // note for clarification to future michael:
      // From a pure memory dependency perspective, there's likely no point unrolling
      // the write loop, since writes just fire off once their LDGs arrive.
      // Put another way, the STGs are dependent on the LDGs, but not on each other.
      // There is still compute ILP benefit from unrolling the loop though.
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        int i = i_start + threadIdx.x + ii*blockDim.x;
Michael Carilli's avatar
Michael Carilli committed
102
103
        if(i < n && i < chunk_size)
        {
104
          // apply weight decay before momentum if necessary
Michael Carilli's avatar
Michael Carilli committed
105
          if(wd != 0.f && !wd_after_momentum)
Simon Layton's avatar
Simon Layton committed
106
            incoming_grads[ii] += wd * incoming_weights[ii];
Michael Carilli's avatar
Michael Carilli committed
107
108
109
110

          if(momentum != 0.f)
          {
            if(!first_run)
Simon Layton's avatar
Simon Layton committed
111
              incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
Michael Carilli's avatar
Michael Carilli committed
112
            else // initialize momentums to current incoming grads
113
              incoming_moms[ii] = incoming_grads[ii];
Simon Layton's avatar
Simon Layton committed
114

Michael Carilli's avatar
Michael Carilli committed
115
            if(nesterov)
Simon Layton's avatar
Simon Layton committed
116
              incoming_grads[ii] += momentum * incoming_moms[ii];
Michael Carilli's avatar
Michael Carilli committed
117
            else
118
              incoming_grads[ii] = incoming_moms[ii];
Simon Layton's avatar
Simon Layton committed
119
120
          }

121
          // Apply WD after momentum if desired
Michael Carilli's avatar
Michael Carilli committed
122
          if(wd != 0.f && wd_after_momentum)
123
124
            incoming_grads[ii] += wd * incoming_weights[ii];

Simon Layton's avatar
Simon Layton committed
125
126
127
          // adjust the weight and write out
          weight_in[i] += (-lr * incoming_grads[ii]);

128
          // if necessary, write out an fp16 copy of the weights
Michael Carilli's avatar
Michael Carilli committed
129
          if(N == 4)
130
131
            model_weights_out[i] = static_cast<at::Half>(weight_in[i]);

Simon Layton's avatar
Simon Layton committed
132
          // also write out the new momentum
Michael Carilli's avatar
Michael Carilli committed
133
          if(momentum != 0.f)
Simon Layton's avatar
Simon Layton committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
            mom_in[i] = incoming_moms[ii];
        }
      }
    }
  }
};

void multi_tensor_sgd_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  float wd,
  float momentum,
  float dampening,
  float lr,
  bool nesterov,
150
  bool first_run,
151
152
  bool wd_after_momentum,
  float scale)
Simon Layton's avatar
Simon Layton committed
153
{
154
  auto num_tensors = tensor_lists.size();
Michael Carilli's avatar
Michael Carilli committed
155
156
  auto grad_type = tensor_lists[0][0].scalar_type();
  auto weight_type = tensor_lists[1][0].scalar_type();
Simon Layton's avatar
Simon Layton committed
157

158
159
  if(num_tensors == 4)
    for(int i = 0; i < tensor_lists[3].size(); i++)
160
        TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
161
162
                 "Additional output tensors should always be fp16.");

163
164
  TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");

Michael Carilli's avatar
Michael Carilli committed
165
  // We have 3 possibilities to handle here, in terms of
Simon Layton's avatar
Simon Layton committed
166
167
  // grad_type, param_type, momentum_type, requires_fp16_copy
  // 1. fp16, fp16, fp16, No
Michael Carilli's avatar
Michael Carilli committed
168
  // 2. fp32, fp32, fp32, No
Simon Layton's avatar
Simon Layton committed
169
  // 3. fp16, fp32, fp32, Yes
170
  // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
171
172
  // 5. bfp16, bfp16, bfp16, No
  // 6. bfp16, fp32, fp32, Yes
Simon Layton's avatar
Simon Layton committed
173
174
175
176
177
  // It's easier to hardcode these possibilities than to use
  // switches etc. to handle the cross-product of cases where
  // we don't want the majority of them.

  // Case 1. fp16, fp16, fp16, No
Michael Carilli's avatar
Michael Carilli committed
178
179
180
181
  if(grad_type == at::ScalarType::Half &&
     weight_type == at::ScalarType::Half &&
     num_tensors == 3)
  {
Simon Layton's avatar
Simon Layton committed
182
183
184
185
186
187
188
189
190
191
192
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<3, at::Half, at::Half>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
193
        first_run,
194
195
        wd_after_momentum,
        scale);
Simon Layton's avatar
Simon Layton committed
196
197
  }
  // Case 2. fp16, fp32, fp32, No
Michael Carilli's avatar
Michael Carilli committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
  // else if (grad_type == at::ScalarType::Half &&
  //          weight_type == at::ScalarType::Float &&
  //          num_tensors == 3) {
  //   multi_tensor_apply<3>(
  //       BLOCK_SIZE,
  //       chunk_size,
  //       noop_flag,
  //       tensor_lists,
  //       SGDFunctor<3, at::Half, float>(),
  //       wd,
  //       momentum,
  //       dampening,
  //       lr,
  //       nesterov,
  //       first_run,
  //       wd_after_momentum);
  // }
  // Case 2. fp32, fp32, fp32, No
  else if(grad_type == at::ScalarType::Float &&
          weight_type == at::ScalarType::Float &&
          num_tensors == 3)
  {
Simon Layton's avatar
Simon Layton committed
220
221
222
223
224
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
Michael Carilli's avatar
Michael Carilli committed
225
        SGDFunctor<3, float, float>(),
Simon Layton's avatar
Simon Layton committed
226
227
228
229
230
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
231
        first_run,
232
233
        wd_after_momentum,
        scale);
Simon Layton's avatar
Simon Layton committed
234
235
  }
  // Case 3. fp16, fp32, fp32, Yes
Michael Carilli's avatar
Michael Carilli committed
236
237
238
239
  else if(grad_type == at::ScalarType::Half &&
          weight_type == at::ScalarType::Float &&
          num_tensors == 4)
  {
Simon Layton's avatar
Simon Layton committed
240
241
242
243
244
245
246
247
248
249
250
    multi_tensor_apply<4>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<4, at::Half, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
251
        first_run,
252
253
        wd_after_momentum,
        scale);
Simon Layton's avatar
Simon Layton committed
254
  }
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
  // Case 4. fp32, fp32, fp32, Yes
  else if(grad_type == at::ScalarType::Float &&
          weight_type == at::ScalarType::Float &&
          num_tensors == 4)
  {
    multi_tensor_apply<4>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<4, float, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
        first_run,
        wd_after_momentum,
        scale);
  }
275
  // Case 5. bfp16, bfp16, bfp16, No
276
  else if(grad_type == at::ScalarType::BFloat16 &&
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
     weight_type == at::ScalarType::BFloat16 &&
     num_tensors == 3)
  {
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<3, at::BFloat16, at::BFloat16>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
        first_run,
        wd_after_momentum,
        scale);
  }
  // Case 6. bfp16, fp32, fp32, Yes
  else if(grad_type == at::ScalarType::BFloat16 &&
          weight_type == at::ScalarType::Float &&
          num_tensors == 4)
  {
    multi_tensor_apply<4>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<4, at::BFloat16, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
        first_run,
        wd_after_momentum,
        scale);
  }
Michael Carilli's avatar
Michael Carilli committed
315
316
  else
  {
Simon Layton's avatar
Simon Layton committed
317
318
    AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
             "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
319
  }
Simon Layton's avatar
Simon Layton committed
320
321
322

  AT_CUDA_CHECK(cudaGetLastError());
}