"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "ea61244c04f500475768c83cd081aa50868c1840"
Commit e14c6b52 authored by John Hewitt's avatar John Hewitt
Browse files

add BertTokenizer flag to skip basic tokenization

parent 2152bfea
...@@ -507,7 +507,7 @@ where ...@@ -507,7 +507,7 @@ where
Examples: Examples:
```python ```python
# BERT # BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, do_basic_tokenize=True)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# OpenAI GPT # OpenAI GPT
...@@ -803,11 +803,12 @@ This model *outputs*: ...@@ -803,11 +803,12 @@ This model *outputs*:
`BertTokenizer` perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization. `BertTokenizer` perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
This class has four arguments: This class has five arguments:
- `vocab_file`: path to a vocabulary file. - `vocab_file`: path to a vocabulary file.
- `do_lower_case`: convert text to lower-case while tokenizing. **Default = True**. - `do_lower_case`: convert text to lower-case while tokenizing. **Default = True**.
- `max_len`: max length to filter the input of the Transformer. Default to pre-trained value for the model if `None`. **Default = None** - `max_len`: max length to filter the input of the Transformer. Default to pre-trained value for the model if `None`. **Default = None**
- `do_basic_tokenize`: Do basic tokenization before wordpice tokenization. Set to false if text is pre-tokenized. **Default = True**.
- `never_split`: a list of tokens that should not be splitted during tokenization. **Default = `["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]`** - `never_split`: a list of tokens that should not be splitted during tokenization. **Default = `["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]`**
and three methods: and three methods:
......
...@@ -74,8 +74,14 @@ def whitespace_tokenize(text): ...@@ -74,8 +74,14 @@ def whitespace_tokenize(text):
class BertTokenizer(object): class BertTokenizer(object):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True, max_len=None, def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BertTokenizer.
Args:
do_lower_case: Whether to lower case the input.
do_wordpiece_only: Whether to do basic tokenization before wordpiece.
"""
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
...@@ -83,16 +89,21 @@ class BertTokenizer(object): ...@@ -83,16 +89,21 @@ class BertTokenizer(object):
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict( self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()]) [(ids, tok) for tok, ids in self.vocab.items()])
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, self.do_basic_tokenize = do_basic_tokenize
never_split=never_split) if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text): def tokenize(self, text):
split_tokens = [] if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text): split_tokens = []
for sub_token in self.wordpiece_tokenizer.tokenize(token): for token in self.basic_tokenizer.tokenize(text):
split_tokens.append(sub_token) for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment