RNN_tests.py 3.95 KB
Newer Older
Christian Sarofeen's avatar
Christian Sarofeen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
from torch.autograd import Variable
import apex
from apex.RNN.models import bidirectionalRNN, stackedRNN, RNNCell
from torch.nn._functions.rnn import LSTMCell
import itertools


torch.backends.cudnn.enabled=False

batch_first = False #not implemented yet
dropout = 0.0 #How to validate?
bidirectional = False #True works, but differs in definition to PyTorch

rnn_types = ['LSTM', 'GRU', 'ReLU', 'Tanh']
sizes = [8,4,2]

seq_sizes = sizes
hidden_sizes = sizes
inp_sizes = sizes
batch_sizes = sizes
num_layerss = sizes

biases = [True]

def copy_param_set(pyt_rnn, my_rnn, layer=0, reverse=False):
    my_params = None

    rnn = None
    if isinstance(my_rnn, bidirectionalRNN):
        rnn = my_rnn.fwd.rnns[layer] if not reverse else my_rnn.bckwrd.rnns[layer]
    elif isinstance(my_rnn, stackedRNN):
        rnn = my_rnn.rnns[layer]
    else:
        raise RuntimeError()

    param_names = ['w_ih', 'w_hh', 'b_ih', 'b_hh']

    if not hasattr(rnn, 'b_hh'):
        param_names = param_names[:2]
    my_params = [getattr(rnn, param_name) for param_name in param_names]
        
    pyt_params = None
    param_names = ['weight_ih_', 'weight_hh_', 'bias_ih_', 'bias_hh_']
    reverse_str = '_reverse' if reverse else ''

    if not hasattr(pyt_rnn, 'bias_hh_l0'):
        param_names=param_names[:2]
    pyt_params =[getattr(pyt_rnn, param_name + 'l' + str(layer) + reverse_str )
                 for param_name in param_names ]
    for pyt_param, my_param in zip(pyt_params, my_params):
        pyt_param.data.copy_(my_param.data)

def copy_all_params(pyt_rnn, my_rnn):
    for layer in range(num_layers):
        copy_param_set(pyt_rnn, my_rnn, layer)
        if bidirectional:
            copy_param_set(pyt_rnn, my_rnn, layer, bidirectional)


def compare_variables(v1, v2, msg, params):
    diff = float((v1.data-v2.data).abs().max())
    if diff > 1e-5:
        print("Error of ", diff, " found for ", msg, " for case: ", str(params))
    
def compare_tuple_variables(t1, t2, msg, params):
    for var1, var2 in zip(t1, t2):
        compare_variables(var1, var2, msg, params)

def maybe_compare(v1, v2, msg, params):
    if isinstance(v1, Variable) and isinstance(v2, Variable):
        compare_variables(v1, v2, msg, params)
    else:
        compare_tuple_variables(v1, v2, msg, params)

product = list(itertools.product(rnn_types, seq_sizes, hidden_sizes, inp_sizes, batch_sizes, num_layerss, biases))

for test_case in product:
    rnn_type, seq_size, hidden_size, inp_size, batch_size, num_layers, bias = test_case

    inp = torch.cuda.FloatTensor(seq_size, batch_size, inp_size).uniform_()

    if rnn_type == 'ReLU' or rnn_type == 'Tanh':
        pytorch_rnn = nn.RNN(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, nonlinearity=rnn_type.lower()).cuda()
    else:
        pytorch_rnn =     getattr(nn, rnn_type)(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional).cuda()
    my_rnn = getattr(apex.RNN.models, rnn_type)(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional).cuda()
    
    copy_all_params(pytorch_rnn, my_rnn)

    pyt_inp = Variable(inp, requires_grad=True)
    my_inp  = Variable(inp, requires_grad=True)

    my_out, my_hiddens =  my_rnn(my_inp)
    pyt_out, pyt_hiddens = pytorch_rnn(pyt_inp)

    pyt_out.sum().backward()
    my_out.sum().backward()


    maybe_compare(pyt_out, my_out, "out", test_case)

    #If there's only one hidden state PyTorch doesn't return it in a tuple,
    #apex does, so we wrap PyTorch's returned hidden state in a tuple.
    if not isinstance(pyt_hiddens, tuple):
        pyt_hiddens = (pyt_hiddens,)

    try:
        for i, (pyt_hid, my_hid) in enumerate(zip(pyt_hiddens, my_hiddens)):
            maybe_compare(pyt_hid, my_hid , "hx_"+str(i), test_case)
    except ValueError:
        maybe_compare(pyt_hiddens, my_hiddens , "hx_0", test_case)
        
        
    maybe_compare(pyt_inp.grad, my_inp.grad, "inp.grad", test_case)

print("Test passed.")