Unverified Commit 4a18337b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Honor existing attention mask in tokenzier.pad (#13926)

* Honor existing attention mask in tokenzier.pad

* Fix initialization of attention mask

* Roll the implem on all subclasses

* Fix tests
parent 3c0c699f
......@@ -267,30 +267,31 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) < max_length
if return_attention_mask and "attention_mask" not in processed_features:
processed_features["attention_mask"] = np.ones(len(required_input), dtype=np.int32)
if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
if return_attention_mask:
attention_mask = np.zeros(max_length, dtype=np.int32)
attention_mask[: len(required_input)] = 1
processed_features["attention_mask"] = attention_mask
processed_features["attention_mask"] = np.pad(
processed_features["attention_mask"], (0, difference)
)
padding_shape = ((0, difference), (0, 0)) if self.feature_size > 1 else (0, difference)
processed_features[self.model_input_names[0]] = np.pad(
required_input, padding_shape, "constant", constant_values=self.padding_value
)
elif self.padding_side == "left":
if return_attention_mask:
attention_mask = np.zeros(max_length, dtype=np.int32)
attention_mask[-len(required_input) :] = 1
processed_features["attention_mask"] = attention_mask
processed_features["attention_mask"] = np.pad(
processed_features["attention_mask"], (difference, 0)
)
padding_shape = ((difference, 0), (0, 0)) if self.feature_size > 1 else (difference, 0)
processed_features[self.model_input_names[0]] = np.pad(
required_input, padding_shape, "constant", constant_values=self.padding_value
)
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask and "attention_mask" not in processed_features:
processed_features["attention_mask"] = np.ones(len(required_input), dtype=np.int32)
return processed_features
......
......@@ -1232,11 +1232,15 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(required_input)
if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
......@@ -1250,7 +1254,7 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
elif self.padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
......@@ -1264,8 +1268,6 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(required_input)
return encoded_inputs
......
......@@ -716,11 +716,15 @@ class LayoutLMv2TokenizerFast(PreTrainedTokenizerFast):
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(required_input)
if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
......@@ -734,7 +738,7 @@ class LayoutLMv2TokenizerFast(PreTrainedTokenizerFast):
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
elif self.padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
......@@ -748,8 +752,6 @@ class LayoutLMv2TokenizerFast(PreTrainedTokenizerFast):
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(required_input)
return encoded_inputs
......
......@@ -1460,17 +1460,23 @@ class LukeTokenizer(RobertaTokenizer):
or (entities_provided and len(encoded_inputs["entity_ids"]) != max_entity_length)
)
# Initialize attention mask if not present.
if return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
if entities_provided and return_attention_mask and "entity_attention_mask" not in encoded_inputs:
encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"])
if needs_to_be_padded:
difference = max_length - len(encoded_inputs["input_ids"])
if entities_provided:
entity_difference = max_entity_length - len(encoded_inputs["entity_ids"])
if self.padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
if entities_provided:
encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"]) + [
0
] * entity_difference
encoded_inputs["entity_attention_mask"] = (
encoded_inputs["entity_attention_mask"] + [0] * entity_difference
)
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [0] * difference
if entities_provided:
......@@ -1495,11 +1501,11 @@ class LukeTokenizer(RobertaTokenizer):
elif self.padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"])
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if entities_provided:
encoded_inputs["entity_attention_mask"] = [0] * entity_difference + [1] * len(
encoded_inputs["entity_ids"]
)
encoded_inputs["entity_attention_mask"] = [0] * entity_difference + encoded_inputs[
"entity_attention_mask"
]
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [0] * difference + encoded_inputs["token_type_ids"]
if entities_provided:
......@@ -1523,11 +1529,6 @@ class LukeTokenizer(RobertaTokenizer):
]
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
if entities_provided:
encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"])
return encoded_inputs
......
......@@ -1819,11 +1819,15 @@ class TapasTokenizer(PreTrainedTokenizer):
padding_strategy != PaddingStrategy.DO_NOT_PAD and len(encoded_inputs["input_ids"]) != max_length
)
# Initialize attention mask if not present.
if return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
if needs_to_be_padded:
difference = max_length - len(encoded_inputs["input_ids"])
if self.padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] + [[self.pad_token_type_id] * 7] * difference
......@@ -1841,7 +1845,7 @@ class TapasTokenizer(PreTrainedTokenizer):
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
elif self.padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"])
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [[self.pad_token_type_id] * 7] * difference + encoded_inputs[
"token_type_ids"
......@@ -1859,9 +1863,6 @@ class TapasTokenizer(PreTrainedTokenizer):
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
return encoded_inputs
......
......@@ -3110,11 +3110,17 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(required_input)
if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
......@@ -3124,7 +3130,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
elif self.padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
......@@ -3134,8 +3140,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(required_input)
return encoded_inputs
......
......@@ -1460,6 +1460,25 @@ class TokenizerTesterMixin:
pad_to_multiple_of=8,
)
def test_padding_with_attention_mask(self):
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
if tokenizer.pad_token is None:
self.skipTest("No padding token.")
if "attention_mask" not in tokenizer.model_input_names:
self.skipTest("This model does not use attention mask.")
features = [
{"input_ids": [1, 2, 3, 4, 5, 6], "attention_mask": [1, 1, 1, 1, 1, 0]},
{"input_ids": [1, 2, 3], "attention_mask": [1, 1, 0]},
]
padded_features = tokenizer.pad(features)
if tokenizer.padding_side == "right":
self.assertListEqual(padded_features["attention_mask"], [[1, 1, 1, 1, 1, 0], [1, 1, 0, 0, 0, 0]])
else:
self.assertListEqual(padded_features["attention_mask"], [[1, 1, 1, 1, 1, 0], [0, 0, 0, 1, 1, 0]])
def test_encode_plus_with_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment