rectify.py 1.63 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

"""Rectify function"""
import torch
from torch.autograd import Function

Hang Zhang's avatar
Hang Zhang committed
13
14
15
from encoding import cpu
if torch.cuda.device_count() > 0:
    from encoding import gpu
Hang Zhang's avatar
Hang Zhang committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

__all__ = ['rectify']

class _rectify(Function):
    @staticmethod
    def forward(ctx, y, x, kernel_size, stride, padding, dilation, average):
        ctx.save_for_backward(x)
        # assuming kernel_size is 3
        kernel_size = [k + 2 * (d - 1) for k,d in zip(kernel_size, dilation)]
        ctx.kernel_size = kernel_size
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.average = average
        if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
31
            gpu.conv_rectify(y, x, kernel_size, stride, padding, dilation, average)
Hang Zhang's avatar
Hang Zhang committed
32
        else:
Hang Zhang's avatar
Hang Zhang committed
33
            cpu.conv_rectify(y, x, kernel_size, stride, padding, dilation, average)
Hang Zhang's avatar
Hang Zhang committed
34
35
36
37
38
39
40
        ctx.mark_dirty(y)
        return y

    @staticmethod
    def backward(ctx, grad_y):
        x, = ctx.saved_variables
        if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
41
            gpu.conv_rectify(grad_y, x, ctx.kernel_size, ctx.stride,
Hang Zhang's avatar
Hang Zhang committed
42
43
                                 ctx.padding, ctx.dilation, ctx.average)
        else:
Hang Zhang's avatar
Hang Zhang committed
44
            cpu.conv_rectify(grad_y, x, ctx.kernel_size, ctx.stride,
Hang Zhang's avatar
Hang Zhang committed
45
46
47
48
49
                                 ctx.padding, ctx.dilation, ctx.average)
        ctx.mark_dirty(grad_y)
        return grad_y, None, None, None, None, None, None

rectify = _rectify.apply