multi_tensor_sgd_kernel.cu 6.97 KB
Newer Older
Simon Layton's avatar
Simon Layton committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"

#include <assert.h>
#include <cuda_runtime.h>

#define BLOCK_SIZE 512
#define ILP 4

/**
 * Perform fused SGD on multiple buffers
15
 * N: number of tensors
Simon Layton's avatar
Simon Layton committed
16
17
18
 * tl[0] : gradients
 * tl[1] : weights
 * tl[2] : momentum buffers
19
 * tl[3] : fp16 weights (if appropriate)
Simon Layton's avatar
Simon Layton committed
20
21
22
23
24
25
 * wd : weight_decay (scalar)
 * momentum : momentum (scalar)
 * dampening : momentum dampening (scalar)
 * lr : learning rate (scalar)
 * nesterov : enable nesterov (bool)
 * first run : necessary for proper momentum handling & init
26
 * wd_after_momentum : apply weight decay _after_ momentum instead of before
Simon Layton's avatar
Simon Layton committed
27
 **/
28
template<int N, typename T_grad, typename T_weight>
Simon Layton's avatar
Simon Layton committed
29
30
31
32
33
struct SGDFunctor
{
   __device__ __forceinline__ void operator()(
    int chunk_size,
    volatile int* noop_gmem,
34
    TensorListMetadata<N>& tl,
Simon Layton's avatar
Simon Layton committed
35
36
37
38
39
    float wd,
    float momentum,
    float dampening,
    float lr,
    bool nesterov,
40
41
    bool first_run,
    bool wd_after_momentum)
Simon Layton's avatar
Simon Layton committed
42
  {
Simon Layton's avatar
Simon Layton committed
43
44
    // Early exit if we don't need to do anything
    if (*noop_gmem) return;
Simon Layton's avatar
Simon Layton committed
45
46
47
48
49

    int tensor_loc = tl.block_to_tensor[blockIdx.x];
    int chunk_idx = tl.block_to_chunk[blockIdx.x];
    int n = tl.sizes[tensor_loc];

50
    T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
Simon Layton's avatar
Simon Layton committed
51
    grad_in += chunk_idx*chunk_size;
Michael Carilli's avatar
cleanup  
Michael Carilli committed
52

53
    T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
Simon Layton's avatar
Simon Layton committed
54
55
    weight_in += chunk_idx*chunk_size;

56
    T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
Simon Layton's avatar
Simon Layton committed
57
58
    mom_in += chunk_idx*chunk_size;

59
    at::Half *model_weights_out = nullptr;
Michael Carilli's avatar
Michael Carilli committed
60
61
    if(N == 4)
    {
62
      model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
63
64
65
      model_weights_out += chunk_idx*chunk_size;
    }

Simon Layton's avatar
Simon Layton committed
66
67
68
    n -= chunk_idx*chunk_size;

    // Non-divergent exit condition for the __syncthreads
69
70
71
    float incoming_grads[ILP];
    float incoming_weights[ILP];
    float incoming_moms[ILP];
Simon Layton's avatar
Simon Layton committed
72
73
74
75
76
77
78
79
80
81
82
83
    for(int i_start = 0;
        i_start < n && i_start < chunk_size;
        i_start += blockDim.x*ILP)
    {
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        incoming_grads[ii] = 0;
        incoming_weights[ii] = 0;
        incoming_moms[ii] = 0;
        int i = i_start + threadIdx.x + ii*blockDim.x;
        if(i < n && i < chunk_size)
Michael Carilli's avatar
Michael Carilli committed
84
        {
85
86
87
          incoming_grads[ii] = static_cast<float>(grad_in[i]);
          incoming_weights[ii] = static_cast<float>(weight_in[i]);
          incoming_moms[ii] = static_cast<float>(mom_in[i]);
Michael Carilli's avatar
Michael Carilli committed
88
        }
Simon Layton's avatar
Simon Layton committed
89
90
91
92
93
94
95
96
97
98
99
      }

      // note for clarification to future michael:
      // From a pure memory dependency perspective, there's likely no point unrolling
      // the write loop, since writes just fire off once their LDGs arrive.
      // Put another way, the STGs are dependent on the LDGs, but not on each other.
      // There is still compute ILP benefit from unrolling the loop though.
      #pragma unroll
      for(int ii = 0; ii < ILP; ii++)
      {
        int i = i_start + threadIdx.x + ii*blockDim.x;
Michael Carilli's avatar
Michael Carilli committed
100
101
        if(i < n && i < chunk_size)
        {
102
          // apply weight decay before momentum if necessary
Michael Carilli's avatar
Michael Carilli committed
103
          if(wd != 0.f && !wd_after_momentum)
Simon Layton's avatar
Simon Layton committed
104
            incoming_grads[ii] += wd * incoming_weights[ii];
Michael Carilli's avatar
Michael Carilli committed
105
106
107
108

          if(momentum != 0.f)
          {
            if(!first_run)
Simon Layton's avatar
Simon Layton committed
109
              incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
Michael Carilli's avatar
Michael Carilli committed
110
            else // initialize momentums to current incoming grads
111
              incoming_moms[ii] = incoming_grads[ii];
Simon Layton's avatar
Simon Layton committed
112

Michael Carilli's avatar
Michael Carilli committed
113
            if(nesterov)
Simon Layton's avatar
Simon Layton committed
114
              incoming_grads[ii] += momentum * incoming_moms[ii];
Michael Carilli's avatar
Michael Carilli committed
115
            else
116
              incoming_grads[ii] = incoming_moms[ii];
Simon Layton's avatar
Simon Layton committed
117
118
          }

119
          // Apply WD after momentum if desired
Michael Carilli's avatar
Michael Carilli committed
120
          if(wd != 0.f && wd_after_momentum)
121
122
            incoming_grads[ii] += wd * incoming_weights[ii];

Simon Layton's avatar
Simon Layton committed
123
124
125
          // adjust the weight and write out
          weight_in[i] += (-lr * incoming_grads[ii]);

126
          // if necessary, write out an fp16 copy of the weights
Michael Carilli's avatar
Michael Carilli committed
127
          if(N == 4)
128
129
            model_weights_out[i] = static_cast<at::Half>(weight_in[i]);

Simon Layton's avatar
Simon Layton committed
130
          // also write out the new momentum
Michael Carilli's avatar
Michael Carilli committed
131
          if(momentum != 0.f)
Simon Layton's avatar
Simon Layton committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
            mom_in[i] = incoming_moms[ii];
        }
      }
    }
  }
};

void multi_tensor_sgd_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  float wd,
  float momentum,
  float dampening,
  float lr,
  bool nesterov,
148
149
  bool first_run,
  bool wd_after_momentum)
Simon Layton's avatar
Simon Layton committed
150
{
151
  auto num_tensors = tensor_lists.size();
Michael Carilli's avatar
Michael Carilli committed
152
153
  auto grad_type = tensor_lists[0][0].scalar_type();
  auto weight_type = tensor_lists[1][0].scalar_type();
Simon Layton's avatar
Simon Layton committed
154

Michael Carilli's avatar
Michael Carilli committed
155
  // We have 3 possibilities to handle here, in terms of
Simon Layton's avatar
Simon Layton committed
156
157
  // grad_type, param_type, momentum_type, requires_fp16_copy
  // 1. fp16, fp16, fp16, No
Michael Carilli's avatar
Michael Carilli committed
158
  // 2. fp32, fp32, fp32, No
Simon Layton's avatar
Simon Layton committed
159
160
161
162
163
164
  // 3. fp16, fp32, fp32, Yes
  // It's easier to hardcode these possibilities than to use
  // switches etc. to handle the cross-product of cases where
  // we don't want the majority of them.

  // Case 1. fp16, fp16, fp16, No
Michael Carilli's avatar
Michael Carilli committed
165
166
167
168
  if(grad_type == at::ScalarType::Half &&
     weight_type == at::ScalarType::Half &&
     num_tensors == 3)
  {
Simon Layton's avatar
Simon Layton committed
169
170
171
172
173
174
175
176
177
178
179
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<3, at::Half, at::Half>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
180
181
        first_run,
        wd_after_momentum);
Simon Layton's avatar
Simon Layton committed
182
183
  }
  // Case 2. fp16, fp32, fp32, No
Michael Carilli's avatar
Michael Carilli committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
  // else if (grad_type == at::ScalarType::Half &&
  //          weight_type == at::ScalarType::Float &&
  //          num_tensors == 3) {
  //   multi_tensor_apply<3>(
  //       BLOCK_SIZE,
  //       chunk_size,
  //       noop_flag,
  //       tensor_lists,
  //       SGDFunctor<3, at::Half, float>(),
  //       wd,
  //       momentum,
  //       dampening,
  //       lr,
  //       nesterov,
  //       first_run,
  //       wd_after_momentum);
  // }
  // Case 2. fp32, fp32, fp32, No
  else if(grad_type == at::ScalarType::Float &&
          weight_type == at::ScalarType::Float &&
          num_tensors == 3)
  {
Simon Layton's avatar
Simon Layton committed
206
207
208
209
210
    multi_tensor_apply<3>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
Michael Carilli's avatar
Michael Carilli committed
211
        SGDFunctor<3, float, float>(),
Simon Layton's avatar
Simon Layton committed
212
213
214
215
216
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
217
218
        first_run,
        wd_after_momentum);
Simon Layton's avatar
Simon Layton committed
219
220
  }
  // Case 3. fp16, fp32, fp32, Yes
Michael Carilli's avatar
Michael Carilli committed
221
222
223
224
  else if(grad_type == at::ScalarType::Half &&
          weight_type == at::ScalarType::Float &&
          num_tensors == 4)
  {
Simon Layton's avatar
Simon Layton committed
225
226
227
228
229
230
231
232
233
234
235
    multi_tensor_apply<4>(
        BLOCK_SIZE,
        chunk_size,
        noop_flag,
        tensor_lists,
        SGDFunctor<4, at::Half, float>(),
        wd,
        momentum,
        dampening,
        lr,
        nesterov,
236
237
        first_run,
        wd_after_momentum);
Simon Layton's avatar
Simon Layton committed
238
  }
Michael Carilli's avatar
Michael Carilli committed
239
240
  else
  {
Simon Layton's avatar
Simon Layton committed
241
242
    AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
             "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
243
  }
Simon Layton's avatar
Simon Layton committed
244
245
246

  AT_CUDA_CHECK(cudaGetLastError());
}