"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "dc07fc29da8a593f68080fbde0b9161a9a68bd36"
Unverified Commit 03ae1f06 authored by raghavanone's avatar raghavanone Committed by GitHub
Browse files

change the way sentinel tokens can retrived (#20373)

* change the way sentinel tokens can retrived

* Fix line length for doc string

* Fix line length for doc string

* Add more stronger test for t5 tokenization

* Format file changes

* Make a stronger test for filtering sentinel tokens

* fix file format issues
parent 81d82e4f
...@@ -79,12 +79,11 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -79,12 +79,11 @@ class T5Tokenizer(PreTrainedTokenizer):
pad_token (`str`, *optional*, defaults to `"<pad>"`): pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths. The token used for padding, for example when batching sequences of different lengths.
extra_ids (`int`, *optional*, defaults to 100): extra_ids (`int`, *optional*, defaults to 100):
Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be
indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary retrieved by calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids
like in T5 preprocessing see method
[here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)). additional_special_tokens (`List[str]`, *optional*):
additional_special_tokens (`List[str]`, *optional*):
Additional special tokens used by the tokenizer. Additional special tokens used by the tokenizer.
sp_model_kwargs (`dict`, *optional*): sp_model_kwargs (`dict`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
...@@ -213,6 +212,14 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -213,6 +212,14 @@ class T5Tokenizer(PreTrainedTokenizer):
return ([0] * len(token_ids_0)) + [1] return ([0] * len(token_ids_0)) + [1]
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def get_sentinel_tokens(self):
return list(
set(filter(lambda x: bool(re.search("<extra_id_\d+>", x)) is not None, self.additional_special_tokens))
)
def get_sentinel_token_ids(self):
return [self._convert_token_to_id(token) for token in self.get_sentinel_tokens()]
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
"""Do not add eos again if user already added it.""" """Do not add eos again if user already added it."""
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import os import os
import re
import warnings import warnings
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -90,11 +91,9 @@ class T5TokenizerFast(PreTrainedTokenizerFast): ...@@ -90,11 +91,9 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
pad_token (`str`, *optional*, defaults to `"<pad>"`): pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths. The token used for padding, for example when batching sequences of different lengths.
extra_ids (`int`, *optional*, defaults to 100): extra_ids (`int`, *optional*, defaults to 100):
Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are accessible as
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be retrieved by
indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method
like in T5 preprocessing see
[here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
additional_special_tokens (`List[str]`, *optional*): additional_special_tokens (`List[str]`, *optional*):
Additional special tokens used by the tokenizer. Additional special tokens used by the tokenizer.
""" """
...@@ -235,3 +234,11 @@ class T5TokenizerFast(PreTrainedTokenizerFast): ...@@ -235,3 +234,11 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
if token_ids_1 is None: if token_ids_1 is None:
return len(token_ids_0 + eos) * [0] return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0] return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
def get_sentinel_tokens(self):
return list(
set(filter(lambda x: bool(re.search("<extra_id_\d+>", x)) is not None, self.additional_special_tokens))
)
def get_sentinel_token_ids(self):
return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()]
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import json import json
import os import os
import re
import tempfile import tempfile
import unittest import unittest
...@@ -379,3 +380,25 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -379,3 +380,25 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
model_name="t5-base", model_name="t5-base",
revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b", revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b",
) )
def test_get_sentinel_tokens(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
sentinel_tokens = tokenizer.get_sentinel_tokens()
self.assertEquals(len(sentinel_tokens), 10)
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
self.assertTrue([re.search("<extra_id_\d+>", token) is not None for token in sentinel_tokens])
def test_get_sentinel_token_ids(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)]))
def test_get_sentinel_tokens_for_fasttokenizer(self):
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
sentinel_tokens = tokenizer.get_sentinel_tokens()
self.assertEquals(len(sentinel_tokens), 10)
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
self.assertTrue([re.search("<extra_id_\d+>", token) is not None for token in sentinel_tokens])
def test_get_sentinel_token_ids_for_fasttokenizer(self):
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment