extending.rst 3.9 KB
Newer Older
Hang Zhang's avatar
docs  
Hang Zhang committed
1
2
3
4
5
6
7
8
9
Extending PyTorch-Encoding
==========================

In this note we'll discuss extending PyTorch-Encoding package,
which is extending :mod:`torch.nn` and
:mod:`torch.autograd` with custom CUDA backend.

Torch C and CUDA Backend
------------------------
Hang Zhang's avatar
Hang Zhang committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

Given an example of the residual operation (in a mini-batch): 

.. math::
    r_{ik} = x_i - c_k

where the inputs are :math:`X=\{x_1, ...x_N\}` and :math:`C=\{c_1,...c_k\}` and the output is :math:`R=\{r_{ik}\}`. 


- Add CUDA kernel function and expose a C API to the generic file ``encoding/kernel/generic/encoding_kernel.c`` using Torch generic files::

    __global__ void Encoding_(Residual_Forward_kernel) (
        THCDeviceTensor<real, 4> R,
        THCDeviceTensor<real, 3> X,
        THCDeviceTensor<real, 2> D)
    /*
     * residual forward kernel function
     */
    {
        /* declarations of the variables */
        int b, k, d, i, K;
        /* Get the index and channels */ 
        b = blockIdx.z;
        d = blockIdx.x * blockDim.x + threadIdx.x;
        i = blockIdx.y * blockDim.y + threadIdx.y;
        K = R.getSize(2);
        /* boundary check for output */
        if (d >= X.getSize(2) || i >= X.getSize(1))    return;
        /* main operation */
        for(k=0; k<K; k++) {
            R[b][i][k][d] = X[b][i][d].ldg() - D[k][d].ldg();
        }
    }

    void Encoding_(Residual_Forward)(
        THCState *state, THCTensor *R_, THCTensor *X_, THCTensor *D_)
    /*
     * residual forward 
     */
    {
        /* Check the GPU index and tensor dims*/
        THCTensor_(checkGPU)(state, 3, R_, X_, D_); 
        if (THCTensor_(nDimension)(state, R_) != 4 ||
            THCTensor_(nDimension)(state, X_) != 3 ||
            THCTensor_(nDimension)(state, D_) != 2)
        THError("Encoding: incorrect input dims. \n");
        /* Device tensors */
        THCDeviceTensor<real, 4> R = devicetensor<4>(state, R_);
        THCDeviceTensor<real, 3> X = devicetensor<3>(state, X_);
        THCDeviceTensor<real, 2> D = devicetensor<2>(state, D_);
        /* kernel function */
        cudaStream_t stream = THCState_getCurrentStream(state);
        dim3 threads(16, 16);
        dim3 blocks(X.getSize(2)/16+1, X.getSize(1)/16+1, 
                    X.getSize(0));
        Encoding_(Residual_Forward_kernel)<<<blocks, threads, 0, stream>>>(R, X, D);
        THCudaCheck(cudaGetLastError());
    }

- Add corresponding function header to ``encoding/kernel/generic/encoding_kernel.h``::

    void Encoding_(Residual_Forward)(
        THCState *state, THCTensor *R_, THCTensor *X_, THCTensor *D_);

- Add a CFFI function to ``encoding/src/generic/encoding_generic.c``, which calls the C API we just write::

    int Encoding_(residual_forward)(THCTensor *R, THCTensor *X, THCTensor *D)
    /*
     * Residual operation
     */
    {
        Encoding_(Residual_Forward)(state, R, X, D);
        /* C function return number of the outputs */
        return 0;
    }

- Add corresponding function header to ``encoding/src/encoding_lib.h``::
    
    int Encoding_Float_residual_forward(THCudaTensor *R, THCudaTensor *X, 
        THCudaTensor *D);

- Finally, call this function using python::

    class residual(Function):
        def forward(self, X, C):
            # X \in(BxNxD) D \in(KxD) R \in(BxNxKxD) 
            B, N, D = X.size()
            K = C.size(0)
            with torch.cuda.device_of(X):
                R = X.new(B,N,K,D)
            if isinstance(X, torch.cuda.FloatTensor):
                with torch.cuda.device_of(X):
                    encoding_lib.Encoding_Float_residual_forward(R, X, C)
            elif isinstance(X, torch.cuda.DoubleTensor):
                with torch.cuda.device_of(X):
                    encoding_lib.Encoding_Double_residual_forward(R, X, C)
            else:
                raise RuntimeError('Unimplemented data type!')
            return R

- Note this is just an example. You also need to implement backward function for ``residual`` operation.