Commit b514a60c authored by thomwolf's avatar thomwolf
Browse files

added tests for OpenAI GPT and Transformer-XL tokenizers

parent 9bdcba53
...@@ -529,10 +529,10 @@ This model *outputs*: ...@@ -529,10 +529,10 @@ This model *outputs*:
`OpenAIGPTDoubleHeadsModel` includes the `OpenAIGPTModel` Transformer followed by two heads: `OpenAIGPTDoubleHeadsModel` includes the `OpenAIGPTModel` Transformer followed by two heads:
- a language modeling head with weights tied to the input embeddings (no additional parameters) and: - a language modeling head with weights tied to the input embeddings (no additional parameters) and:
- a multiple choice classifier (linear layer). - a multiple choice classifier (linear layer that take as input a hidden state in a sequence to compute a score, see details in paper).
*Inputs* are the same as the inputs of the [`OpenAIGPTModel`](#-9.-`OpenAIGPTModel`) class plus a classification mask and two optional labels: *Inputs* are the same as the inputs of the [`OpenAIGPTModel`](#-9.-`OpenAIGPTModel`) class plus a classification mask and two optional labels:
- `multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise. - `multiple_choice_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token whose hidden state should be used as input for the multiple choice classifier (usually the [CLS] token for each choice).
- `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size]. - `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size].
- `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices]. - `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices].
...@@ -613,9 +613,9 @@ Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch ...@@ -613,9 +613,9 @@ Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch
#### `TransfoXLTokenizer` #### `TransfoXLTokenizer`
`TransfoXLTokenizer` perform word tokenization. `TransfoXLTokenizer` perform word tokenization. This tokenizer can be used for adaptive softmax and has utilities for counting tokens in a corpus to create a vocabulary ordered by toekn frequency (for adaptive softmax). See the adaptive softmax paper ([Efficient softmax approximation for GPUs](http://arxiv.org/abs/1609.04309)) for more details.
Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of the `TransfoXLTokenizer`. Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of these additional methods in `TransfoXLTokenizer`.
### Optimizers: ### Optimizers:
......
...@@ -70,7 +70,10 @@ def text_standardize(text): ...@@ -70,7 +70,10 @@ def text_standardize(text):
class OpenAIGPTTokenizer(object): class OpenAIGPTTokenizer(object):
""" """
mostly a wrapper for a public python bpe tokenizer BPE tokenizer. Peculiarities:
- lower case all inputs
- uses SpaCy tokenizer
- special tokens: additional symbols (ex: "__classify__") to add to a vocabulary.
""" """
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
...@@ -150,7 +153,7 @@ class OpenAIGPTTokenizer(object): ...@@ -150,7 +153,7 @@ class OpenAIGPTTokenizer(object):
logger.info("Special tokens {}".format(self.special_tokens)) logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + ( token[-1] + '</w>',) word = tuple(token[:-1]) + (token[-1] + '</w>',)
if token in self.cache: if token in self.cache:
return self.cache[token] return self.cache[token]
pairs = get_pairs(word) pairs = get_pairs(word)
...@@ -159,7 +162,7 @@ class OpenAIGPTTokenizer(object): ...@@ -159,7 +162,7 @@ class OpenAIGPTTokenizer(object):
return token+'</w>' return token+'</w>'
while True: while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
......
...@@ -25,6 +25,7 @@ import os ...@@ -25,6 +25,7 @@ import os
import sys import sys
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from io import open from io import open
import unicodedata
import torch import torch
import numpy as np import numpy as np
...@@ -89,8 +90,8 @@ class TransfoXLTokenizer(object): ...@@ -89,8 +90,8 @@ class TransfoXLTokenizer(object):
tokenizer.__dict__[key] = value tokenizer.__dict__[key] = value
return tokenizer return tokenizer
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
delimiter=None, vocab_file=None): delimiter=None, vocab_file=None, never_split=("<unk>", "<eos>", "<formula>")):
self.counter = Counter() self.counter = Counter()
self.special = special self.special = special
self.min_freq = min_freq self.min_freq = min_freq
...@@ -98,6 +99,7 @@ class TransfoXLTokenizer(object): ...@@ -98,6 +99,7 @@ class TransfoXLTokenizer(object):
self.lower_case = lower_case self.lower_case = lower_case
self.delimiter = delimiter self.delimiter = delimiter
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.never_split = never_split
def count_file(self, path, verbose=False, add_eos=False): def count_file(self, path, verbose=False, add_eos=False):
if verbose: print('counting file {} ...'.format(path)) if verbose: print('counting file {} ...'.format(path))
...@@ -132,7 +134,12 @@ class TransfoXLTokenizer(object): ...@@ -132,7 +134,12 @@ class TransfoXLTokenizer(object):
for line in f: for line in f:
symb = line.strip().split()[0] symb = line.strip().split()[0]
self.add_symbol(symb) self.add_symbol(symb)
self.unk_idx = self.sym2idx['<UNK>'] if '<UNK>' in self.sym2idx:
self.unk_idx = self.sym2idx['<UNK>']
elif '<unk>' in self.sym2idx:
self.unk_idx = self.sym2idx['<unk>']
else:
raise ValueError('No <unkown> token in vocabulary')
def build_vocab(self): def build_vocab(self):
if self.vocab_file: if self.vocab_file:
...@@ -198,7 +205,7 @@ class TransfoXLTokenizer(object): ...@@ -198,7 +205,7 @@ class TransfoXLTokenizer(object):
self.sym2idx[sym] = len(self.idx2sym) - 1 self.sym2idx[sym] = len(self.idx2sym) - 1
def get_sym(self, idx): def get_sym(self, idx):
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
return self.idx2sym[idx] return self.idx2sym[idx]
def get_idx(self, sym): def get_idx(self, sym):
...@@ -206,9 +213,16 @@ class TransfoXLTokenizer(object): ...@@ -206,9 +213,16 @@ class TransfoXLTokenizer(object):
return self.sym2idx[sym] return self.sym2idx[sym]
else: else:
# print('encounter unk {}'.format(sym)) # print('encounter unk {}'.format(sym))
assert '<eos>' not in sym # assert '<eos>' not in sym
assert hasattr(self, 'unk_idx') if hasattr(self, 'unk_idx'):
return self.sym2idx.get(sym, self.unk_idx) return self.sym2idx.get(sym, self.unk_idx)
# Backward compatibility with pre-trained models
elif '<unk>' in self.sym2idx:
return self.sym2idx['<unk>']
elif '<UNK>' in self.sym2idx:
return self.sym2idx['<UNK>']
else:
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
def convert_ids_to_tokens(self, indices): def convert_ids_to_tokens(self, indices):
"""Converts a sequence of indices in symbols using the vocab.""" """Converts a sequence of indices in symbols using the vocab."""
...@@ -231,24 +245,82 @@ class TransfoXLTokenizer(object): ...@@ -231,24 +245,82 @@ class TransfoXLTokenizer(object):
def __len__(self): def __len__(self):
return len(self.idx2sym) return len(self.idx2sym)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
if text in self.never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def whitespace_tokenize(self, text):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text = text.strip()
if not text:
return []
if self.delimiter == '':
tokens = text
else:
tokens = text.split(self.delimiter)
return tokens
def tokenize(self, line, add_eos=False, add_double_eos=False): def tokenize(self, line, add_eos=False, add_double_eos=False):
line = self._clean_text(line)
line = line.strip() line = line.strip()
# convert to lower case
if self.lower_case:
line = line.lower()
# empty delimiter '' will evaluate False symbols = self.whitespace_tokenize(line)
if self.delimiter == '':
symbols = line split_symbols = []
else: for symbol in symbols:
symbols = line.split(self.delimiter) if self.lower_case and symbol not in self.never_split:
symbol = symbol.lower()
symbol = self._run_strip_accents(symbol)
split_symbols.extend(self._run_split_on_punc(symbol))
if add_double_eos: # lm1b if add_double_eos: # lm1b
return ['<S>'] + symbols + ['<S>'] return ['<S>'] + split_symbols + ['<S>']
elif add_eos: elif add_eos:
return symbols + ['<eos>'] return split_symbols + ['<eos>']
else: else:
return symbols return split_symbols
class LMOrderedIterator(object): class LMOrderedIterator(object):
...@@ -556,3 +628,42 @@ def get_lm_corpus(datadir, dataset): ...@@ -556,3 +628,42 @@ def get_lm_corpus(datadir, dataset):
torch.save(corpus, fn) torch.save(corpus, fn)
return corpus return corpus
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
import json
from io import open
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer
class OpenAIGPTTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>",
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w", encoding='utf-8') as fp:
json.dump(vocab_tokens, fp)
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w", encoding='utf-8') as fp:
fp.write("\n".join(merges))
merges_file = fp.name
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>"])
os.remove(vocab_file)
os.remove(merges_file)
text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__':
unittest.main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
from io import open
from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer,
_is_control, _is_punctuation,
_is_whitespace)
class TransfoXLTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
vocab_tokens = [
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", ","
]
with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
tokenizer.build_vocab()
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True)
self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello", "!", "how", "are", "you", "?"])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_full_tokenizer_no_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=False)
self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_is_whitespace(self):
self.assertTrue(_is_whitespace(u" "))
self.assertTrue(_is_whitespace(u"\t"))
self.assertTrue(_is_whitespace(u"\r"))
self.assertTrue(_is_whitespace(u"\n"))
self.assertTrue(_is_whitespace(u"\u00A0"))
self.assertFalse(_is_whitespace(u"A"))
self.assertFalse(_is_whitespace(u"-"))
def test_is_control(self):
self.assertTrue(_is_control(u"\u0005"))
self.assertFalse(_is_control(u"A"))
self.assertFalse(_is_control(u" "))
self.assertFalse(_is_control(u"\t"))
self.assertFalse(_is_control(u"\r"))
def test_is_punctuation(self):
self.assertTrue(_is_punctuation(u"-"))
self.assertTrue(_is_punctuation(u"$"))
self.assertTrue(_is_punctuation(u"`"))
self.assertTrue(_is_punctuation(u"."))
self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(_is_punctuation(u" "))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment