welford.cu 19.4 KB
Newer Older
jjsjann123's avatar
jjsjann123 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
#include <iostream>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>


#include <cuda.h>
#include <cuda_runtime.h>


#include <vector>


__device__ __forceinline__ int lastpow2(int n)
{
  int out = 1 << (31 - __clz(n));
  if(n == out) 
    out >>= 1;
  return out;
}

__host__ __forceinline__ int h_next_pow2(unsigned int n) {
    n |= (n >>  1);
    n |= (n >>  2);
    n |= (n >>  4);
    n |= (n >>  8);
    n |= (n >> 16);
    return n + 1;
}

__host__ __forceinline__ int h_last_pow2(unsigned int n) {
    n |= (n >>  1);
    n |= (n >>  2);
    n |= (n >>  4);
    n |= (n >>  8);
    n |= (n >> 16);
    return n - (n >> 1);
}


#define WARP_SIZE 32

template<typename T>
__device__ __forceinline__ T warp_reduce_sum(T val)
{
  #pragma unroll
  for(int i = WARP_SIZE/2; i > 0; i >>= 1)
    val = val + __shfl_down_sync(0xffffffff, val, i);
  return val;
}

template<typename T>
__device__ __forceinline__ T reduce_block(T *x, T val)
{
  int tid = threadIdx.y*blockDim.x + threadIdx.x;
  int blockSize = blockDim.x * blockDim.y;

  if (blockSize > 32) {
    val = warp_reduce_sum(val);
    if (tid % WARP_SIZE == 0)
      x[tid/WARP_SIZE] = val;

    __syncthreads();

    val = (tid < blockSize / WARP_SIZE? x[tid%WARP_SIZE] : T(0));
  }

  if(tid/WARP_SIZE==0) val = warp_reduce_sum(val);

  return val;
}

#define TILE_W 32
#define MAX_BLOCK_SIZE 256

template<typename T>
__device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
{
  #pragma unroll
  for(int i = WARP_SIZE/2; i > 0; i >>= 1) {
    auto num_new = __shfl_down_sync(0xffffffff, num, i);
    auto mean_new = __shfl_down_sync(0xffffffff, mean, i);
    auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i);
    if (num_new != 0) {
      auto dif_mean = mean - mean_new;
      mean = (mean_new * num_new + mean * num) / (num + num_new);
      m2n += m2n_new + dif_mean*dif_mean*num*num_new/(num_new+num);
      num += num_new;
    }
  }
}

template <typename T>
__device__ void welford_reduce_mean_m2n(
      T* __restrict__ x,
      int* __restrict__ count,
      T &mean,
      T &m2n,
      int &num,
      int block_size,
      int thread_id)
{
  int lane = thread_id % WARP_SIZE;
  int wid = thread_id / WARP_SIZE;

  if (block_size > 32) {
    warp_reduce_mean_m2n(mean, m2n, num);
    if (lane == 0) {
      x[wid*2] = mean;
      x[wid*2+1] = m2n;
      count[wid] = num;
    }
    __syncthreads();

    if (wid == 0) {
      mean = (thread_id < block_size / WARP_SIZE)? x[lane*2] : T(0);
      m2n = (thread_id < block_size / WARP_SIZE)? x[lane*2+1] : T(0);
      num = (thread_id < block_size / WARP_SIZE)? count[lane] : int(0);
    }
  }

  if (wid==0) warp_reduce_mean_m2n(mean, m2n, num);

  return;
}

// return spatial size for NC+ Tensors
__host__ int get_tensor_spatial_size(const at::Tensor& input)
{
  auto space_size = input.size(2);
  for (int i = 3; i < input.ndimension(); i++) {
    space_size *= input.size(i);
  }
  return space_size;
}

// promote accumulation scalar type. promote half to float.
__host__ at::ScalarType promote_scalartype(const at::Tensor& input)
{
  return input.type().scalarType() == at::ScalarType::Half ?
           at::ScalarType::Float : input.type().scalarType();
}

// return single element size, optional accumulation type promotion.
__host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false)
{
  auto scalar_type = accumulation ? promote_scalartype(input) : input.type().scalarType();
  return at::elementSize(scalar_type);
}


// welford kernel calculating mean/biased_variance/unbiased_variance
template <typename scalar_t, typename accscalar_t, typename outscalar_t>
__global__ void welford_kernel(
      const scalar_t* __restrict__ input,
      outscalar_t* __restrict__ out_mean,
      outscalar_t* __restrict__ out_var,
      outscalar_t* __restrict__ out_var_biased,
      const int bs,
      const int fs,
      const int ss) {
  static __shared__ int s_mem[160];
  int block_size = blockDim.x * blockDim.y;

  accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];

  int count = 0;
  accscalar_t x_mean = accscalar_t(0);
  accscalar_t m_2_n = accscalar_t(0);

  int thread_id = threadIdx.y*blockDim.x + threadIdx.x;

  for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {
    int input_base = blockIdx.x*ss + batch_id*ss*fs;
    // sequential welford
    for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
      count++;
      auto x_n = static_cast<accscalar_t>(input[offset+input_base]);
      auto x_mean_new = x_mean + (x_n - x_mean) / count;
      m_2_n = m_2_n + (x_n - x_mean_new) * (x_n - x_mean);
      x_mean = x_mean_new;
    }
  }

  welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);

  if (thread_id == 0) {
    out_mean[blockIdx.x] = static_cast<outscalar_t>(x_mean);
    out_var[blockIdx.x] = static_cast<outscalar_t>(m_2_n/(count-1));
    out_var_biased[blockIdx.x] = static_cast<outscalar_t>(m_2_n/count);
  }
}

// elementwise BN kernel
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
__global__ void batchnorm_forward_kernel(
      const scalar_t* __restrict__ input,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ var,
      const layerscalar_t* __restrict__ weight,
      const layerscalar_t* __restrict__ shift,
      scalar_t* __restrict__ out,
      const int ss,
      const float eps) {
  int address_base = blockIdx.x*ss + blockIdx.y*gridDim.x*ss;

  auto m_c = mean[blockIdx.x];
  auto inv_std_c = static_cast<accscalar_t>(rsqrt(var[blockIdx.x] + eps));
  auto w_c = static_cast<accscalar_t>(weight[blockIdx.x]);
  auto s_c = static_cast<accscalar_t>(shift[blockIdx.x]);

  for (int offset = threadIdx.x; offset < ss ; offset+= blockDim.x) {
    out[address_base+offset] = static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base+offset]) - m_c ) * inv_std_c + s_c);
  }
}

// Backward BN kernel, calculates grad_bias, grad_weight as well as intermediate
// results to calculating grad_input.
// Breaking the grad_input to two step to support sync BN, which requires all
// reduce of the intermediate results across processes.
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
__global__ void reduce_bn_kernel(
      const scalar_t* __restrict__ input,
      const scalar_t* __restrict__ grad_output,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ var,
      accscalar_t* __restrict__ mean_dy,
      accscalar_t* __restrict__ mean_dy_xmu,
      layerscalar_t* __restrict__ grad_weight,
      layerscalar_t* __restrict__ grad_bias,
      const int bs,
      const int fs,
      const int ss,
      const float eps) {
  static __shared__ int s_mem[64];
  int total_item_num = bs * ss;

  int thread_id = threadIdx.y*blockDim.x + threadIdx.x;

  auto r_mean = mean[blockIdx.x];
  auto factor = accscalar_t(1.0) / (accscalar_t)sqrt(var[blockIdx.x] + eps);

  // Kahan sum
  accscalar_t sum_dy = 0.0;
  accscalar_t sum_dy_xmu = 0.0;
  accscalar_t sum_dy_c = 0.0;
  accscalar_t sum_dy_xmu_c = 0.0;
  for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {
    int input_base = blockIdx.x*ss + batch_id*ss*fs;
    for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
      auto e_grad = static_cast<accscalar_t>(grad_output[offset+input_base]);
      auto e_input = static_cast<accscalar_t>(input[offset+input_base]);
      // calculating sum_dy
      auto sum_dy_y = e_grad - sum_dy_c;
      auto sum_dy_t = sum_dy + sum_dy_y;
      sum_dy_c = (sum_dy_t - sum_dy) - sum_dy_y;
      sum_dy = sum_dy_t;

      // calculating sum_dy_xmu
      auto sum_dy_xmu_y = e_grad * (e_input - r_mean) - sum_dy_xmu_c;
      auto sum_dy_xmu_t = sum_dy_xmu + sum_dy_xmu_y;
      sum_dy_xmu_c = (sum_dy_xmu_t - sum_dy_xmu) - sum_dy_xmu_y;
      sum_dy_xmu = sum_dy_xmu_t;
    }
  }

  sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy);
  __syncthreads();
  sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);
  
  if (thread_id == 0) {
    grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);
    grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
    mean_dy[blockIdx.x] = sum_dy / total_item_num;
    mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
  }
}

// elementwise backward BN kernel
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
__global__ void batchnorm_backward_kernel(
      const scalar_t* __restrict__ grad_output,
      const scalar_t* __restrict__ input,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ var,
      const layerscalar_t* __restrict__ weight,
      const accscalar_t* __restrict__ mean_dy,
      const accscalar_t* __restrict__ mean_dy_xmu,
      scalar_t* __restrict__ grad_input,
      const int ss,
      const float eps) {
  int address_base = blockIdx.x*ss + blockIdx.y*gridDim.x*ss;

  auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
  auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
  auto factor_1_c = static_cast<accscalar_t>(var[blockIdx.x]) + eps;
  auto factor_2_c = static_cast<accscalar_t>(weight[blockIdx.x]) / sqrt(factor_1_c);
  factor_1_c /= static_cast<accscalar_t>(mean_dy_xmu[blockIdx.x]);

  for (int offset = threadIdx.x; offset < ss ; offset+= blockDim.x) {
    grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) / factor_1_c) * factor_2_c;
  }
}

// parallel welford kernel to further reduce mean / biased_var / unbiased_var
// across multiple processes.
template <typename scalar_t, typename accscalar_t>
__global__ void welford_kernel_parallel(
      const scalar_t* __restrict__ mean,
      const scalar_t* __restrict__ var_biased,
      scalar_t* __restrict__ out_mean,
      scalar_t* __restrict__ out_var,
      scalar_t* __restrict__ out_var_biased,
      const int ns,
      const int fs,
      const int numel) {
  static __shared__ int s_mem[160];
  int block_size = blockDim.x;

  accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];

  int input_base = blockIdx.x*ns + threadIdx.x;
  int thread_id = threadIdx.x;

  // load data; 
  auto x_mean = static_cast<accscalar_t>(mean[input_base]);
  auto m_2_n = static_cast<accscalar_t>(var_biased[input_base]) * numel;
  auto count = numel;

  __syncthreads();

  welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);

  if (thread_id == 0) {
    out_mean[blockIdx.x] = static_cast<scalar_t>(x_mean);
    out_var[blockIdx.x] = static_cast<scalar_t>(m_2_n/(count-1));
    out_var_biased[blockIdx.x] = static_cast<scalar_t>(m_2_n/count);
  }
}
  

std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
  const auto batch_size = input.size(0);
  const auto feature_size = input.size(1);

  auto space_size = get_tensor_spatial_size(input);
  auto scalar_type = promote_scalartype(input);

  at::Tensor out_var = at::empty({feature_size}, input.options().dtype(scalar_type));
  at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));
  at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));

  int block_x = TILE_W;
  int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / block_x));
  const dim3 block(block_x, block_y);
  const dim3 grid(feature_size);

  // shared memory used for reduce on mean, var, num_elements;
  auto stream = at::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_kernel", ([&] {
    using accscalar_t = at::acc_type<scalar_t, true>;
    welford_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
        input.data<scalar_t>(),
        out_mean.data<accscalar_t>(),
        out_var.data<accscalar_t>(),
        out_var_biased.data<accscalar_t>(),
        batch_size,
        feature_size,
        space_size);
  }));

  return {out_mean, out_var, out_var_biased};
}

at::Tensor batchnorm_forward_CUDA(
    const at::Tensor input,
    const at::Tensor mean,
    const at::Tensor var,
    const at::Tensor weight,
    const at::Tensor shift,
    const float eps) {
  const auto batch_size = input.size(0);
  const auto feature_size = input.size(1);
  at::Tensor out = at::empty_like(input);

  auto space_size = get_tensor_spatial_size(input);

  int block = min(MAX_BLOCK_SIZE, h_next_pow2(space_size)/4);
  // TODO(jie): should I do 1 block per feature?
  const dim3 grid(feature_size, batch_size);
  auto stream = at::cuda::getCurrentCUDAStream();

  if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
          input.data<scalar_t>(),
          mean.data<accscalar_t>(),
          var.data<accscalar_t>(),
          weight.data<accscalar_t>(),
          shift.data<accscalar_t>(),
          out.data<scalar_t>(),
          space_size,
          eps);
    }));
  } else {
    AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()"); 
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
          input.data<scalar_t>(),
          mean.data<accscalar_t>(),
          var.data<accscalar_t>(),
          weight.data<scalar_t>(),
          shift.data<scalar_t>(),
          out.data<scalar_t>(),
          space_size,
          eps);
    }));
  }
  return out;
}

std::vector<at::Tensor> reduce_bn_CUDA(
    const at::Tensor grad_output,
    const at::Tensor input,
    const at::Tensor mean,
    const at::Tensor var,
    const at::Tensor weight,
    const float eps)
{
  const auto batch_size = input.size(0);
  const auto feature_size = input.size(1);

  auto scalar_type = promote_scalartype(input);

  at::Tensor mean_dy = at::empty({feature_size}, mean.options());
  at::Tensor mean_dy_xmu = at::empty({feature_size}, mean.options());
  at::Tensor grad_weight = at::empty({feature_size}, weight.options());
  at::Tensor grad_bias = at::empty({feature_size}, weight.options());

  auto space_size = get_tensor_spatial_size(input);

  int block_x = TILE_W;
  int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / block_x));
  const dim3 block(block_x, block_y);
  const dim3 grid(feature_size);
  // shared memory used for reduce on sum_dy, sum_dy_xmu;
  auto stream = at::cuda::getCurrentCUDAStream();

  if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
          input.data<scalar_t>(),
          grad_output.data<scalar_t>(),
          mean.data<accscalar_t>(),
          var.data<accscalar_t>(),
          mean_dy.data<accscalar_t>(),
          mean_dy_xmu.data<accscalar_t>(),
          grad_weight.data<accscalar_t>(),
          grad_bias.data<accscalar_t>(),
          batch_size,
          feature_size,
          space_size,
          eps);
    }));
  } else {
    AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()"); 
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
          input.data<scalar_t>(),
          grad_output.data<scalar_t>(),
          mean.data<accscalar_t>(),
          var.data<accscalar_t>(),
          mean_dy.data<accscalar_t>(),
          mean_dy_xmu.data<accscalar_t>(),
          grad_weight.data<scalar_t>(),
          grad_bias.data<scalar_t>(),
          batch_size,
          feature_size,
          space_size,
          eps);
    }));
  }
  
  return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
}

at::Tensor batchnorm_backward_CUDA(
    const at::Tensor grad_output,
    const at::Tensor input,
    const at::Tensor mean,
    const at::Tensor var,
    const at::Tensor weight,
    const at::Tensor mean_dy,
    const at::Tensor mean_dy_xmu,
    const float eps) {
  const auto batch_size = input.size(0);
  const auto feature_size = input.size(1);

  at::Tensor grad_input = at::empty_like(input);

  auto space_size = get_tensor_spatial_size(input);

  int block = min(MAX_BLOCK_SIZE, h_next_pow2(space_size)/4);
  // TODO(jie): should I do 1 block per feature?
  const dim3 grid(feature_size, batch_size);
  auto stream = at::cuda::getCurrentCUDAStream();

  if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
          grad_output.data<scalar_t>(),
          input.data<scalar_t>(),
          mean.data<accscalar_t>(),
          var.data<accscalar_t>(),
          weight.data<accscalar_t>(),
          mean_dy.data<accscalar_t>(),
          mean_dy_xmu.data<accscalar_t>(),
          grad_input.data<scalar_t>(),
          space_size,
          eps);
    }));
  } else {
    AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()"); 
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
          grad_output.data<scalar_t>(),
          input.data<scalar_t>(),
          mean.data<accscalar_t>(),
          var.data<accscalar_t>(),
          weight.data<scalar_t>(),
          mean_dy.data<accscalar_t>(),
          mean_dy_xmu.data<accscalar_t>(),
          grad_input.data<scalar_t>(),
          space_size,
          eps);
    }));
  }
  
  return grad_input;
}

std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes, const at::Tensor var_biased, int numel) {
  const auto feature_size = mean_feature_nodes.size(0);
  const auto world_size = mean_feature_nodes.size(1);

  at::Tensor out_var = at::empty({feature_size}, var_biased.options());
  at::Tensor out_var_biased = at::empty_like(out_var);
  at::Tensor out_mean = at::empty_like(out_var);

  // TODO(jie): tile this for memory coalescing!
  const dim3 block(world_size);
  const dim3 grid(feature_size);
  // shared memory used for reduce on mean, var, num_elements;
  auto stream = at::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(mean_feature_nodes.type(), "welford_parallel_kernel", ([&] {
    using accscalar_t = at::acc_type<scalar_t, true>;
    welford_kernel_parallel<scalar_t, accscalar_t><<<grid, block, 0, stream>>>(
        mean_feature_nodes.data<scalar_t>(),
        var_biased.data<scalar_t>(),
        out_mean.data<scalar_t>(),
        out_var.data<scalar_t>(),
        out_var_biased.data<scalar_t>(),
        world_size,
        feature_size,
        numel);
  }));

  return {out_mean, out_var, out_var_biased};
}