Commit b514a60c authored by thomwolf's avatar thomwolf
Browse files

added tests for OpenAI GPT and Transformer-XL tokenizers

parent 9bdcba53
......@@ -529,10 +529,10 @@ This model *outputs*:
`OpenAIGPTDoubleHeadsModel` includes the `OpenAIGPTModel` Transformer followed by two heads:
- a language modeling head with weights tied to the input embeddings (no additional parameters) and:
- a multiple choice classifier (linear layer).
- a multiple choice classifier (linear layer that take as input a hidden state in a sequence to compute a score, see details in paper).
*Inputs* are the same as the inputs of the [`OpenAIGPTModel`](#-9.-`OpenAIGPTModel`) class plus a classification mask and two optional labels:
- `multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
- `multiple_choice_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token whose hidden state should be used as input for the multiple choice classifier (usually the [CLS] token for each choice).
- `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size].
- `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices].
......@@ -613,9 +613,9 @@ Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch
#### `TransfoXLTokenizer`
`TransfoXLTokenizer` perform word tokenization.
`TransfoXLTokenizer` perform word tokenization. This tokenizer can be used for adaptive softmax and has utilities for counting tokens in a corpus to create a vocabulary ordered by toekn frequency (for adaptive softmax). See the adaptive softmax paper ([Efficient softmax approximation for GPUs](http://arxiv.org/abs/1609.04309)) for more details.
Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of the `TransfoXLTokenizer`.
Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of these additional methods in `TransfoXLTokenizer`.
### Optimizers:
......
......@@ -70,7 +70,10 @@ def text_standardize(text):
class OpenAIGPTTokenizer(object):
"""
mostly a wrapper for a public python bpe tokenizer
BPE tokenizer. Peculiarities:
- lower case all inputs
- uses SpaCy tokenizer
- special tokens: additional symbols (ex: "__classify__") to add to a vocabulary.
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
......@@ -150,7 +153,7 @@ class OpenAIGPTTokenizer(object):
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
word = tuple(token[:-1]) + (token[-1] + '</w>',)
if token in self.cache:
return self.cache[token]
pairs = get_pairs(word)
......@@ -159,7 +162,7 @@ class OpenAIGPTTokenizer(object):
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
......
......@@ -25,6 +25,7 @@ import os
import sys
from collections import Counter, OrderedDict
from io import open
import unicodedata
import torch
import numpy as np
......@@ -89,8 +90,8 @@ class TransfoXLTokenizer(object):
tokenizer.__dict__[key] = value
return tokenizer
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
delimiter=None, vocab_file=None):
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
delimiter=None, vocab_file=None, never_split=("<unk>", "<eos>", "<formula>")):
self.counter = Counter()
self.special = special
self.min_freq = min_freq
......@@ -98,6 +99,7 @@ class TransfoXLTokenizer(object):
self.lower_case = lower_case
self.delimiter = delimiter
self.vocab_file = vocab_file
self.never_split = never_split
def count_file(self, path, verbose=False, add_eos=False):
if verbose: print('counting file {} ...'.format(path))
......@@ -132,7 +134,12 @@ class TransfoXLTokenizer(object):
for line in f:
symb = line.strip().split()[0]
self.add_symbol(symb)
if '<UNK>' in self.sym2idx:
self.unk_idx = self.sym2idx['<UNK>']
elif '<unk>' in self.sym2idx:
self.unk_idx = self.sym2idx['<unk>']
else:
raise ValueError('No <unkown> token in vocabulary')
def build_vocab(self):
if self.vocab_file:
......@@ -198,7 +205,7 @@ class TransfoXLTokenizer(object):
self.sym2idx[sym] = len(self.idx2sym) - 1
def get_sym(self, idx):
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
return self.idx2sym[idx]
def get_idx(self, sym):
......@@ -206,9 +213,16 @@ class TransfoXLTokenizer(object):
return self.sym2idx[sym]
else:
# print('encounter unk {}'.format(sym))
assert '<eos>' not in sym
assert hasattr(self, 'unk_idx')
# assert '<eos>' not in sym
if hasattr(self, 'unk_idx'):
return self.sym2idx.get(sym, self.unk_idx)
# Backward compatibility with pre-trained models
elif '<unk>' in self.sym2idx:
return self.sym2idx['<unk>']
elif '<UNK>' in self.sym2idx:
return self.sym2idx['<UNK>']
else:
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
def convert_ids_to_tokens(self, indices):
"""Converts a sequence of indices in symbols using the vocab."""
......@@ -231,24 +245,82 @@ class TransfoXLTokenizer(object):
def __len__(self):
return len(self.idx2sym)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
if text in self.never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def whitespace_tokenize(self, text):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text = text.strip()
if not text:
return []
if self.delimiter == '':
tokens = text
else:
tokens = text.split(self.delimiter)
return tokens
def tokenize(self, line, add_eos=False, add_double_eos=False):
line = self._clean_text(line)
line = line.strip()
# convert to lower case
if self.lower_case:
line = line.lower()
# empty delimiter '' will evaluate False
if self.delimiter == '':
symbols = line
else:
symbols = line.split(self.delimiter)
symbols = self.whitespace_tokenize(line)
split_symbols = []
for symbol in symbols:
if self.lower_case and symbol not in self.never_split:
symbol = symbol.lower()
symbol = self._run_strip_accents(symbol)
split_symbols.extend(self._run_split_on_punc(symbol))
if add_double_eos: # lm1b
return ['<S>'] + symbols + ['<S>']
return ['<S>'] + split_symbols + ['<S>']
elif add_eos:
return symbols + ['<eos>']
return split_symbols + ['<eos>']
else:
return symbols
return split_symbols
class LMOrderedIterator(object):
......@@ -556,3 +628,42 @@ def get_lm_corpus(datadir, dataset):
torch.save(corpus, fn)
return corpus
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
import json
from io import open
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer
class OpenAIGPTTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>",
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w", encoding='utf-8') as fp:
json.dump(vocab_tokens, fp)
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w", encoding='utf-8') as fp:
fp.write("\n".join(merges))
merges_file = fp.name
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>"])
os.remove(vocab_file)
os.remove(merges_file)
text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__':
unittest.main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
from io import open
from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer,
_is_control, _is_punctuation,
_is_whitespace)
class TransfoXLTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
vocab_tokens = [
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", ","
]
with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
tokenizer.build_vocab()
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True)
self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello", "!", "how", "are", "you", "?"])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_full_tokenizer_no_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=False)
self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_is_whitespace(self):
self.assertTrue(_is_whitespace(u" "))
self.assertTrue(_is_whitespace(u"\t"))
self.assertTrue(_is_whitespace(u"\r"))
self.assertTrue(_is_whitespace(u"\n"))
self.assertTrue(_is_whitespace(u"\u00A0"))
self.assertFalse(_is_whitespace(u"A"))
self.assertFalse(_is_whitespace(u"-"))
def test_is_control(self):
self.assertTrue(_is_control(u"\u0005"))
self.assertFalse(_is_control(u"A"))
self.assertFalse(_is_control(u" "))
self.assertFalse(_is_control(u"\t"))
self.assertFalse(_is_control(u"\r"))
def test_is_punctuation(self):
self.assertTrue(_is_punctuation(u"-"))
self.assertTrue(_is_punctuation(u"$"))
self.assertTrue(_is_punctuation(u"`"))
self.assertTrue(_is_punctuation(u"."))
self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(_is_punctuation(u" "))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment