activation_kernel.cu 1.32 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
#include <torch/extension.h>
#include <ATen/ATen.h>
Hang Zhang's avatar
Hang Zhang committed
3
#include <vector>
Hang Zhang's avatar
Hang Zhang committed
4
5
6
7

#include <cuda_runtime_api.h>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
Hang Zhang's avatar
Hang Zhang committed
8
#include "common.h"
Hang Zhang's avatar
Hang Zhang committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44


namespace {

template<typename T>
inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
  // Create thrust pointers
  thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
  thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);

  thrust::transform_if(th_dz, th_dz + count, th_z, th_dz,
                       [slope] __device__ (const T& dz) { return dz * slope; },
                       [] __device__ (const T& z) { return z < 0; });
  thrust::transform_if(th_z, th_z + count, th_z,
                       [slope] __device__ (const T& z) { return z / slope; },
                       [] __device__ (const T& z) { return z < 0; });
}

}

void LeakyRelu_Forward_CUDA(at::Tensor z, float slope) {
  at::leaky_relu_(z, slope);
}

void LeakyRelu_Backward_CUDA(at::Tensor z, at::Tensor dz, float slope) {
  int64_t count = z.numel();

  AT_DISPATCH_FLOATING_TYPES(z.type(), "LeakyRelu_Backward_CUDA", ([&] {
    leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count);
  }));
  /*
  // unstable after scaling
  at::leaky_relu_(z, 1.0 / slope);
  at::leaky_relu_backward(dz, z, slope);
  */
}