Unverified Commit 3471ff0d authored by Anthony MOI's avatar Anthony MOI
Browse files

FastPreTrainedTokenizer

parent 81db12c3
...@@ -1410,3 +1410,130 @@ class PreTrainedTokenizer(object): ...@@ -1410,3 +1410,130 @@ class PreTrainedTokenizer(object):
.replace(" 're", "'re") .replace(" 're", "'re")
) )
return out_string return out_string
class FastPreTrainedTokenizer(PreTrainedTokenizer):
def __init__(self, **kwargs):
super(FastPreTrainedTokenizer, self).__init__(**kwargs)
@property
def tokenizer(self):
if self._tokenizer is None:
raise NotImplementedError
return self._tokenizer
@property
def decoder(self):
if self._decoder is None:
raise NotImplementedError
return self._decoder
@property
def vocab_size(self):
return self.tokenizer.get_vocab_size(False)
def __len__(self):
return self.tokenizer.get_vocab_size(True)
def _update_special_tokens(self):
self.tokenizer.add_special_tokens(self.all_special_tokens)
@staticmethod
def _convert_encoding(encoding,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
encoding_dict = {
"input_ids": encoding.ids,
}
if return_token_type_ids:
encoding_dict["token_type_ids"] = encoding.type_ids
if return_attention_mask:
encoding_dict["attention_mask"] = encoding.attention_mask
if return_overflowing_tokens:
overflowing = encoding.overflowing
encoding_dict["overflowing_tokens"] = overflowing.ids if overflowing is not None else []
if return_special_tokens_mask:
encoding_dict["special_tokens_mask"] = encoding.special_tokens_mask
# Prepare inputs as tensors if asked
if return_tensors == 'tf' and is_tf_available():
encoding_dict["input_ids"] = tf.constant([encoding_dict["input_ids"]])
encoding_dict["token_type_ids"] = tf.constant([encoding_dict["token_type_ids"]])
if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = tf.constant([encoding_dict["attention_mask"]])
elif return_tensors == 'pt' and is_torch_available():
encoding_dict["input_ids"] = torch.tensor([encoding_dict["input_ids"]])
encoding_dict["token_type_ids"] = torch.tensor([encoding_dict["token_type_ids"]])
if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = torch.tensor([encoding_dict["attention_mask"]])
elif return_tensors is not None:
logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors))
return encoding_dict
def encode_plus(self,
text,
text_pair=None,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
**kwargs):
encoding = self.tokenizer.encode(text, text_pair)
return self._convert_encoding(encoding,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask)
def tokenize(self, text):
return self.tokenizer.encode(text).tokens
def _convert_token_to_id_with_added_voc(self, token):
return self.tokenizer.token_to_id(token)
def _convert_id_to_token(self, index):
return self.tokenizer.id_to_token(int(index))
def convert_tokens_to_string(self, tokens):
return self.decoder.decode(tokens)
def add_tokens(self, new_tokens):
self.tokenizer.add_tokens(new_tokens)
def encode_batch(self, texts,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
return [self._convert_encoding(encoding,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask)
for encoding in self.tokenizer.encode_batch(texts)]
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
text = self.tokenizer.decode(token_ids, skip_special_tokens)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
def decode_batch(self, ids_batch, skip_special_tokens=False, clear_up_tokenization_spaces=True):
return [self.clean_up_tokenization(text)
if clear_up_tokenization_spaces else text
for text in self.tokenizer.decode_batch(ids_batch, skip_special_tokens)]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment