Commit 070507df authored by Rémi Louf's avatar Rémi Louf
Browse files

format utils for summarization

parent da10de84
...@@ -128,7 +128,7 @@ def build_mask(sequence, pad_token): ...@@ -128,7 +128,7 @@ def build_mask(sequence, pad_token):
""" Builds the mask. The attention mechanism will only attend to positions """ Builds the mask. The attention mechanism will only attend to positions
with value 1. """ with value 1. """
mask = torch.ones_like(sequence) mask = torch.ones_like(sequence)
idx_pad_tokens = (sequence == pad_token) idx_pad_tokens = sequence == pad_token
mask[idx_pad_tokens] = 0 mask[idx_pad_tokens] = 0
return mask return mask
......
...@@ -105,9 +105,7 @@ class SummarizationDataProcessingTest(unittest.TestCase): ...@@ -105,9 +105,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
def test_build_mask_no_padding(self): def test_build_mask_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4]) sequence = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 1, 1, 1]) expected = torch.tensor([1, 1, 1, 1])
np.testing.assert_array_equal( np.testing.assert_array_equal(build_mask(sequence, 0).numpy(), expected.numpy())
build_mask(sequence, 0).numpy(), expected.numpy()
)
def test_build_mask(self): def test_build_mask(self):
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23]) sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
...@@ -119,9 +117,7 @@ class SummarizationDataProcessingTest(unittest.TestCase): ...@@ -119,9 +117,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
def test_build_mask_with_padding_equal_to_one(self): def test_build_mask_with_padding_equal_to_one(self):
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1]) sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0]) expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
np.testing.assert_array_equal( np.testing.assert_array_equal(build_mask(sequence, 1).numpy(), expected.numpy())
build_mask(sequence, 1).numpy(), expected.numpy()
)
def test_compute_token_type_ids(self): def test_compute_token_type_ids(self):
separator = 101 separator = 101
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment