"...linux/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "349658c0524427c2bc496aa62336aff0fe0075f8"
Unverified Commit 951ae99b authored by Anthony MOI's avatar Anthony MOI
Browse files

BertTokenizerFast

parent 041eac2d
...@@ -103,7 +103,7 @@ from .pipelines import ( ...@@ -103,7 +103,7 @@ from .pipelines import (
) )
from .tokenization_albert import AlbertTokenizer from .tokenization_albert import AlbertTokenizer
from .tokenization_auto import AutoTokenizer from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
from .tokenization_camembert import CamembertTokenizer from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer from .tokenization_ctrl import CTRLTokenizer
......
...@@ -20,7 +20,7 @@ import logging ...@@ -20,7 +20,7 @@ import logging
import os import os
import unicodedata import unicodedata
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -525,3 +525,54 @@ def _is_punctuation(char): ...@@ -525,3 +525,54 @@ def _is_punctuation(char):
if cat.startswith("P"): if cat.startswith("P"):
return True return True
return False return False
class BertTokenizerFast(FastPreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
mask_token="[MASK]", tokenize_chinese_chars=True,
max_length=None, pad_to_max_length=False, stride=0,
truncation_strategy='longest_first', add_special_tokens=True, **kwargs):
try:
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors
super(BertTokenizerFast, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
self._tokenizer = Tokenizer(models.WordPiece.from_files(
vocab_file,
unk_token=unk_token
))
self._update_special_tokens()
self._tokenizer.with_pre_tokenizer(pre_tokenizers.BertPreTokenizer.new(
do_basic_tokenize=do_basic_tokenize,
do_lower_case=do_lower_case,
tokenize_chinese_chars=tokenize_chinese_chars,
never_split=never_split if never_split is not None else [],
))
self._tokenizer.with_decoder(decoders.WordPiece.new())
if add_special_tokens:
self._tokenizer.with_post_processor(processors.BertProcessing.new(
(sep_token, self._tokenizer.token_to_id(sep_token)),
(cls_token, self._tokenizer.token_to_id(cls_token)),
))
if max_length is not None:
self._tokenizer.with_truncation(max_length, stride, truncation_strategy)
self._tokenizer.with_padding(
max_length if pad_to_max_length else None,
self.padding_side,
self.pad_token_id,
self.pad_token_type_id,
self.pad_token
)
self._decoder = decoders.WordPiece.new()
except (AttributeError, ImportError) as e:
logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`")
raise e
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment