multi_tensor_sgd_kernel.cu 7.79 KB
Newer Older
Simon Layton's avatar
Simon Layton committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"

#include <assert.h>
#include <cuda_runtime.h>

#define BLOCK_SIZE 512
#define ILP 4

/**
 * Perform fused SGD on multiple buffers
15
 * N: number of tensors
Simon Layton's avatar
Simon Layton committed
16
17
18
 * tl[0] : gradients
 * tl[1] : weights
 * tl[2] : momentum buffers
19
 * tl[3] : fp16 weights (if appropriate)
Simon Layton's avatar
Simon Layton committed
20
21
22
23
24
25
 * wd : weight_decay (scalar)
 * momentum : momentum (scalar)
 * dampening : momentum dampening (scalar)
 * lr : learning rate (scalar)
 * nesterov : enable nesterov (bool)
 * first run : necessary for proper momentum handling & init
26
 * wd_after_momentum : apply weight decay _after_ momentum instead of before
Simon Layton's avatar
Simon Layton committed
27
 **/
28
template<int N, typename T_grad, typename T_weight>
Simon Layton's avatar
Simon Layton committed
29
30
31
32
33
struct SGDFunctor
{
   __device__ __forceinline__ void operator()(
    int chunk_size,
    volatile int* noop_gmem,
34
    TensorListMetadata<N>& tl,
Simon Layton's avatar
Simon Layton committed
35
36
37
38
39
    float wd,
    float momentum,
    float dampening,
    float lr,
    bool nesterov,
40
    bool first_run,
41
42
    bool wd_after_momentum,
    float scale)
Simon Layton's avatar
Simon Layton committed
43
  {
Simon Layton's avatar
Simon Layton committed
44
45
    // Early exit if we don't need to do anything
    if (*noop_gmem) return;
Simon Layton's avatar
Simon Layton committed
46
47
48
49
50

    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];

51
    T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
Simon Layton's avatar
Simon Layton committed
52
    grad_in += chunk_idx*chunk_size;
Michael Carilli's avatar
cleanup  
Michael Carilli committed
53

54
    T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
Simon Layton's avatar
Simon Layton committed
55
56
    weight_in += chunk_idx*chunk_size;

57
    T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
Simon Layton's avatar
Simon Layton committed
58
59
    mom_in += chunk_idx*chunk_size;

60
    at::Half *model_weights_out = nullptr;
Michael Carilli's avatar
Michael Carilli committed
61
62
    if(N == 4)
    {
63
      model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
64
65
66
      model_weights_out += chunk_idx*chunk_size;
    }

Simon Layton's avatar
Simon Layton committed
67
68
69
    n -= chunk_idx*chunk_size;

    // Non-divergent exit condition for the __syncthreads
70
71
72
    float incoming_grads[ILP];
    float incoming_weights[ILP];
    float incoming_moms[ILP];
Simon Layton's avatar
Simon Layton committed
73
74
75
76
77
78
79
80
81
82
83
84
    for(int i_start = 0;
        i_start < n && i_start < chunk_size;
        i_start += blockDim.x*ILP)
    {
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        incoming_grads[ii] = 0;
        incoming_weights[ii] = 0;
        incoming_moms[ii] = 0;
        int i = i_start + threadIdx.x + ii*blockDim.x;
        if(i < n && i < chunk_size)
Michael Carilli's avatar
Michael Carilli committed
85
        {
86
          incoming_grads[ii] = static_cast<float>(grad_in[i])*scale;
87
88
          incoming_weights[ii] = static_cast<float>(weight_in[i]);
          incoming_moms[ii] = static_cast<float>(mom_in[i]);
Michael Carilli's avatar
Michael Carilli committed
89
        }
Simon Layton's avatar
Simon Layton committed
90
91
92
93
94
95
96
97
98
99
100
      }

      // note for clarification to future michael:
      // From a pure memory dependency perspective, there's likely no point unrolling
      // the write loop, since writes just fire off once their LDGs arrive.
      // Put another way, the STGs are dependent on the LDGs, but not on each other.
      // There is still compute ILP benefit from unrolling the loop though.
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        int i = i_start + threadIdx.x + ii*blockDim.x;
Michael Carilli's avatar
Michael Carilli committed
101
102
        if(i < n && i < chunk_size)
        {
103
          // apply weight decay before momentum if necessary
Michael Carilli's avatar
Michael Carilli committed
104
          if(wd != 0.f && !wd_after_momentum)
Simon Layton's avatar
Simon Layton committed
105
            incoming_grads[ii] += wd * incoming_weights[ii];
Michael Carilli's avatar
Michael Carilli committed
106
107
108
109

          if(momentum != 0.f)
          {
            if(!first_run)
Simon Layton's avatar
Simon Layton committed
110
              incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
Michael Carilli's avatar
Michael Carilli committed
111
            else // initialize momentums to current incoming grads
112
              incoming_moms[ii] = incoming_grads[ii];
Simon Layton's avatar
Simon Layton committed
113

Michael Carilli's avatar
Michael Carilli committed
114
            if(nesterov)
Simon Layton's avatar
Simon Layton committed
115
              incoming_grads[ii] += momentum * incoming_moms[ii];
Michael Carilli's avatar
Michael Carilli committed
116
            else
117
              incoming_grads[ii] = incoming_moms[ii];
Simon Layton's avatar
Simon Layton committed
118
119
          }

120
          // Apply WD after momentum if desired
Michael Carilli's avatar
Michael Carilli committed
121
          if(wd != 0.f && wd_after_momentum)
122
123
            incoming_grads[ii] += wd * incoming_weights[ii];

Simon Layton's avatar
Simon Layton committed
124
125
126
          // adjust the weight and write out
          weight_in[i] += (-lr * incoming_grads[ii]);

127
          // if necessary, write out an fp16 copy of the weights
Michael Carilli's avatar
Michael Carilli committed
128
          if(N == 4)
129
130
            model_weights_out[i] = static_cast<at::Half>(weight_in[i]);

Simon Layton's avatar
Simon Layton committed
131
          // also write out the new momentum
Michael Carilli's avatar
Michael Carilli committed
132
          if(momentum != 0.f)
Simon Layton's avatar
Simon Layton committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
            mom_in[i] = incoming_moms[ii];
        }
      }
    }
  }
};

void multi_tensor_sgd_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  float wd,
  float momentum,
  float dampening,
  float lr,
  bool nesterov,
149
  bool first_run,
150
151
  bool wd_after_momentum,
  float scale)
Simon Layton's avatar
Simon Layton committed
152
{
153
  auto num_tensors = tensor_lists.size();
Michael Carilli's avatar
Michael Carilli committed
154
155
  auto grad_type = tensor_lists[0][0].scalar_type();
  auto weight_type = tensor_lists[1][0].scalar_type();
Simon Layton's avatar
Simon Layton committed
156

157
158
159
160
161
  if(num_tensors == 4)
    for(int i = 0; i < tensor_lists[3].size(); i++)
        AT_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
                 "Additional output tensors should always be fp16.");

Michael Carilli's avatar
Michael Carilli committed
162
  // We have 3 possibilities to handle here, in terms of
Simon Layton's avatar
Simon Layton committed
163
164
  // grad_type, param_type, momentum_type, requires_fp16_copy
  // 1. fp16, fp16, fp16, No
Michael Carilli's avatar
Michael Carilli committed
165
  // 2. fp32, fp32, fp32, No
Simon Layton's avatar
Simon Layton committed
166
  // 3. fp16, fp32, fp32, Yes
167
  // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
Simon Layton's avatar
Simon Layton committed
168
169
170
171
172
  // It's easier to hardcode these possibilities than to use
  // switches etc. to handle the cross-product of cases where
  // we don't want the majority of them.

  // Case 1. fp16, fp16, fp16, No
Michael Carilli's avatar
Michael Carilli committed
173
174
175
176
  if(grad_type == at::ScalarType::Half &&
     weight_type == at::ScalarType::Half &&
     num_tensors == 3)
  {
Simon Layton's avatar
Simon Layton committed
177
178
179
180
181
182
183
184
185
186
187
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<3, at::Half, at::Half>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
188
        first_run,
189
190
        wd_after_momentum,
        scale);
Simon Layton's avatar
Simon Layton committed
191
192
  }
  // Case 2. fp16, fp32, fp32, No
Michael Carilli's avatar
Michael Carilli committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
  // else if (grad_type == at::ScalarType::Half &&
  //          weight_type == at::ScalarType::Float &&
  //          num_tensors == 3) {
  //   multi_tensor_apply<3>(
  //       BLOCK_SIZE,
  //       chunk_size,
  //       noop_flag,
  //       tensor_lists,
  //       SGDFunctor<3, at::Half, float>(),
  //       wd,
  //       momentum,
  //       dampening,
  //       lr,
  //       nesterov,
  //       first_run,
  //       wd_after_momentum);
  // }
  // Case 2. fp32, fp32, fp32, No
  else if(grad_type == at::ScalarType::Float &&
          weight_type == at::ScalarType::Float &&
          num_tensors == 3)
  {
Simon Layton's avatar
Simon Layton committed
215
216
217
218
219
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
Michael Carilli's avatar
Michael Carilli committed
220
        SGDFunctor<3, float, float>(),
Simon Layton's avatar
Simon Layton committed
221
222
223
224
225
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
226
        first_run,
227
228
        wd_after_momentum,
        scale);
Simon Layton's avatar
Simon Layton committed
229
230
  }
  // Case 3. fp16, fp32, fp32, Yes
Michael Carilli's avatar
Michael Carilli committed
231
232
233
234
  else if(grad_type == at::ScalarType::Half &&
          weight_type == at::ScalarType::Float &&
          num_tensors == 4)
  {
Simon Layton's avatar
Simon Layton committed
235
236
237
238
239
240
241
242
243
244
245
    multi_tensor_apply<4>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<4, at::Half, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
246
        first_run,
247
248
        wd_after_momentum,
        scale);
Simon Layton's avatar
Simon Layton committed
249
  }
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
  // Case 4. fp32, fp32, fp32, Yes
  else if(grad_type == at::ScalarType::Float &&
          weight_type == at::ScalarType::Float &&
          num_tensors == 4)
  {
    multi_tensor_apply<4>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<4, float, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
        first_run,
        wd_after_momentum,
        scale);
  }
Michael Carilli's avatar
Michael Carilli committed
270
271
  else
  {
Simon Layton's avatar
Simon Layton committed
272
273
    AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
             "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
274
  }
Simon Layton's avatar
Simon Layton committed
275
276
277

  AT_CUDA_CHECK(cudaGetLastError());
}