"vscode:/vscode.git/clone" did not exist on "59cd9de39da4c406e12c1785f70ee73806ebc6ba"
Unverified Commit a6271967 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Override _pad in LEDTokenizer to deal with global_attention_mask (#15940)



* Override _pad in LEDTokenizer

* Override _pad in LEDTokenizerFast

* add Copied from

* calling the super method

* add comment about -1
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent cb2b0276
......@@ -13,6 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LED."""
from typing import Dict, Optional, Union
from ...file_utils import PaddingStrategy
from ...tokenization_utils_base import BatchEncoding, EncodedInput
from ...utils import logging
from ..bart.tokenization_bart import BartTokenizer
......@@ -48,3 +53,45 @@ class LEDTokenizer(BartTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
encoded_inputs = super()._pad(
encoded_inputs=encoded_inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
# Load from model defaults
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
if return_attention_mask and "global_attention_mask" in encoded_inputs:
required_input = encoded_inputs[self.model_input_names[0]]
# `global_attention_mask` need to have the same length as other (sequential) inputs.
needs_to_be_padded = len(encoded_inputs["global_attention_mask"]) != len(required_input)
if needs_to_be_padded:
difference = len(required_input) - len(encoded_inputs["global_attention_mask"])
if self.padding_side == "right":
# Use `-1` since `0` in `global_attention_mask` means `local attention` instead of `not to attend`
encoded_inputs["global_attention_mask"] = (
encoded_inputs["global_attention_mask"] + [-1] * difference
)
elif self.padding_side == "left":
encoded_inputs["global_attention_mask"] = [-1] * difference + encoded_inputs[
"global_attention_mask"
]
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
return encoded_inputs
......@@ -13,6 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LED."""
from typing import Dict, Optional, Union
from ...file_utils import PaddingStrategy
from ...tokenization_utils_base import BatchEncoding, EncodedInput
from ...utils import logging
from ..bart.tokenization_bart_fast import BartTokenizerFast
from .tokenization_led import LEDTokenizer
......@@ -50,3 +55,46 @@ class LEDTokenizerFast(BartTokenizerFast):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
slow_tokenizer_class = LEDTokenizer
# Copied from transformers.models.led.tokenization_led.LEDTokenizer._pad
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
encoded_inputs = super()._pad(
encoded_inputs=encoded_inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
# Load from model defaults
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
if return_attention_mask and "global_attention_mask" in encoded_inputs:
required_input = encoded_inputs[self.model_input_names[0]]
# `global_attention_mask` need to have the same length as other (sequential) inputs.
needs_to_be_padded = len(encoded_inputs["global_attention_mask"]) != len(required_input)
if needs_to_be_padded:
difference = len(required_input) - len(encoded_inputs["global_attention_mask"])
if self.padding_side == "right":
# Use `-1` since `0` in `global_attention_mask` means `local attention` instead of `not to attend`
encoded_inputs["global_attention_mask"] = (
encoded_inputs["global_attention_mask"] + [-1] * difference
)
elif self.padding_side == "left":
encoded_inputs["global_attention_mask"] = [-1] * difference + encoded_inputs[
"global_attention_mask"
]
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
return encoded_inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment