Commit 4d47f498 authored by thomwolf's avatar thomwolf
Browse files

slight refactoring, add abstract class for model loading

parent 59cefd4f
...@@ -28,4 +28,5 @@ from .optimization_openai import OpenAIAdam ...@@ -28,4 +28,5 @@ from .optimization_openai import OpenAIAdam
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path) from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
from .model_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig) from .model_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
...@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__) ...@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
TF_WEIGHTS_NAME = 'model.ckpt'
class PretrainedConfig(object): class PretrainedConfig(object):
...@@ -131,6 +132,169 @@ class PretrainedConfig(object): ...@@ -131,6 +132,169 @@ class PretrainedConfig(object):
writer.write(self.to_json_string()) writer.write(self.to_json_string())
class PreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = PretrainedConfig
pretrained_model_archive_map = {}
pretrained_config_archive_map = {}
load_tf_weights = lambda model, config, path: None
base_model_prefix = ""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
if not isinstance(config, PretrainedConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load, or
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a tensorflow pretrained model checkpoint containing:
. `config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use
instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
else:
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = cls.config_class.from_json_file(resolved_config_file)
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint
return load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
# Load from a PyTorch state_dict
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
start_prefix = cls.base_model_prefix + '.' # Used to be able to load base models as well as derived modesl (with heads)
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
if hasattr(model, tie_weights):
model.tie_weights() # make sure word embedding weights are still tied
return model
def prune_linear_layer(layer, index, dim=0): def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index. """ Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True. Return the pruned layer as a new layer with requires_grad=True.
...@@ -197,3 +361,16 @@ def prune_conv1d_layer(layer, index, dim=1): ...@@ -197,3 +361,16 @@ def prune_conv1d_layer(layer, index, dim=1):
new_layer.bias.copy_(b.contiguous()) new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True new_layer.bias.requires_grad = True
return new_layer return new_layer
def prune_layer(layer, index, dim=None):
""" Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
if isinstance(layer, nn.Linear):
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
elif isinstance(layer, Conv1D):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
...@@ -30,7 +30,7 @@ from torch import nn ...@@ -30,7 +30,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, prune_linear_layer from .model_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -64,11 +64,9 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -64,11 +64,9 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
} }
BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path): def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model """ Load tf checkpoints in a pytorch model
""" """
try: try:
...@@ -168,7 +166,8 @@ class BertConfig(PretrainedConfig): ...@@ -168,7 +166,8 @@ class BertConfig(PretrainedConfig):
max_position_embeddings=512, max_position_embeddings=512,
type_vocab_size=2, type_vocab_size=2,
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12): layer_norm_eps=1e-12,
finetuning_task=None):
"""Constructs BertConfig. """Constructs BertConfig.
Args: Args:
...@@ -193,6 +192,7 @@ class BertConfig(PretrainedConfig): ...@@ -193,6 +192,7 @@ class BertConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm. layer_norm_eps: The epsilon used by LayerNorm.
finetuning_task: name of the glue task on which the model was fine-tuned if any
""" """
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
...@@ -213,6 +213,7 @@ class BertConfig(PretrainedConfig): ...@@ -213,6 +213,7 @@ class BertConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.finetuning_task = finetuning_task
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
...@@ -539,20 +540,18 @@ class BertPreTrainingHeads(nn.Module): ...@@ -539,20 +540,18 @@ class BertPreTrainingHeads(nn.Module):
return prediction_scores, seq_relationship_score return prediction_scores, seq_relationship_score
class BertPreTrainedModel(nn.Module): class BertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
def __init__(self, config, *inputs, **kwargs): config_class = BertConfig
super(BertPreTrainedModel, self).__init__() pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
if not isinstance(config, BertConfig): pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
raise ValueError( load_tf_weights = load_tf_weights_in_bert
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. " base_model_prefix = "bert"
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( def __init__(self, *inputs, **kwargs):
self.__class__.__name__, self.__class__.__name__ super(BertPreTrainedModel, self).__init__(*inputs, **kwargs)
))
self.config = config
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
...@@ -567,152 +566,6 @@ class BertPreTrainedModel(nn.Module): ...@@ -567,152 +566,6 @@ class BertPreTrainedModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
. `bert-base-german-cased`
. `bert-large-uncased-whole-word-masking`
. `bert-large-cased-whole-word-masking`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, BERT_CONFIG_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = BertConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint
return load_tf_weights_in_bert(model, resolved_archive_file)
# Load from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
start_prefix = 'bert.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
return model
class BertModel(BertPreTrainedModel): class BertModel(BertPreTrainedModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer"). """BERT model ("Bidirectional Embedding Representations from a Transformer").
......
...@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss ...@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, prune_conv1d_layer from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer
from .modeling import BertLayerNorm as LayerNorm from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -42,7 +42,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.hugging ...@@ -42,7 +42,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.hugging
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"} "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model """ Load tf checkpoints in a pytorch model
""" """
try: try:
...@@ -356,22 +356,18 @@ class GPT2MultipleChoiceHead(nn.Module): ...@@ -356,22 +356,18 @@ class GPT2MultipleChoiceHead(nn.Module):
return multiple_choice_logits return multiple_choice_logits
class GPT2PreTrainedModel(nn.Module): class GPT2PreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = GPT2Config
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
def __init__(self, config, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super(GPT2PreTrainedModel, self).__init__() super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
if not isinstance(config, GPT2Config):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
)
)
self.config = config
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
...@@ -407,130 +403,130 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -407,130 +403,130 @@ class GPT2PreTrainedModel(nn.Module):
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific GPT2 class *inputs, **kwargs: additional input for the specific GPT2 class
""" """
state_dict = kwargs.get('state_dict', None) # state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None) # kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None) # cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None) # kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False) # from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None) # kwargs.pop('from_tf', None)
num_special_tokens = kwargs.get('num_special_tokens', None) num_special_tokens = kwargs.get('num_special_tokens', None)
kwargs.pop('num_special_tokens', None) kwargs.pop('num_special_tokens', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: # if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] # archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] # config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else: # else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) # archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) # config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary # # redirect to the cache, if necessary
try: # try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) # resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError: # except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: # if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error( # logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format( # "Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file)) # archive_file))
else: # else:
logger.error( # logger.error(
"Model name '{}' was not found in model name list ({}). " # "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} " # "We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format( # "at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, # pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file # archive_file
) # )
) # )
return None # return None
try: # try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir) # resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError: # except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP: # if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error( # logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format( # "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file)) # config_file))
else: # else:
logger.error( # logger.error(
"Model name '{}' was not found in model name list ({}). " # "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} " # "We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format( # "at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, # pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file # config_file
) # )
) # )
return None # return None
if resolved_archive_file == archive_file and resolved_config_file == config_file: # if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file)) # logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file)) # logger.info("loading configuration file {}".format(config_file))
else: # else:
logger.info("loading weights file {} from cache at {}".format( # logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file)) # archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format( # logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file)) # config_file, resolved_config_file))
# Load config # # Load config
config = GPT2Config.from_json_file(resolved_config_file) # config = GPT2Config.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config)) # logger.info("Model config {}".format(config))
# Instantiate model. # # Instantiate model.
model = cls(config, *inputs, **kwargs) # model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf: # if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu') # state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf: # if from_tf:
# Directly load from a TensorFlow checkpoint (stored as NumPy array) # # Directly load from a TensorFlow checkpoint (stored as NumPy array)
return load_tf_weights_in_gpt2(model, resolved_archive_file) # return load_tf_weights_in_gpt2(model, resolved_archive_file)
old_keys = [] # old_keys = []
new_keys = [] # new_keys = []
for key in state_dict.keys(): # for key in state_dict.keys():
new_key = None # new_key = None
if key.endswith(".g"): # if key.endswith(".g"):
new_key = key[:-2] + ".weight" # new_key = key[:-2] + ".weight"
elif key.endswith(".b"): # elif key.endswith(".b"):
new_key = key[:-2] + ".bias" # new_key = key[:-2] + ".bias"
elif key.endswith(".w"): # elif key.endswith(".w"):
new_key = key[:-2] + ".weight" # new_key = key[:-2] + ".weight"
if new_key: # if new_key:
old_keys.append(key) # old_keys.append(key)
new_keys.append(new_key) # new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys): # for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key) # state_dict[new_key] = state_dict.pop(old_key)
missing_keys = [] # missing_keys = []
unexpected_keys = [] # unexpected_keys = []
error_msgs = [] # error_msgs = []
# copy state_dict so _load_from_state_dict can modify it # # copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None) # metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy() # state_dict = state_dict.copy()
if metadata is not None: # if metadata is not None:
state_dict._metadata = metadata # state_dict._metadata = metadata
def load(module, prefix=""): # def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) # local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict( # module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs # state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
) # )
for name, child in module._modules.items(): # for name, child in module._modules.items():
if child is not None: # if child is not None:
load(child, prefix + name + ".") # load(child, prefix + name + ".")
start_model = model # start_model = model
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()): # if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
start_model = model.transformer # start_model = model.transformer
load(start_model, prefix="") # load(start_model, prefix="")
if len(missing_keys) > 0: # if len(missing_keys) > 0:
logger.info( # logger.info(
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys) # "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
) # )
if len(unexpected_keys) > 0: # if len(unexpected_keys) > 0:
logger.info( # logger.info(
"Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys) # "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
) # )
if len(error_msgs) > 0: # if len(error_msgs) > 0:
raise RuntimeError( # raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) # "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
) # )
# Add additional embeddings for special tokens if needed # Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights # This step also make sure we are still sharing the output and input embeddings after loading weights
model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special) model.set_num_special_tokens(num_special_tokens)
return model return model
...@@ -608,9 +604,9 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -608,9 +604,9 @@ class GPT2Model(GPT2PreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens=None):
" Update input embeddings with new embedding matrice if needed " " Update input embeddings with new embedding matrice if needed "
if self.config.n_special == num_special_tokens: if num_special_tokens is None or self.config.n_special == num_special_tokens:
return return
# Update config # Update config
self.config.n_special = num_special_tokens self.config.n_special = num_special_tokens
......
...@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss ...@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, prune_conv1d_layer from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer
from .modeling import BertLayerNorm as LayerNorm from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,12 +41,17 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.h ...@@ -41,12 +41,17 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.h
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"} PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path): def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here) """ Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
""" """
import re import re
import numpy as np import numpy as np
print("Loading weights...")
if '.ckpt' in openai_checkpoint_folder_path:
openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)
logger.info("Loading weights from {}".format(openai_checkpoint_folder_path))
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8')) names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8')) shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
offsets = np.cumsum([np.prod(shape) for shape in shapes]) offsets = np.cumsum([np.prod(shape) for shape in shapes])
...@@ -377,22 +382,18 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): ...@@ -377,22 +382,18 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
return multiple_choice_logits return multiple_choice_logits
class OpenAIGPTPreTrainedModel(nn.Module): class OpenAIGPTPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = OpenAIGPTConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer"
def __init__(self, config, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super(OpenAIGPTPreTrainedModel, self).__init__() super(OpenAIGPTPreTrainedModel, self).__init__(*inputs, **kwargs)
if not isinstance(config, OpenAIGPTConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `OpenAIGPTConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
)
)
self.config = config
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
...@@ -408,7 +409,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -408,7 +409,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
module.bias.data.zero_() module.bias.data.zero_()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, num_special_tokens=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" """
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
...@@ -416,140 +417,25 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -416,140 +417,25 @@ class OpenAIGPTPreTrainedModel(nn.Module):
Params: Params:
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of: - a str with the name of a pre-trained model to load selected in the list of:
. `openai-gpt`
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
. `openai_gpt_config.json` a configuration file for the model . `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance . `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
. `openai-gpt-config.json` a configuration file for the model . `config.json` a configuration file for the model
. a series of NumPy files containing OpenAI TensorFlow trained weights . a series of NumPy files containing OpenAI TensorFlow trained weights
from_tf: should we load the weights from a locally saved TensorFlow checkpoint from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached. cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific OpenAI-GPT class *inputs, **kwargs: additional input for the specific OpenAI-GPT class
""" """
state_dict = kwargs.get('state_dict', None) num_special_tokens = kwargs.get('num_special_tokens', None)
kwargs.pop('state_dict', None) kwargs.pop('num_special_tokens', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None) model = PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file
)
)
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = OpenAIGPTConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return load_tf_weights_in_openai_gpt(model, resolved_archive_file)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if key.endswith(".g"):
new_key = key[:-2] + ".weight"
elif key.endswith(".b"):
new_key = key[:-2] + ".bias"
elif key.endswith(".w"):
new_key = key[:-2] + ".weight"
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
start_model = model
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
start_model = model.transformer
load(start_model, prefix="")
if len(missing_keys) > 0:
logger.info(
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
)
if len(unexpected_keys) > 0:
logger.info(
"Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
# Add additional embeddings for special tokens if needed # Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights # This step also make sure we are still sharing the output and input embeddings after loading weights
model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special) model.set_num_special_tokens(num_special_tokens)
return model return model
...@@ -621,9 +507,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -621,9 +507,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens=None):
" Update input embeddings with new embedding matrice if needed " " Update input embeddings with new embedding matrice if needed "
if self.config.n_special == num_special_tokens: if num_special_tokens is None or self.config.n_special == num_special_tokens:
return return
# Update config # Update config
self.config.n_special = num_special_tokens self.config.n_special = num_special_tokens
......
...@@ -38,7 +38,7 @@ from torch.nn.parameter import Parameter ...@@ -38,7 +38,7 @@ from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm from .modeling import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -49,8 +49,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -49,8 +49,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
} }
TF_WEIGHTS_NAME = 'model.ckpt'
def build_tf_to_pytorch_map(model, config): def build_tf_to_pytorch_map(model, config):
""" A map of modules from TF to PyTorch. """ A map of modules from TF to PyTorch.
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible. This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
...@@ -787,28 +785,26 @@ class AdaptiveEmbedding(nn.Module): ...@@ -787,28 +785,26 @@ class AdaptiveEmbedding(nn.Module):
return embed return embed
class TransfoXLPreTrainedModel(nn.Module): class TransfoXLPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
def __init__(self, config, *inputs, **kwargs): config_class = TransfoXLConfig
super(TransfoXLPreTrainedModel, self).__init__() pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
if not isinstance(config, TransfoXLConfig): pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
raise ValueError( load_tf_weights = load_tf_weights_in_transfo_xl
"Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. " base_model_prefix = "transformer"
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( def __init__(self, *inputs, **kwargs):
self.__class__.__name__, self.__class__.__name__ super(TransfoXLPreTrainedModel, self).__init__(*inputs, **kwargs)
))
self.config = config def _init_weight(self, weight):
def init_weight(self, weight):
if self.config.init == 'uniform': if self.config.init == 'uniform':
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
elif self.config.init == 'normal': elif self.config.init == 'normal':
nn.init.normal_(weight, 0.0, self.config.init_std) nn.init.normal_(weight, 0.0, self.config.init_std)
def init_bias(self, bias): def _init_bias(self, bias):
nn.init.constant_(bias, 0.0) nn.init.constant_(bias, 0.0)
def init_weights(self, m): def init_weights(self, m):
...@@ -817,9 +813,9 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -817,9 +813,9 @@ class TransfoXLPreTrainedModel(nn.Module):
classname = m.__class__.__name__ classname = m.__class__.__name__
if classname.find('Linear') != -1: if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None: if hasattr(m, 'weight') and m.weight is not None:
self.init_weight(m.weight) self._init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None: if hasattr(m, 'bias') and m.bias is not None:
self.init_bias(m.bias) self._init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1: elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'): if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)): for i in range(len(m.emb_projs)):
...@@ -827,12 +823,12 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -827,12 +823,12 @@ class TransfoXLPreTrainedModel(nn.Module):
nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std) nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
elif classname.find('Embedding') != -1: elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'): if hasattr(m, 'weight'):
self.init_weight(m.weight) self._init_weight(m.weight)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None: if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
self.init_weight(m.cluster_weight) self._init_weight(m.cluster_weight)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None: if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
self.init_bias(m.cluster_bias) self._init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'): if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)): for i in range(len(m.out_projs)):
if m.out_projs[i] is not None: if m.out_projs[i] is not None:
...@@ -841,144 +837,20 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -841,144 +837,20 @@ class TransfoXLPreTrainedModel(nn.Module):
if hasattr(m, 'weight'): if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, self.config.init_std) nn.init.normal_(m.weight, 1.0, self.config.init_std)
if hasattr(m, 'bias') and m.bias is not None: if hasattr(m, 'bias') and m.bias is not None:
self.init_bias(m.bias) self._init_bias(m.bias)
elif classname.find('TransformerLM') != -1: elif classname.find('TransformerLM') != -1:
if hasattr(m, 'r_emb'): if hasattr(m, 'r_emb'):
self.init_weight(m.r_emb) self._init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'): if hasattr(m, 'r_w_bias'):
self.init_weight(m.r_w_bias) self._init_weight(m.r_w_bias)
if hasattr(m, 'r_r_bias'): if hasattr(m, 'r_r_bias'):
self.init_weight(m.r_r_bias) self._init_weight(m.r_r_bias)
if hasattr(m, 'r_bias'): if hasattr(m, 'r_bias'):
self.init_bias(m.r_bias) self._init_bias(m.r_bias)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens):
pass pass
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl-wt103`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific TransformerXL class
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file
)
)
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = TransfoXLConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint
return load_tf_weights_in_transfo_xl(model, config, pretrained_model_name_or_path)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, 'transformer') and any(s.startswith('transformer.') for s in state_dict.keys()):
start_prefix = 'transformer.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
# Make sure we are still sharing the input and output embeddings
if hasattr(model, 'tie_weights'):
model.tie_weights()
return model
class TransfoXLModel(TransfoXLPreTrainedModel): class TransfoXLModel(TransfoXLPreTrainedModel):
"""Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"). """Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").
......
...@@ -36,7 +36,7 @@ from torch.nn import functional as F ...@@ -36,7 +36,7 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -390,20 +390,18 @@ class BeamHypotheses(object): ...@@ -390,20 +390,18 @@ class BeamHypotheses(object):
return self.worst_score >= best_sum_logprobs / self.max_len ** self.length_penalty return self.worst_score >= best_sum_logprobs / self.max_len ** self.length_penalty
class XLMPreTrainedModel(nn.Module): class XLMPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
def __init__(self, config, *inputs, **kwargs): config_class = XLMConfig
super(XLMPreTrainedModel, self).__init__() pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
if not isinstance(config, XLMBaseConfig): pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
raise ValueError( load_tf_weights = None
"Parameter config in `{}(config)` should be an instance of class `XLMBaseConfig`. " base_model_prefix = "xlm"
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( def __init__(self, *inputs, **kwargs):
self.__class__.__name__, self.__class__.__name__ super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
))
self.config = config
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
...@@ -423,138 +421,6 @@ class XLMPreTrainedModel(nn.Module): ...@@ -423,138 +421,6 @@ class XLMPreTrainedModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a XLMPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLMForPreTraining instance
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLM class
(ex: num_labels for XLMForSequenceClassification)
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, XLNET_CONFIG_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = XLMConfig.from_json_file(resolved_config_file)
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
# Load from a PyTorch state_dict
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, 'transformer') and any(s.startswith('transformer') for s in state_dict.keys()):
start_prefix = 'transformer.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
if isinstance(model, XLMLMHeadModel):
model.tie_weights() # make sure word embedding weights are still tied
return model
class XLMModel(XLMPreTrainedModel): class XLMModel(XLMPreTrainedModel):
......
...@@ -33,7 +33,7 @@ from torch.nn import functional as F ...@@ -33,7 +33,7 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -44,11 +44,9 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -44,11 +44,9 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
PRETRAINED_CONFIG_ARCHIVE_MAP = { PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
} }
XLNET_CONFIG_NAME = 'xlnet_config.json'
TF_WEIGHTS_NAME = 'model.ckpt'
def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_task=None): def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
""" A map of modules from TF to PyTorch. """ A map of modules from TF to PyTorch.
I use a map to keep the PyTorch model as I use a map to keep the PyTorch model as
identical to the original PyTorch model as possible. identical to the original PyTorch model as possible.
...@@ -64,9 +62,9 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_tas ...@@ -64,9 +62,9 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_tas
# We will load also the sequence summary # We will load also the sequence summary
tf_to_pt_map['model/sequnece_summary/summary/kernel'] = model.sequence_summary.summary.weight tf_to_pt_map['model/sequnece_summary/summary/kernel'] = model.sequence_summary.summary.weight
tf_to_pt_map['model/sequnece_summary/summary/bias'] = model.sequence_summary.summary.bias tf_to_pt_map['model/sequnece_summary/summary/bias'] = model.sequence_summary.summary.bias
if hasattr(model, 'logits_proj') and finetuning_task is not None and 'model/regression_{}/logit/kernel'.format(finetuning_task) in tf_weights: if hasattr(model, 'logits_proj') and config.finetuning_task is not None and 'model/regression_{}/logit/kernel'.format(finetuning_task) in tf_weights:
tf_to_pt_map['model/regression_{}/logit/kernel'.format(finetuning_task)] = model.logits_proj.weight tf_to_pt_map['model/regression_{}/logit/kernel'.format(config.finetuning_task)] = model.logits_proj.weight
tf_to_pt_map['model/regression_{}/logit/bias'.format(finetuning_task)] = model.logits_proj.bias tf_to_pt_map['model/regression_{}/logit/bias'.format(config.finetuning_task)] = model.logits_proj.bias
# Now load the rest of the transformer # Now load the rest of the transformer
model = model.transformer model = model.transformer
...@@ -117,7 +115,7 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_tas ...@@ -117,7 +115,7 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_tas
'model/transformer/seg_embed': seg_embed_list}) 'model/transformer/seg_embed': seg_embed_list})
return tf_to_pt_map return tf_to_pt_map
def load_tf_weights_in_xlnet(model, config, tf_path, finetuning_task=None): def load_tf_weights_in_xlnet(model, config, tf_path):
""" Load tf checkpoints in a pytorch model """ Load tf checkpoints in a pytorch model
""" """
try: try:
...@@ -138,7 +136,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path, finetuning_task=None): ...@@ -138,7 +136,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path, finetuning_task=None):
input("Press Enter to continue...") input("Press Enter to continue...")
# Build TF to PyTorch weights loading map # Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights, finetuning_task) tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights)
for name, pointer in tf_to_pt_map.items(): for name, pointer in tf_to_pt_map.items():
print("Importing {}".format(name)) print("Importing {}".format(name))
...@@ -223,7 +221,8 @@ class XLNetConfig(PretrainedConfig): ...@@ -223,7 +221,8 @@ class XLNetConfig(PretrainedConfig):
reuse_len=None, reuse_len=None,
bi_data=False, bi_data=False,
clamp_len=-1, clamp_len=-1,
same_length=False): same_length=False,
finetuning_task=None):
"""Constructs XLNetConfig. """Constructs XLNetConfig.
Args: Args:
...@@ -265,6 +264,7 @@ class XLNetConfig(PretrainedConfig): ...@@ -265,6 +264,7 @@ class XLNetConfig(PretrainedConfig):
clamp_len: int, clamp all relative distances larger than clamp_len. clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping. -1 means no clamping.
same_length: bool, whether to use the same attention length for each token. same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any
""" """
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
...@@ -298,6 +298,7 @@ class XLNetConfig(PretrainedConfig): ...@@ -298,6 +298,7 @@ class XLNetConfig(PretrainedConfig):
self.bi_data = bi_data self.bi_data = bi_data
self.clamp_len = clamp_len self.clamp_len = clamp_len
self.same_length = same_length self.same_length = same_length
self.finetuning_task = finetuning_task
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
...@@ -550,20 +551,19 @@ class XLNetLayer(nn.Module): ...@@ -550,20 +551,19 @@ class XLNetLayer(nn.Module):
# return attentions, layer_output # return attentions, layer_output
return output_h, output_g return output_h, output_g
class XLNetPreTrainedModel(nn.Module):
class XLNetPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
def __init__(self, config, *inputs, **kwargs): config_class = XLNetConfig
super(XLNetPreTrainedModel, self).__init__() pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
if not isinstance(config, XLNetConfig): pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
raise ValueError( load_tf_weights = load_tf_weights_in_xlnet
"Parameter config in `{}(config)` should be an instance of class `XLNetConfig`. " base_model_prefix = "transformer"
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( def __init__(self, *inputs, **kwargs):
self.__class__.__name__, self.__class__.__name__ super(XLNetPreTrainedModel, self).__init__(*inputs, **kwargs)
))
self.config = config
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
...@@ -583,144 +583,6 @@ class XLNetPreTrainedModel(nn.Module): ...@@ -583,144 +583,6 @@ class XLNetPreTrainedModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a XLNetPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, XLNET_CONFIG_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = XLNetConfig.from_json_file(resolved_config_file)
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint
return load_tf_weights_in_xlnet(model, config, resolved_archive_file)
# Load from a PyTorch state_dict
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, 'transformer') and any(s.startswith('transformer') for s in state_dict.keys()):
start_prefix = 'transformer.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
if isinstance(model, XLNetLMHeadModel):
model.tie_weights() # make sure word embedding weights are still tied
return model
class XLNetModel(XLNetPreTrainedModel): class XLNetModel(XLNetPreTrainedModel):
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config, output_attentions=False, keep_multihead_output=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment