Unverified Commit 6e4d370a authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

More updates for PyTorch (#114)

parent 3bde773d
...@@ -96,9 +96,9 @@ class LSTMEncoder(FairseqEncoder): ...@@ -96,9 +96,9 @@ class LSTMEncoder(FairseqEncoder):
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
if LanguagePairDataset.LEFT_PAD_SOURCE: if LanguagePairDataset.LEFT_PAD_SOURCE:
# convert left-padding to right-padding # convert left-padding to right-padding
src_tokens.data = utils.convert_padding_direction( src_tokens = utils.convert_padding_direction(
src_tokens.data, src_tokens,
src_lengths.data, src_lengths,
self.padding_idx, self.padding_idx,
left_to_right=True, left_to_right=True,
) )
......
...@@ -289,8 +289,6 @@ def convert_padding_direction( ...@@ -289,8 +289,6 @@ def convert_padding_direction(
right_to_left=False, right_to_left=False,
left_to_right=False, left_to_right=False,
): ):
assert not isinstance(src_tokens, Variable)
assert not isinstance(src_lengths, Variable)
assert right_to_left ^ left_to_right assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx) pad_mask = src_tokens.eq(padding_idx)
if pad_mask.max() == 0: if pad_mask.max() == 0:
......
...@@ -61,7 +61,7 @@ class TestUtils(unittest.TestCase): ...@@ -61,7 +61,7 @@ class TestUtils(unittest.TestCase):
def assertAlmostEqual(self, t1, t2): def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch") self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4) self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment