activation_kernel.cu 1.52 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
#include <exception>
Hang Zhang's avatar
Hang Zhang committed
2
3
#include <torch/extension.h>
#include <ATen/ATen.h>
Hang Zhang's avatar
Hang Zhang committed
4
#include <vector>
Hang Zhang's avatar
Hang Zhang committed
5
6
7
8

#include <cuda_runtime_api.h>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
Hang Zhang's avatar
Hang Zhang committed
9
#include "common.h"
Hang Zhang's avatar
Hang Zhang committed
10

Hang Zhang's avatar
Hang Zhang committed
11
using namespace std;
Hang Zhang's avatar
Hang Zhang committed
12
13
14

namespace {

Hang Zhang's avatar
Hang Zhang committed
15
16
17
18
19
20
21
22
23
24
25
26
27
// template<typename T>
// inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
//   // Create thrust pointers
//   thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
//   thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
// 
//   thrust::transform_if(th_dz, th_dz + count, th_z, th_dz,
//                        [slope] __device__ (const T& dz) { return dz * slope; },
//                        [] __device__ (const T& z) { return z < 0; });
//   thrust::transform_if(th_z, th_z + count, th_z,
//                        [slope] __device__ (const T& z) { return z / slope; },
//                        [] __device__ (const T& z) { return z < 0; });
// }
Hang Zhang's avatar
Hang Zhang committed
28
29
30
31
32
33
34
35
36
37

}

void LeakyRelu_Forward_CUDA(at::Tensor z, float slope) {
  at::leaky_relu_(z, slope);
}

void LeakyRelu_Backward_CUDA(at::Tensor z, at::Tensor dz, float slope) {
  int64_t count = z.numel();

Hang Zhang's avatar
Hang Zhang committed
38
  /*
Hang Zhang's avatar
Hang Zhang committed
39
40
41
  AT_DISPATCH_FLOATING_TYPES(z.type(), "LeakyRelu_Backward_CUDA", ([&] {
    leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count);
  }));
Hang Zhang's avatar
Hang Zhang committed
42
  */
Hang Zhang's avatar
Hang Zhang committed
43
44
  // unstable after scaling
  at::leaky_relu_(z, 1.0 / slope);
Hang Zhang's avatar
Hang Zhang committed
45
46
47
  // This API is changed on pytorch side, feature broken
  throw "PyTorch API break, Don't use InplaceABN for now.";
  // at::leaky_relu_backward(dz, z, slope, false);
Hang Zhang's avatar
Hang Zhang committed
48
}