Commit ae0c05d9 authored by Sam Gross's avatar Sam Gross Committed by Myle Ott
Browse files

Fix call ordering to ATen addmm and sum (#22)

parent 7aba6084
......@@ -67,7 +67,7 @@ extern "C" void TemporalConvolutionTBC_forward(
auto W = weight[k];
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
auto O = output.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
at::addmm_out(1, O, 1, I, W, O);
O.addmm_(I, W);
}
}
}
......@@ -108,7 +108,7 @@ extern "C" void TemporalConvolutionTBC_backward(
if (t > 0) {
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
auto dI = dInput.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
at::addmm_out(1, dI, 1, dO, weight[k].t(), dI);
dI.addmm_(dO, weight[k].t());
}
}
......@@ -121,10 +121,10 @@ extern "C" void TemporalConvolutionTBC_backward(
auto dW = dWeight[k];
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes}).t();
at::addmm_out(1, dW, 1, I, dO, dW);
dW.addmm_(I, dO);
}
}
auto tmp = dOutput.sum(0, false);
at::sum_out(tmp, 0, dBias);
dBias.assign_(tmp.sum(0));
}
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
import unittest
from fairseq.modules import ConvTBC
import torch.nn as nn
from torch.autograd import Variable, gradcheck
class TestConvTBC(unittest.TestCase):
def test_convtbc(self):
# ksz, in_channels, out_channels
conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1)
# out_channels, in_channels, ksz
conv1d = nn.Conv1d(4, 5, kernel_size=3, padding=1)
conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2))
conv_tbc.bias.data.copy_(conv1d.bias.data)
input_tbc = Variable(torch.randn(7, 2, 4), requires_grad=True)
input1d = Variable(input_tbc.data.transpose(0, 1).transpose(1, 2), requires_grad=True)
output_tbc = conv_tbc(input_tbc)
output1d = conv1d(input1d)
self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data)
grad_tbc = torch.randn(output_tbc.size())
grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous()
output_tbc.backward(grad_tbc)
output1d.backward(grad1d)
self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data)
self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data)
self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment