"vscode:/vscode.git/clone" did not exist on "608a8f5b567f81f3cc997a195496dd8bf1c28158"
Commit 62df4ba5 authored by thomwolf's avatar thomwolf
Browse files

add dilbert tokenizer and tests

parent 4ce5f36f
......@@ -7,6 +7,7 @@ from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_dilbert import DilBertTokenizer
from .tokenization_utils import (PreTrainedTokenizer)
......
......@@ -42,7 +42,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self):
return BertTokenizer.from_pretrained(self.tmpdirname)
return self.tokenizer_class.from_pretrained(self.tmpdirname)
def get_input_output_texts(self):
input_text = u"UNwant\u00E9d,running"
......@@ -50,7 +50,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = BertTokenizer(self.vocab_file)
tokenizer = self.tokenizer_class(self.vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
......@@ -126,7 +126,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertFalse(_is_punctuation(u" "))
def test_sequence_builders(self):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
from io import open
from pytorch_transformers.tokenization_dilbert import (DilBertTokenizer)
from .tokenization_tests_commons import CommonTestCases
from .tokenization_bert_test import BertTokenizationTest
class DilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DilBertTokenizer
def get_tokenizer(self):
return DilBertTokenizer.from_pretrained(self.tmpdirname)
def test_sequence_builders(self):
tokenizer = DilBertTokenizer.from_pretrained("dilbert-base-uncased")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
assert encoded_sentence == [101] + text + [102]
assert encoded_pair == [101] + text + [102] + text_2 + [102]
if __name__ == '__main__':
unittest.main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for DilBERT."""
from __future__ import absolute_import, division, print_function, unicode_literals
import collections
import logging
import os
import unicodedata
from io import open
from .tokenization_bert import BertTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'dilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'dilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'dilbert-base-uncased': 512,
'dilbert-base-uncased-distilled-squad': 512,
}
class DilBertTokenizer(BertTokenizer):
r"""
Constructs a DilBertTokenizer.
:class:`~pytorch_transformers.DilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
minimum of this value (if specified) and the underlying BERT model's sequence length.
never_split: List of tokens which will never be split during tokenization. Only has an effect when
do_wordpiece_only=False
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment