Commit 1c12ee0e authored by thomwolf's avatar thomwolf
Browse files

fixing xlm-roberta tokenizer max_length and automodels

parent 65c75fc5
...@@ -20,7 +20,7 @@ import logging ...@@ -20,7 +20,7 @@ import logging
from .configuration_auto import (AlbertConfig, BertConfig, CamembertConfig, CTRLConfig, from .configuration_auto import (AlbertConfig, BertConfig, CamembertConfig, CTRLConfig,
DistilBertConfig, GPT2Config, OpenAIGPTConfig, RobertaConfig, DistilBertConfig, GPT2Config, OpenAIGPTConfig, RobertaConfig,
TransfoXLConfig, XLMConfig, XLNetConfig) TransfoXLConfig, XLMConfig, XLNetConfig, XLMRobertaConfig)
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering, \ from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering, \
BertForTokenClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP BertForTokenClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -41,7 +41,8 @@ from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertF ...@@ -41,7 +41,8 @@ from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertF
from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, \ from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, \
AlbertForQuestionAnswering, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP AlbertForQuestionAnswering, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification, XLMRobertaForMultipleChoice, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification, \
XLMRobertaForMultipleChoice, XLMRobertaForTokenClassification, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_utils import PreTrainedModel, SequenceSummary from .modeling_utils import PreTrainedModel, SequenceSummary
...@@ -146,6 +147,8 @@ class AutoModel(object): ...@@ -146,6 +147,8 @@ class AutoModel(object):
return AlbertModel(config) return AlbertModel(config)
elif isinstance(config, CamembertConfig): elif isinstance(config, CamembertConfig):
return CamembertModel(config) return CamembertModel(config)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaModel(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
...@@ -333,6 +336,8 @@ class AutoModelWithLMHead(object): ...@@ -333,6 +336,8 @@ class AutoModelWithLMHead(object):
return XLMWithLMHeadModel(config) return XLMWithLMHeadModel(config)
elif isinstance(config, CTRLConfig): elif isinstance(config, CTRLConfig):
return CTRLLMHeadModel(config) return CTRLLMHeadModel(config)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForMaskedLM(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
...@@ -509,6 +514,8 @@ class AutoModelForSequenceClassification(object): ...@@ -509,6 +514,8 @@ class AutoModelForSequenceClassification(object):
return XLNetForSequenceClassification(config) return XLNetForSequenceClassification(config)
elif isinstance(config, XLMConfig): elif isinstance(config, XLMConfig):
return XLMForSequenceClassification(config) return XLMForSequenceClassification(config)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForSequenceClassification(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
...@@ -787,6 +794,8 @@ class AutoModelForTokenClassification: ...@@ -787,6 +794,8 @@ class AutoModelForTokenClassification:
return XLNetForTokenClassification(config) return XLNetForTokenClassification(config)
elif isinstance(config, RobertaConfig): elif isinstance(config, RobertaConfig):
return RobertaForTokenClassification(config) return RobertaForTokenClassification(config)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForTokenClassification(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
...@@ -865,6 +874,8 @@ class AutoModelForTokenClassification: ...@@ -865,6 +874,8 @@ class AutoModelForTokenClassification:
return CamembertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return CamembertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path: elif 'distilbert' in pretrained_model_name_or_path:
return DistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path:
return XLMRobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
...@@ -873,4 +884,4 @@ class AutoModelForTokenClassification: ...@@ -873,4 +884,4 @@ class AutoModelForTokenClassification:
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert', 'distilbert', 'roberta'".format(pretrained_model_name_or_path)) "'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(pretrained_model_name_or_path))
...@@ -415,7 +415,7 @@ class PreTrainedModel(nn.Module): ...@@ -415,7 +415,7 @@ class PreTrainedModel(nn.Module):
state_dict = torch.load(resolved_archive_file, map_location='cpu') state_dict = torch.load(resolved_archive_file, map_location='cpu')
except: except:
raise OSError("Unable to load weights from pytorch checkpoint file. " raise OSError("Unable to load weights from pytorch checkpoint file. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. ") "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. ")
missing_keys = [] missing_keys = []
unexpected_keys = [] unexpected_keys = []
......
...@@ -49,7 +49,7 @@ logger = logging.getLogger(__name__) ...@@ -49,7 +49,7 @@ logger = logging.getLogger(__name__)
def get_framework(model=None): def get_framework(model=None):
""" Select framework (TensorFlow/PyTorch) to use. """ Select framework (TensorFlow/PyTorch) to use.
If both frameworks are installed and no specific model is provided, defaults to using TensorFlow. If both frameworks are installed and no specific model is provided, defaults to using PyTorch.
""" """
if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str): if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
# Both framework are available but the use supplied a model class instance. # Both framework are available but the use supplied a model class instance.
...@@ -60,7 +60,8 @@ def get_framework(model=None): ...@@ -60,7 +60,8 @@ def get_framework(model=None):
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/.") "To install PyTorch, read the instructions at https://pytorch.org/.")
else: else:
framework = 'tf' if is_tf_available() else 'pt' # framework = 'tf' if is_tf_available() else 'pt'
framework = 'pt' if is_torch_available() else 'tf'
return framework return framework
class ArgumentHandler(ABC): class ArgumentHandler(ABC):
......
...@@ -434,7 +434,11 @@ class PreTrainedTokenizer(object): ...@@ -434,7 +434,11 @@ class PreTrainedTokenizer(object):
init_kwargs[key] = value init_kwargs[key] = value
# Instantiate tokenizer. # Instantiate tokenizer.
tokenizer = cls(*init_inputs, **init_kwargs) try:
tokenizer = cls(*init_inputs, **init_kwargs)
except OSError:
OSError("Unable to load vocabulary from file. "
"Please check that the provided vocabulary is accessible and not corrupted.")
# Save inputs and kwargs for saving and re-loading with ``save_pretrained`` # Save inputs and kwargs for saving and re-loading with ``save_pretrained``
tokenizer.init_inputs = init_inputs tokenizer.init_inputs = init_inputs
......
...@@ -40,8 +40,12 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -40,8 +40,12 @@ PRETRAINED_VOCAB_FILES_MAP = {
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-roberta-base': None, 'xlm-roberta-base': 512,
'xlm-roberta-large': None, 'xlm-roberta-large': 512,
'xlm-roberta-large-finetuned-conll02-dutch': 512,
'xlm-roberta-large-finetuned-conll02-spanish': 512,
'xlm-roberta-large-finetuned-conll03-english': 512,
'xlm-roberta-large-finetuned-conll03-german': 512,
} }
class XLMRobertaTokenizer(PreTrainedTokenizer): class XLMRobertaTokenizer(PreTrainedTokenizer):
...@@ -58,10 +62,10 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -58,10 +62,10 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>", def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>",
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>',
**kwargs): **kwargs):
super(XLMRobertaTokenizer, self).__init__(max_len=512, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, super(XLMRobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, mask_token=mask_token,
**kwargs) **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment