Unverified Commit 8d6b5096 authored by st81's avatar st81 Committed by GitHub
Browse files

Add token type ids to CodeGenTokenizer (#29265)

* Add create token type ids to CodeGenTokenizer

* Fix inconsistent length of token type ids

* Format source codes

* Fix inconsistent order of methods

* Update docstring

* add test_tokenizer_integration test

* Format source codes

* Add `copied from` comment to CodeGenTokenizerFast

* Add doc of create_token_type_ids_from_sequences

* Make return_token_type_ids False by default

* Make test_tokenizer_integration as slow test

* Add return_token_type_ids to tokenizer init arg

* Add test for tokenizer's init return_token_type_ids

* Format source codes
parent 812a5de2
...@@ -72,6 +72,7 @@ hello_world() ...@@ -72,6 +72,7 @@ hello_world()
## CodeGenTokenizer ## CodeGenTokenizer
[[autodoc]] CodeGenTokenizer [[autodoc]] CodeGenTokenizer
- create_token_type_ids_from_sequences
- save_vocabulary - save_vocabulary
## CodeGenTokenizerFast ## CodeGenTokenizerFast
......
...@@ -134,6 +134,8 @@ class CodeGenTokenizer(PreTrainedTokenizer): ...@@ -134,6 +134,8 @@ class CodeGenTokenizer(PreTrainedTokenizer):
other word. (CodeGen tokenizer detect beginning of words by the preceding space). other word. (CodeGen tokenizer detect beginning of words by the preceding space).
add_bos_token (`bool`, *optional*, defaults to `False`): add_bos_token (`bool`, *optional*, defaults to `False`):
Whether to add a beginning of sequence token at the start of sequences. Whether to add a beginning of sequence token at the start of sequences.
return_token_type_ids (`bool`, *optional*, defaults to `False`):
Whether to return token type IDs.
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
...@@ -150,6 +152,7 @@ class CodeGenTokenizer(PreTrainedTokenizer): ...@@ -150,6 +152,7 @@ class CodeGenTokenizer(PreTrainedTokenizer):
pad_token=None, pad_token=None,
add_prefix_space=False, add_prefix_space=False,
add_bos_token=False, add_bos_token=False,
return_token_type_ids=False,
**kwargs, **kwargs,
): ):
bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token
...@@ -157,6 +160,9 @@ class CodeGenTokenizer(PreTrainedTokenizer): ...@@ -157,6 +160,9 @@ class CodeGenTokenizer(PreTrainedTokenizer):
unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
self.return_token_type_ids = return_token_type_ids
if self.return_token_type_ids:
self.model_input_names.append("token_type_ids")
with open(vocab_file, encoding="utf-8") as vocab_handle: with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle) self.encoder = json.load(vocab_handle)
...@@ -181,6 +187,7 @@ class CodeGenTokenizer(PreTrainedTokenizer): ...@@ -181,6 +187,7 @@ class CodeGenTokenizer(PreTrainedTokenizer):
pad_token=pad_token, pad_token=pad_token,
add_prefix_space=add_prefix_space, add_prefix_space=add_prefix_space,
add_bos_token=add_bos_token, add_bos_token=add_bos_token,
return_token_type_ids=return_token_type_ids,
**kwargs, **kwargs,
) )
...@@ -270,6 +277,35 @@ class CodeGenTokenizer(PreTrainedTokenizer): ...@@ -270,6 +277,35 @@ class CodeGenTokenizer(PreTrainedTokenizer):
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text return text
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A sequence
pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
sep = [self.sep_token_id] if self.sep_token_id is not None else []
cls = [self.cls_token_id] if self.sep_token_id is not None else []
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory") logger.error(f"Vocabulary path ({save_directory}) should be a directory")
......
...@@ -91,6 +91,8 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast): ...@@ -91,6 +91,8 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast):
add_prefix_space (`bool`, *optional*, defaults to `False`): add_prefix_space (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word. (CodeGen tokenizer detect beginning of words by the preceding space). other word. (CodeGen tokenizer detect beginning of words by the preceding space).
return_token_type_ids (`bool`, *optional*, defaults to `False`):
Whether to return token type IDs.
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
...@@ -106,8 +108,13 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast): ...@@ -106,8 +108,13 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast):
bos_token="<|endoftext|>", bos_token="<|endoftext|>",
eos_token="<|endoftext|>", eos_token="<|endoftext|>",
add_prefix_space=False, add_prefix_space=False,
return_token_type_ids=False,
**kwargs, **kwargs,
): ):
self.return_token_type_ids = return_token_type_ids
if self.return_token_type_ids:
self.model_input_names.append("token_type_ids")
super().__init__( super().__init__(
vocab_file, vocab_file,
merges_file, merges_file,
...@@ -116,6 +123,7 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast): ...@@ -116,6 +123,7 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast):
bos_token=bos_token, bos_token=bos_token,
eos_token=eos_token, eos_token=eos_token,
add_prefix_space=add_prefix_space, add_prefix_space=add_prefix_space,
return_token_type_ids=return_token_type_ids,
**kwargs, **kwargs,
) )
...@@ -157,6 +165,36 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast): ...@@ -157,6 +165,36 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast):
return super()._encode_plus(*args, **kwargs) return super()._encode_plus(*args, **kwargs)
# Copied from transformers.models.codegen.tokenization_codegen.CodeGenTokenizer.create_token_type_ids_from_sequences
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A sequence
pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
sep = [self.sep_token_id] if self.sep_token_id is not None else []
cls = [self.cls_token_id] if self.sep_token_id is not None else []
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix) files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files) return tuple(files)
......
...@@ -264,3 +264,55 @@ class CodeGenTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -264,3 +264,55 @@ class CodeGenTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# tokenizer has no padding token # tokenizer has no padding token
def test_padding_different_model_input_name(self): def test_padding_different_model_input_name(self):
pass pass
@slow
def test_tokenizer_integration(self):
# Custom test since this tokenizer takes return_token_type_ids as an init argument for backward compatibility.
sequences = [
"Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides "
"general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural "
"Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained "
"models in 100+ languages and deep interoperability between Jax, PyTorch and TensorFlow.",
"BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly "
"conditioning on both left and right context in all layers.",
"The quick brown fox jumps over the lazy dog.",
]
tokenizer_classes = [self.tokenizer_class]
if self.test_rust_tokenizer:
tokenizer_classes.append(self.rust_tokenizer_class)
# Test default case. i.e. return_token_type_ids is False.
for tokenizer_class in tokenizer_classes:
tokenizer = tokenizer_class.from_pretrained("Salesforce/codegen-350M-mono")
encoding = tokenizer(sequences)
decoded_sequences = [tokenizer.decode(seq, skip_special_tokens=True) for seq in encoding["input_ids"]]
# fmt: off
expected_encoding = {'input_ids': [[41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13], [13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13], [464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # noqa: E501
# fmt: on
encoding_data = encoding.data
self.assertDictEqual(encoding_data, expected_encoding)
for expected, decoded in zip(sequences, decoded_sequences):
self.assertEqual(expected, decoded)
# Test return_token_type_ids is True case.
for tokenizer_class in tokenizer_classes:
tokenizer = tokenizer_class.from_pretrained("Salesforce/codegen-350M-mono", return_token_type_ids=True)
encoding = tokenizer(sequences)
decoded_sequences = [tokenizer.decode(seq, skip_special_tokens=True) for seq in encoding["input_ids"]]
# fmt: off
expected_encoding = {'input_ids': [[41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13], [13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13], [464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # noqa: E501
# fmt: on
encoding_data = encoding.data
self.assertDictEqual(encoding_data, expected_encoding)
for expected, decoded in zip(sequences, decoded_sequences):
self.assertEqual(expected, decoded)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment