"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2e5203be043f107eae5c1b6788584d199f403286"
Commit da10de84 authored by Rémi Louf's avatar Rémi Louf
Browse files

fix bug with padding mask + add corresponding test

parent 3b0d2fa3
...@@ -127,9 +127,9 @@ def build_lm_labels(sequence, pad_token): ...@@ -127,9 +127,9 @@ def build_lm_labels(sequence, pad_token):
def build_mask(sequence, pad_token): def build_mask(sequence, pad_token):
""" Builds the mask. The attention mechanism will only attend to positions """ Builds the mask. The attention mechanism will only attend to positions
with value 1. """ with value 1. """
mask = sequence.clone() mask = torch.ones_like(sequence)
mask[mask != pad_token] = 1 idx_pad_tokens = (sequence == pad_token)
mask[mask == pad_token] = 0 mask[idx_pad_tokens] = 0
return mask return mask
......
...@@ -116,6 +116,13 @@ class SummarizationDataProcessingTest(unittest.TestCase): ...@@ -116,6 +116,13 @@ class SummarizationDataProcessingTest(unittest.TestCase):
build_mask(sequence, 23).numpy(), expected.numpy() build_mask(sequence, 23).numpy(), expected.numpy()
) )
def test_build_mask_with_padding_equal_to_one(self):
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
np.testing.assert_array_equal(
build_mask(sequence, 1).numpy(), expected.numpy()
)
def test_compute_token_type_ids(self): def test_compute_token_type_ids(self):
separator = 101 separator = 101
batch = torch.tensor( batch = torch.tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment