"vscode:/vscode.git/clone" did not exist on "34858ae1d9e11dc51100b26ac468770c81c8afc1"
Commit 070507df authored by Rémi Louf's avatar Rémi Louf
Browse files

format utils for summarization

parent da10de84
......@@ -128,7 +128,7 @@ def build_mask(sequence, pad_token):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask = torch.ones_like(sequence)
idx_pad_tokens = (sequence == pad_token)
idx_pad_tokens = sequence == pad_token
mask[idx_pad_tokens] = 0
return mask
......
......@@ -105,9 +105,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
def test_build_mask_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 1, 1, 1])
np.testing.assert_array_equal(
build_mask(sequence, 0).numpy(), expected.numpy()
)
np.testing.assert_array_equal(build_mask(sequence, 0).numpy(), expected.numpy())
def test_build_mask(self):
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
......@@ -119,9 +117,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
def test_build_mask_with_padding_equal_to_one(self):
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
np.testing.assert_array_equal(
build_mask(sequence, 1).numpy(), expected.numpy()
)
np.testing.assert_array_equal(build_mask(sequence, 1).numpy(), expected.numpy())
def test_compute_token_type_ids(self):
separator = 101
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment