"docs/vscode:/vscode.git/clone" did not exist on "37d113cce729007558cbb95ebb081d39fa6ebcff"
Commit 6c41a8f5 authored by LysandreJik's avatar LysandreJik
Browse files

Encode and Decode are back in the superclass. They now handle sentence pairs special tokens.

parent e367ac46
......@@ -7,7 +7,6 @@ from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_utils import (PreTrainedTokenizer, clean_up_tokenization)
from .tokenization_utils import (PreTrainedTokenizer)
......@@ -39,7 +38,7 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel,
from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
......
......@@ -23,7 +23,7 @@ import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss
from pytorch_transformers.modeling_bert import (BertConfig, BertEmbeddings,
BertLayerNorm, BertModel,
......@@ -144,7 +144,6 @@ class RobertaLMHead(nn.Module):
return x
class RobertaForSequenceClassification(BertPreTrainedModel):
"""
Roberta Model with a classifier head on top.
......
......@@ -21,18 +21,19 @@ import logging
import re
from io import open
import six
import os
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
'vocab_file': 'dict.txt',
DICT_FILES_NAMES = {
'dict_file': 'dict.txt',
}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
PRETRAINED_DICT_FILES_MAP = {
'dict_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
......@@ -178,89 +179,62 @@ class RobertaTokenizer(PreTrainedTokenizer):
RoBERTa tokenizer. Peculiarities:
- GPT-2 tokenizer with a different integer mapping on top.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
vocab_files_names = DICT_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_DICT_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file,
bos_token="<s>", eos_token="</s>", **kwargs):
super(RobertaTokenizer, self).__init__(cls_token=bos_token, sep_token=eos_token, eos_token=eos_token, **kwargs)
def __init__(self, dict_file, bpe_tokenizer=None, bos_token="<s>", eos_token="</s>", sep_token="</s>", cls_token="<s>",
unk_token="<unk>", **kwargs):
super(RobertaTokenizer, self).__init__(cls_token=bos_token, sep_token=eos_token, eos_token=eos_token,
unk_token=unk_token, **kwargs)
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.dictionary = Dictionary.load(vocab_file)
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") if bpe_tokenizer is None else bpe_tokenizer
self.dictionary = Dictionary.load(dict_file)
@property
def vocab_size(self):
return len(self.dictionary.indices)
def _tokenize(self, text):
""" Use GPT-2 Tokenizer """
return self.gpt2_tokenizer._tokenize(text)
def encode(self, text, *args):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
"""
bpe_sentence = [self.cls_token] + \
self.gpt2_tokenizer.convert_tokens_to_ids(self.tokenize(text)) + \
[self.sep_token]
if len(args):
for additional_sentence in args:
bpe_sentence += [self.sep_token
] + \
self.gpt2_tokenizer.convert_tokens_to_ids(self.tokenize(additional_sentence)) + \
[self.sep_token]
return self.dictionary.encode_line(' '.join([str(token) for token in bpe_sentence]), append_eos=False)
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
Handles sentence pairs.
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
if any(isinstance(element, list) for element in filtered_tokens):
texts = []
for element in filtered_tokens:
text = self.convert_tokens_to_string(element)
if clean_up_tokenization_spaces:
text = clean_up_tokenization(text)
texts.append(text)
return texts
else:
text = self.convert_tokens_to_string(filtered_tokens)
if clean_up_tokenization_spaces:
text = clean_up_tokenization(text)
return text
def _convert_token_to_id(self, token):
return self.dictionary.index(token)
if self.dictionary.index(token) != 3:
return self.dictionary.index(token)
return self.dictionary.index(str(self.gpt2_tokenizer.convert_tokens_to_ids(token)))
def _convert_id_to_token(self, index):
symbol = self.dictionary[index]
try:
idx = int(symbol)
return self.gpt2_tokenizer._convert_id_to_token(idx)
except:
except ValueError:
return symbol
def convert_tokens_to_string(self, tokens):
return self.gpt2_tokenizer.convert_tokens_to_string(tokens)
def convert_tokens_to_ids(self, tokens, no_sep_cls_tokens=False):
cls = [self._convert_token_to_id(self.cls_token)]
tokens = super().convert_tokens_to_ids(tokens)
sep = [self._convert_token_to_id(self.sep_token)]
return (cls + tokens + sep) if (isinstance(tokens, list) and not no_sep_cls_tokens) else tokens
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
# Remove the first and last tokens which are cls and sep tokens
ids = ids[1:-1]
# If multi sentence, then split (multi sentence found by looking for two sequential sep tokens)
ids = [list(map(int, example.split(' '))) for example in ' '.join([str(id) for id in ids]).split(' 2 2 ')]
return super().convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)[1:-1]
if len(ids) == 1:
tokens = self.gpt2_tokenizer.convert_ids_to_tokens(list(map(lambda id: int(self.dictionary[id]), ids[0])))
else:
tokens = []
for example in ids:
tokens += [
self.gpt2_tokenizer.convert_ids_to_tokens(list(map(lambda id: int(self.dictionary[id]), example)))]
return tokens
def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
dict_file = os.path.join(save_directory, DICT_FILES_NAMES['dict_file'])
with open(dict_file, 'w', encoding='utf-8') as f:
for i in range(self.dictionary.nspecial, len(self.dictionary.count)):
f.write(f"{list(self.dictionary.indices.keys())[i]} {self.dictionary.count[i]}\n")
def convert_tokens_to_ids(self, tokens):
tokens = " ".join(str(x) for x in self.gpt2_tokenizer.convert_tokens_to_ids(tokens))
bpe_sentence = '<s> ' + tokens + ' </s>'
return self.dictionary.encode_line(bpe_sentence, append_eos=False)
vocab_files = self.gpt2_tokenizer.save_pretrained(save_directory)
return vocab_files + (dict_file,)
......@@ -495,7 +495,7 @@ class PreTrainedTokenizer(object):
"""
raise NotImplementedError
def convert_tokens_to_ids(self, tokens):
def convert_tokens_to_ids(self, tokens, **kwargs):
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
(resp. a sequence of ids), using the vocabulary.
"""
......@@ -520,12 +520,29 @@ class PreTrainedTokenizer(object):
raise NotImplementedError
def encode(self, text):
def encode(self, *text, cls_token_at_end=False, double_sep_token=False, no_sep_cls_tokens=False):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
"""
return self.convert_tokens_to_ids(self.tokenize(text))
if len(text) == 1:
return self.convert_tokens_to_ids(self.tokenize(text[0]), no_sep_cls_tokens=no_sep_cls_tokens)
if len(text) > 2:
logger.warning("Tokenization currently only supports sentence pairs. Ignoring every string following the "
"initial two.")
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[0])]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[1])]
sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
n_sep_token = 2 if double_sep_token else 1
tokens = first_sentence_tokens + sep * n_sep_token + second_sentence_tokens + sep
tokens = (tokens + cls) if cls_token_at_end else (cls + tokens)
return tokens
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
......@@ -560,7 +577,8 @@ class PreTrainedTokenizer(object):
"""
return ' '.join(self.convert_ids_to_tokens(tokens))
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, cls_token_at_end=False,
double_sep_token=False):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
......@@ -568,9 +586,21 @@ class PreTrainedTokenizer(object):
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
text = self.convert_tokens_to_string(filtered_tokens)
if clean_up_tokenization_spaces:
text = self.clean_up_tokenization(text)
return text
if self.sep_token is not None and self.sep_token in text:
text = text.replace(self.cls_token, self.sep_token)
split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self.sep_token)))
if clean_up_tokenization_spaces:
clean_text = [self.clean_up_tokenization(text) for text in split_text]
return clean_text
else:
return split_text
else:
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
@property
def special_tokens_map(self):
......@@ -602,7 +632,7 @@ class PreTrainedTokenizer(object):
class attributes (cls_token, unk_token...).
"""
all_toks = self.all_special_tokens
all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
all_ids = list(self._convert_token_to_id(t) for t in all_toks)
return all_ids
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment