Unverified Commit 39b5d1a6 authored by SaulLu's avatar SaulLu Committed by GitHub
Browse files

fix set truncation attribute in `__init__` of `PreTrainedTokenizerBase` (#15456)



* change truncation_side in init of `PreTrainedTokenizerBase`
Co-authored-by: default avatarLSinev <LSinev@users.noreply.github.com>

* add test

* Revert "replace assert with exception for `padding_side` arg in `PreTrainedTokenizerBase` `__init__`"

This reverts commit 7a98b87962d2635c7e4d4f00db3948b694624843.

* fix kwargs

* Revert "fix kwargs"

This reverts commit 67b0a5270e8cf1dbf70e6b0232e94c0452b6946f.

* Update tests/test_tokenization_common.py
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>

* delete truncation_side variable

* reorganize test

* format

* complete doc

* Revert "Revert "replace assert with exception for `padding_side` arg in `PreTrainedTokenizerBase` `__init__`""

This reverts commit d5a10a7e2680539e5d9e98ae5d896c893d224b80.

* fix typo

* fix typos to render documentation

* Revert "Revert "Revert "replace assert with exception for `padding_side` arg in `PreTrainedTokenizerBase` `__init__`"""

This reverts commit 16cf58811943a08f43409a7c83eaa330686591d0.

* format
Co-authored-by: default avatarLSinev <LSinev@users.noreply.github.com>
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 45cac3fa
...@@ -1383,6 +1383,8 @@ INIT_TOKENIZER_DOCSTRING = r""" ...@@ -1383,6 +1383,8 @@ INIT_TOKENIZER_DOCSTRING = r"""
- **model_input_names** (`List[str]`) -- A list of inputs expected in the forward pass of the model. - **model_input_names** (`List[str]`) -- A list of inputs expected in the forward pass of the model.
- **padding_side** (`str`) -- The default value for the side on which the model should have padding applied. - **padding_side** (`str`) -- The default value for the side on which the model should have padding applied.
Should be `'right'` or `'left'`. Should be `'right'` or `'left'`.
- **truncation_side** (`str`) -- The default value for the side on which the model should have truncation
applied. Should be `'right'` or `'left'`.
Args: Args:
model_max_length (`int`, *optional*): model_max_length (`int`, *optional*):
...@@ -1393,6 +1395,9 @@ INIT_TOKENIZER_DOCSTRING = r""" ...@@ -1393,6 +1395,9 @@ INIT_TOKENIZER_DOCSTRING = r"""
padding_side (`str`, *optional*): padding_side (`str`, *optional*):
The side on which the model should have padding applied. Should be selected between ['right', 'left']. The side on which the model should have padding applied. Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name. Default value is picked from the class attribute of the same name.
truncation_side (`str`, *optional*):
The side on which the model should have truncation applied. Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
model_input_names (`List[string]`, *optional*): model_input_names (`List[string]`, *optional*):
The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or
`"attention_mask"`). Default value is picked from the class attribute of the same name. `"attention_mask"`). Default value is picked from the class attribute of the same name.
...@@ -1456,12 +1461,20 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1456,12 +1461,20 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None)) model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
# Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed. # Padding and truncation side are right by default and overridden in subclasses. If specified in the kwargs, it
# is changed.
self.padding_side = kwargs.pop("padding_side", self.padding_side) self.padding_side = kwargs.pop("padding_side", self.padding_side)
if self.padding_side not in ["right", "left"]: if self.padding_side not in ["right", "left"]:
raise ValueError( raise ValueError(
f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
) )
self.truncation_side = kwargs.pop("truncation_side", self.truncation_side)
if self.truncation_side not in ["right", "left"]:
raise ValueError(
f"Padding side should be selected between 'right' and 'left', current value: {self.truncation_side}"
)
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
self.deprecation_warnings = ( self.deprecation_warnings = (
......
...@@ -1415,6 +1415,47 @@ class TokenizerTesterMixin: ...@@ -1415,6 +1415,47 @@ class TokenizerTesterMixin:
**kwargs, **kwargs,
) )
def test_truncation_side_in_kwargs(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
if self.test_rust_tokenizer:
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
pretrained_name, truncation_side="left", **kwargs
)
self.assertEqual(tokenizer_r.truncation_side, "left")
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
pretrained_name, truncation_side="right", **kwargs
)
self.assertEqual(tokenizer_r.truncation_side, "right")
self.assertRaises(
ValueError,
self.rust_tokenizer_class.from_pretrained,
pretrained_name,
truncation_side="unauthorized",
**kwargs,
)
if self.test_slow_tokenizer:
tokenizer_p = self.tokenizer_class.from_pretrained(
pretrained_name, truncation_side="left", **kwargs
)
self.assertEqual(tokenizer_p.truncation_side, "left")
tokenizer_p = self.tokenizer_class.from_pretrained(
pretrained_name, truncation_side="right", **kwargs
)
self.assertEqual(tokenizer_p.truncation_side, "right")
self.assertRaises(
ValueError,
self.tokenizer_class.from_pretrained,
pretrained_name,
truncation_side="unauthorized",
**kwargs,
)
def test_right_and_left_padding(self): def test_right_and_left_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False) tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: for tokenizer in tokenizers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment