encoding.py 5.45 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
sync BN  
Hang Zhang committed
11
"""Functions for Encoding Layer"""
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
12
13
14
15
import torch
from torch.autograd import Function, Variable
from .._ext import encoding_lib

Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
16
__all__ = ['aggregate', 'scaledL2']
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
17
18

class _aggregate(Function):
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
19
    @staticmethod
Hang Zhang's avatar
sync BN  
Hang Zhang committed
20
    def forward(ctx, A, X, C):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
21
        # A \in(BxNxK) R \in(BxNxKxD) => E \in(BxNxD)
Hang Zhang's avatar
sync BN  
Hang Zhang committed
22
23
        ctx.save_for_backward(A, X, C)
        B, _, K = A.size()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
24
25
        D = X.size(2)
        with torch.cuda.device_of(A):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
26
            E = A.new(B, K, D)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
27
28
        if isinstance(A, torch.cuda.FloatTensor):
            with torch.cuda.device_of(A):
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
29
                encoding_lib.Encoding_Float_aggregate_forward(E, A, X, C)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
30
31
        elif isinstance(A, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(A):
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
32
                encoding_lib.Encoding_Double_aggregate_forward(E, A, X, C)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
33
34
35
36
        else:
            raise RuntimeError('Unimplemented data type!')
        return E

Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
37
    @staticmethod
Hang Zhang's avatar
sync BN  
Hang Zhang committed
38
39
    def backward(ctx, gradE):
        A, X, C = ctx.saved_variables
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
40
        with torch.cuda.device_of(A):
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
41
42
43
44
45
            gradA = Variable(A.data.new().resize_as_(A.data))
            gradX = Variable(A.data.new().resize_as_(X.data))
            gradC = Variable(A.data.new().resize_as_(C.data))
        if isinstance(A.data, torch.cuda.FloatTensor):
            with torch.cuda.device_of(A.data):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
46
                encoding_lib.Encoding_Float_aggregate_backward(gradA.data, \
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
47
48
49
                    gradE.data, A.data, X.data, C.data)
        elif isinstance(A.data, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(A.data):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
50
                encoding_lib.Encoding_Double_aggregate_backward(gradA.data, \
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
51
                    gradE.data, A.data, X.data, C.data)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
52
53
        else:
            raise RuntimeError('Unimplemented data type!')
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
54
55
        gradX.data.copy_(torch.bmm(A, gradE).data)
        gradC.data.copy_((-gradE*A.sum(1).unsqueeze(2)).sum(0).data)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
56
57
58
59
        return gradA, gradX, gradC

def aggregate(A, X, C):
    r"""
Hang Zhang's avatar
sync BN  
Hang Zhang committed
60
61
    Aggregate operation, aggregate the residuals of inputs (:math:`X`) with repect
    to the codewords (:math:`C`) with assignment weights (:math:`A`).
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
62
63
64
65
66

    .. math::
        e_{k} = \sum_{i=1}^{N} a_{ik} (x_i - d_k)

    Shape:
Hang Zhang's avatar
sync BN  
Hang Zhang committed
67
68
69
70
        - Input: :math:`A\in\mathcal{R}^{B\times N\times K}`
          :math:`X\in\mathcal{R}^{B\times N\times D}` :math:`C\in\mathcal{R}^{K\times D}`
          (where :math:`B` is batch, :math:`N` is total number of features,
          :math:`K` is number is codewords, :math:`D` is feature dimensions.)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
71
72
73
74
75
76
77
78
79
80
81
        - Output: :math:`E\in\mathcal{R}^{B\times K\times D}`

    Examples:
        >>> B,N,K,D = 2,3,4,5
        >>> A = Variable(torch.cuda.DoubleTensor(B,N,K).uniform_(-0.5,0.5), requires_grad=True)
        >>> X = Variable(torch.cuda.DoubleTensor(B,N,D).uniform_(-0.5,0.5), requires_grad=True)
        >>> C = Variable(torch.cuda.DoubleTensor(K,D).uniform_(-0.5,0.5), requires_grad=True)
        >>> func = encoding.aggregate()
        >>> E = func(A, X, C)

    """
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
82
    return _aggregate.apply(A, X, C)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
83
84

class _scaledL2(Function):
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
85
    @staticmethod
Hang Zhang's avatar
sync BN  
Hang Zhang committed
86
87
    def forward(ctx, X, C, S):
        B, N, _ = X.size()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
88
89
        K = C.size(0)
        with torch.cuda.device_of(X):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
90
            SL = X.new(B, N, K)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
91
92
93
94
95
96
97
98
        if isinstance(X, torch.cuda.FloatTensor):
            with torch.cuda.device_of(X):
                encoding_lib.Encoding_Float_scaledl2_forward(SL, X, C, S)
        elif isinstance(X, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(X):
                encoding_lib.Encoding_Double_scaledl2_forward(SL, X, C, S)
        else:
            raise RuntimeError('Unimplemented data type!')
Hang Zhang's avatar
sync BN  
Hang Zhang committed
99
        ctx.save_for_backward(X, C, S, SL)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
100
        return SL
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
101
102

    @staticmethod
Hang Zhang's avatar
sync BN  
Hang Zhang committed
103
104
    def backward(ctx, gradSL):
        X, C, S, SL = ctx.saved_variables
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
105
        K = C.size(0)
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
106
107
108
109
110
111
        with torch.cuda.device_of(X.data):
            gradX = Variable(X.data.new().resize_as_(X.data))
            gradC = Variable(X.data.new().resize_as_(C.data))
            gradS = Variable(X.data.new().resize_as_(S.data))
        if isinstance(X.data, torch.cuda.FloatTensor):
            with torch.cuda.device_of(X.data):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
112
                encoding_lib.Encoding_Float_scaledl2_backward(gradSL.data, \
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
113
114
115
                    gradX.data, gradC.data, X.data, C.data, S.data)
        elif isinstance(X.data, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(X.data):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
116
                encoding_lib.Encoding_Double_scaledl2_backward(gradSL.data, \
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
117
                    gradX.data, gradC.data, X.data, C.data, S.data)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
118
119
        else:
            raise RuntimeError('Unimplemented data type!')
Hang Zhang's avatar
sync BN  
Hang Zhang committed
120
        gradS.data.copy_((gradSL*(SL/S.view(1, 1, K))).sum(0).sum(0).data)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
121
122
123
124
125
126
127
128
129
130
131
        return gradX, gradC, gradS


def scaledL2(X, C, S):
    r"""
    scaledL2 distance

    .. math::
        sl_{ik} = s_k \|x_i-c_k\|^2

    Shape:
Hang Zhang's avatar
sync BN  
Hang Zhang committed
132
133
134
135
        - Input: :math:`X\in\mathcal{R}^{B\times N\times D}`
          :math:`C\in\mathcal{R}^{K\times D}` :math:`S\in \mathcal{R}^K`
          (where :math:`B` is batch, :math:`N` is total number of features,
          :math:`K` is number is codewords, :math:`D` is feature dimensions.)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
136
137
138
        - Output: :math:`E\in\mathcal{R}^{B\times N\times K}`

    """
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
139
    return _scaledL2.apply(X, C, S)