"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "5e19b159b02ae1b519ce6c3a0d1e7ff04ab9fcd5"
Unverified Commit 4784b04f authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #325 from john-hewitt/master

add BertTokenizer flag to skip basic tokenization
parents 2152bfea 4d1ad832
...@@ -507,7 +507,7 @@ where ...@@ -507,7 +507,7 @@ where
Examples: Examples:
```python ```python
# BERT # BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, do_basic_tokenize=True)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# OpenAI GPT # OpenAI GPT
...@@ -803,11 +803,12 @@ This model *outputs*: ...@@ -803,11 +803,12 @@ This model *outputs*:
`BertTokenizer` perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization. `BertTokenizer` perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
This class has four arguments: This class has five arguments:
- `vocab_file`: path to a vocabulary file. - `vocab_file`: path to a vocabulary file.
- `do_lower_case`: convert text to lower-case while tokenizing. **Default = True**. - `do_lower_case`: convert text to lower-case while tokenizing. **Default = True**.
- `max_len`: max length to filter the input of the Transformer. Default to pre-trained value for the model if `None`. **Default = None** - `max_len`: max length to filter the input of the Transformer. Default to pre-trained value for the model if `None`. **Default = None**
- `do_basic_tokenize`: Do basic tokenization before wordpice tokenization. Set to false if text is pre-tokenized. **Default = True**.
- `never_split`: a list of tokens that should not be splitted during tokenization. **Default = `["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]`** - `never_split`: a list of tokens that should not be splitted during tokenization. **Default = `["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]`**
and three methods: and three methods:
......
...@@ -74,8 +74,22 @@ def whitespace_tokenize(text): ...@@ -74,8 +74,22 @@ def whitespace_tokenize(text):
class BertTokenizer(object): class BertTokenizer(object):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True, max_len=None, def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BertTokenizer.
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input
Only has an effect when do_wordpiece_only=False
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying BERT model's
sequence length.
never_split: List of tokens which will never be split during tokenization.
Only has an effect when do_wordpiece_only=False
"""
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
...@@ -83,16 +97,21 @@ class BertTokenizer(object): ...@@ -83,16 +97,21 @@ class BertTokenizer(object):
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict( self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()]) [(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split) never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text): def tokenize(self, text):
if self.do_basic_tokenize:
split_tokens = [] split_tokens = []
for token in self.basic_tokenizer.tokenize(text): for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token): for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token) split_tokens.append(sub_token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment