Commit 4d47f498 authored by thomwolf's avatar thomwolf
Browse files

slight refactoring, add abstract class for model loading

parent 59cefd4f
......@@ -28,4 +28,5 @@ from .optimization_openai import OpenAIAdam
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
from .model_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig)
from .model_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
......@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
TF_WEIGHTS_NAME = 'model.ckpt'
class PretrainedConfig(object):
......@@ -131,6 +132,169 @@ class PretrainedConfig(object):
writer.write(self.to_json_string())
class PreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = PretrainedConfig
pretrained_model_archive_map = {}
pretrained_config_archive_map = {}
load_tf_weights = lambda model, config, path: None
base_model_prefix = ""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
if not isinstance(config, PretrainedConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load, or
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a tensorflow pretrained model checkpoint containing:
. `config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use
instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
else:
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = cls.config_class.from_json_file(resolved_config_file)
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint
return load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
# Load from a PyTorch state_dict
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
start_prefix = cls.base_model_prefix + '.' # Used to be able to load base models as well as derived modesl (with heads)
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
if hasattr(model, tie_weights):
model.tie_weights() # make sure word embedding weights are still tied
return model
def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
......@@ -197,3 +361,16 @@ def prune_conv1d_layer(layer, index, dim=1):
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
def prune_layer(layer, index, dim=None):
""" Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
if isinstance(layer, nn.Linear):
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
elif isinstance(layer, Conv1D):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
......@@ -30,7 +30,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, prune_linear_layer
from .model_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
logger = logging.getLogger(__name__)
......@@ -64,11 +64,9 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
}
BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path):
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
try:
......@@ -168,7 +166,8 @@ class BertConfig(PretrainedConfig):
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12):
layer_norm_eps=1e-12,
finetuning_task=None):
"""Constructs BertConfig.
Args:
......@@ -193,6 +192,7 @@ class BertConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
......@@ -213,6 +213,7 @@ class BertConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.finetuning_task = finetuning_task
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
......@@ -539,20 +540,18 @@ class BertPreTrainingHeads(nn.Module):
return prediction_scores, seq_relationship_score
class BertPreTrainedModel(nn.Module):
class BertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(BertPreTrainedModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
config_class = BertConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
def __init__(self, *inputs, **kwargs):
super(BertPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
......@@ -567,152 +566,6 @@ class BertPreTrainedModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
. `bert-base-german-cased`
. `bert-large-uncased-whole-word-masking`
. `bert-large-cased-whole-word-masking`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, BERT_CONFIG_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = BertConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint
return load_tf_weights_in_bert(model, resolved_archive_file)
# Load from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
start_prefix = 'bert.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
return model
class BertModel(BertPreTrainedModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
......
......@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, prune_conv1d_layer
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer
from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
......@@ -42,7 +42,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.hugging
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
try:
......@@ -356,22 +356,18 @@ class GPT2MultipleChoiceHead(nn.Module):
return multiple_choice_logits
class GPT2PreTrainedModel(nn.Module):
class GPT2PreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = GPT2Config
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
def __init__(self, config, *inputs, **kwargs):
super(GPT2PreTrainedModel, self).__init__()
if not isinstance(config, GPT2Config):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
)
)
self.config = config
def __init__(self, *inputs, **kwargs):
super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
......@@ -407,130 +403,130 @@ class GPT2PreTrainedModel(nn.Module):
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific GPT2 class
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
# state_dict = kwargs.get('state_dict', None)
# kwargs.pop('state_dict', None)
# cache_dir = kwargs.get('cache_dir', None)
# kwargs.pop('cache_dir', None)
# from_tf = kwargs.get('from_tf', False)
# kwargs.pop('from_tf', None)
num_special_tokens = kwargs.get('num_special_tokens', None)
kwargs.pop('num_special_tokens', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file
)
)
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = GPT2Config.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return load_tf_weights_in_gpt2(model, resolved_archive_file)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if key.endswith(".g"):
new_key = key[:-2] + ".weight"
elif key.endswith(".b"):
new_key = key[:-2] + ".bias"
elif key.endswith(".w"):
new_key = key[:-2] + ".weight"
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
start_model = model
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
start_model = model.transformer
load(start_model, prefix="")
if len(missing_keys) > 0:
logger.info(
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
)
if len(unexpected_keys) > 0:
logger.info(
"Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
# if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
# archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
# config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
# else:
# archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
# config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# # redirect to the cache, if necessary
# try:
# resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
# except EnvironmentError:
# if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
# logger.error(
# "Couldn't reach server at '{}' to download pretrained weights.".format(
# archive_file))
# else:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find file {} "
# "at this path or url.".format(
# pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
# archive_file
# )
# )
# return None
# try:
# resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
# except EnvironmentError:
# if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
# logger.error(
# "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
# config_file))
# else:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find file {} "
# "at this path or url.".format(
# pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
# config_file
# )
# )
# return None
# if resolved_archive_file == archive_file and resolved_config_file == config_file:
# logger.info("loading weights file {}".format(archive_file))
# logger.info("loading configuration file {}".format(config_file))
# else:
# logger.info("loading weights file {} from cache at {}".format(
# archive_file, resolved_archive_file))
# logger.info("loading configuration file {} from cache at {}".format(
# config_file, resolved_config_file))
# # Load config
# config = GPT2Config.from_json_file(resolved_config_file)
# logger.info("Model config {}".format(config))
# # Instantiate model.
# model = cls(config, *inputs, **kwargs)
# if state_dict is None and not from_tf:
# state_dict = torch.load(resolved_archive_file, map_location='cpu')
# if from_tf:
# # Directly load from a TensorFlow checkpoint (stored as NumPy array)
# return load_tf_weights_in_gpt2(model, resolved_archive_file)
# old_keys = []
# new_keys = []
# for key in state_dict.keys():
# new_key = None
# if key.endswith(".g"):
# new_key = key[:-2] + ".weight"
# elif key.endswith(".b"):
# new_key = key[:-2] + ".bias"
# elif key.endswith(".w"):
# new_key = key[:-2] + ".weight"
# if new_key:
# old_keys.append(key)
# new_keys.append(new_key)
# for old_key, new_key in zip(old_keys, new_keys):
# state_dict[new_key] = state_dict.pop(old_key)
# missing_keys = []
# unexpected_keys = []
# error_msgs = []
# # copy state_dict so _load_from_state_dict can modify it
# metadata = getattr(state_dict, "_metadata", None)
# state_dict = state_dict.copy()
# if metadata is not None:
# state_dict._metadata = metadata
# def load(module, prefix=""):
# local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
# module._load_from_state_dict(
# state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
# )
# for name, child in module._modules.items():
# if child is not None:
# load(child, prefix + name + ".")
# start_model = model
# if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
# start_model = model.transformer
# load(start_model, prefix="")
# if len(missing_keys) > 0:
# logger.info(
# "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
# )
# if len(unexpected_keys) > 0:
# logger.info(
# "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
# )
# if len(error_msgs) > 0:
# raise RuntimeError(
# "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
# )
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special)
model.set_num_special_tokens(num_special_tokens)
return model
......@@ -608,9 +604,9 @@ class GPT2Model(GPT2PreTrainedModel):
self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
def set_num_special_tokens(self, num_special_tokens=None):
" Update input embeddings with new embedding matrice if needed "
if self.config.n_special == num_special_tokens:
if num_special_tokens is None or self.config.n_special == num_special_tokens:
return
# Update config
self.config.n_special = num_special_tokens
......
......@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, prune_conv1d_layer
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer
from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
......@@ -41,12 +41,17 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.h
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
"""
import re
import numpy as np
print("Loading weights...")
if '.ckpt' in openai_checkpoint_folder_path:
openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)
logger.info("Loading weights from {}".format(openai_checkpoint_folder_path))
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
offsets = np.cumsum([np.prod(shape) for shape in shapes])
......@@ -377,22 +382,18 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
return multiple_choice_logits
class OpenAIGPTPreTrainedModel(nn.Module):
class OpenAIGPTPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = OpenAIGPTConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer"
def __init__(self, config, *inputs, **kwargs):
super(OpenAIGPTPreTrainedModel, self).__init__()
if not isinstance(config, OpenAIGPTConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `OpenAIGPTConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
)
)
self.config = config
def __init__(self, *inputs, **kwargs):
super(OpenAIGPTPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
......@@ -408,7 +409,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, num_special_tokens=None, *inputs, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
......@@ -416,140 +417,25 @@ class OpenAIGPTPreTrainedModel(nn.Module):
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `openai-gpt`
- a path or url to a pretrained model archive containing:
. `openai_gpt_config.json` a configuration file for the model
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
- a path or url to a pretrained model archive containing:
. `openai-gpt-config.json` a configuration file for the model
. `config.json` a configuration file for the model
. a series of NumPy files containing OpenAI TensorFlow trained weights
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific OpenAI-GPT class
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file
)
)
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = OpenAIGPTConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return load_tf_weights_in_openai_gpt(model, resolved_archive_file)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if key.endswith(".g"):
new_key = key[:-2] + ".weight"
elif key.endswith(".b"):
new_key = key[:-2] + ".bias"
elif key.endswith(".w"):
new_key = key[:-2] + ".weight"
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
start_model = model
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
start_model = model.transformer
load(start_model, prefix="")
if len(missing_keys) > 0:
logger.info(
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
)
if len(unexpected_keys) > 0:
logger.info(
"Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
num_special_tokens = kwargs.get('num_special_tokens', None)
kwargs.pop('num_special_tokens', None)
model = PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special)
model.set_num_special_tokens(num_special_tokens)
return model
......@@ -621,9 +507,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
def set_num_special_tokens(self, num_special_tokens=None):
" Update input embeddings with new embedding matrice if needed "
if self.config.n_special == num_special_tokens:
if num_special_tokens is None or self.config.n_special == num_special_tokens:
return
# Update config
self.config.n_special = num_special_tokens
......
......@@ -38,7 +38,7 @@ from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
logger = logging.getLogger(__name__)
......@@ -49,8 +49,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
}
TF_WEIGHTS_NAME = 'model.ckpt'
def build_tf_to_pytorch_map(model, config):
""" A map of modules from TF to PyTorch.
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
......@@ -787,28 +785,26 @@ class AdaptiveEmbedding(nn.Module):
return embed
class TransfoXLPreTrainedModel(nn.Module):
class TransfoXLPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(TransfoXLPreTrainedModel, self).__init__()
if not isinstance(config, TransfoXLConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_weight(self, weight):
config_class = TransfoXLConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(TransfoXLPreTrainedModel, self).__init__(*inputs, **kwargs)
def _init_weight(self, weight):
if self.config.init == 'uniform':
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
elif self.config.init == 'normal':
nn.init.normal_(weight, 0.0, self.config.init_std)
def init_bias(self, bias):
def _init_bias(self, bias):
nn.init.constant_(bias, 0.0)
def init_weights(self, m):
......@@ -817,9 +813,9 @@ class TransfoXLPreTrainedModel(nn.Module):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
self.init_weight(m.weight)
self._init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
self.init_bias(m.bias)
self._init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)):
......@@ -827,12 +823,12 @@ class TransfoXLPreTrainedModel(nn.Module):
nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
self.init_weight(m.weight)
self._init_weight(m.weight)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
self.init_weight(m.cluster_weight)
self._init_weight(m.cluster_weight)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
self.init_bias(m.cluster_bias)
self._init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
......@@ -841,144 +837,20 @@ class TransfoXLPreTrainedModel(nn.Module):
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, self.config.init_std)
if hasattr(m, 'bias') and m.bias is not None:
self.init_bias(m.bias)
self._init_bias(m.bias)
elif classname.find('TransformerLM') != -1:
if hasattr(m, 'r_emb'):
self.init_weight(m.r_emb)
self._init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'):
self.init_weight(m.r_w_bias)
self._init_weight(m.r_w_bias)
if hasattr(m, 'r_r_bias'):
self.init_weight(m.r_r_bias)
self._init_weight(m.r_r_bias)
if hasattr(m, 'r_bias'):
self.init_bias(m.r_bias)
self._init_bias(m.r_bias)
def set_num_special_tokens(self, num_special_tokens):
pass
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl-wt103`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific TransformerXL class
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file
)
)
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = TransfoXLConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint
return load_tf_weights_in_transfo_xl(model, config, pretrained_model_name_or_path)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, 'transformer') and any(s.startswith('transformer.') for s in state_dict.keys()):
start_prefix = 'transformer.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
# Make sure we are still sharing the input and output embeddings
if hasattr(model, 'tie_weights'):
model.tie_weights()
return model
class TransfoXLModel(TransfoXLPreTrainedModel):
"""Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").
......
......@@ -36,7 +36,7 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
logger = logging.getLogger(__name__)
......@@ -390,20 +390,18 @@ class BeamHypotheses(object):
return self.worst_score >= best_sum_logprobs / self.max_len ** self.length_penalty
class XLMPreTrainedModel(nn.Module):
class XLMPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(XLMPreTrainedModel, self).__init__()
if not isinstance(config, XLMBaseConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `XLMBaseConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
config_class = XLMConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = None
base_model_prefix = "xlm"
def __init__(self, *inputs, **kwargs):
super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
......@@ -423,138 +421,6 @@ class XLMPreTrainedModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a XLMPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLMForPreTraining instance
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLM class
(ex: num_labels for XLMForSequenceClassification)
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, XLNET_CONFIG_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = XLMConfig.from_json_file(resolved_config_file)
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
# Load from a PyTorch state_dict
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, 'transformer') and any(s.startswith('transformer') for s in state_dict.keys()):
start_prefix = 'transformer.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
if isinstance(model, XLMLMHeadModel):
model.tie_weights() # make sure word embedding weights are still tied
return model
class XLMModel(XLMPreTrainedModel):
......
......@@ -33,7 +33,7 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
logger = logging.getLogger(__name__)
......@@ -44,11 +44,9 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
}
XLNET_CONFIG_NAME = 'xlnet_config.json'
TF_WEIGHTS_NAME = 'model.ckpt'
def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_task=None):
def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
""" A map of modules from TF to PyTorch.
I use a map to keep the PyTorch model as
identical to the original PyTorch model as possible.
......@@ -64,9 +62,9 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_tas
# We will load also the sequence summary
tf_to_pt_map['model/sequnece_summary/summary/kernel'] = model.sequence_summary.summary.weight
tf_to_pt_map['model/sequnece_summary/summary/bias'] = model.sequence_summary.summary.bias
if hasattr(model, 'logits_proj') and finetuning_task is not None and 'model/regression_{}/logit/kernel'.format(finetuning_task) in tf_weights:
tf_to_pt_map['model/regression_{}/logit/kernel'.format(finetuning_task)] = model.logits_proj.weight
tf_to_pt_map['model/regression_{}/logit/bias'.format(finetuning_task)] = model.logits_proj.bias
if hasattr(model, 'logits_proj') and config.finetuning_task is not None and 'model/regression_{}/logit/kernel'.format(finetuning_task) in tf_weights:
tf_to_pt_map['model/regression_{}/logit/kernel'.format(config.finetuning_task)] = model.logits_proj.weight
tf_to_pt_map['model/regression_{}/logit/bias'.format(config.finetuning_task)] = model.logits_proj.bias
# Now load the rest of the transformer
model = model.transformer
......@@ -117,7 +115,7 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_tas
'model/transformer/seg_embed': seg_embed_list})
return tf_to_pt_map
def load_tf_weights_in_xlnet(model, config, tf_path, finetuning_task=None):
def load_tf_weights_in_xlnet(model, config, tf_path):
""" Load tf checkpoints in a pytorch model
"""
try:
......@@ -138,7 +136,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path, finetuning_task=None):
input("Press Enter to continue...")
# Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights, finetuning_task)
tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights)
for name, pointer in tf_to_pt_map.items():
print("Importing {}".format(name))
......@@ -223,7 +221,8 @@ class XLNetConfig(PretrainedConfig):
reuse_len=None,
bi_data=False,
clamp_len=-1,
same_length=False):
same_length=False,
finetuning_task=None):
"""Constructs XLNetConfig.
Args:
......@@ -265,6 +264,7 @@ class XLNetConfig(PretrainedConfig):
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
......@@ -298,6 +298,7 @@ class XLNetConfig(PretrainedConfig):
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
self.finetuning_task = finetuning_task
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
......@@ -550,20 +551,19 @@ class XLNetLayer(nn.Module):
# return attentions, layer_output
return output_h, output_g
class XLNetPreTrainedModel(nn.Module):
class XLNetPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(XLNetPreTrainedModel, self).__init__()
if not isinstance(config, XLNetConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `XLNetConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
config_class = XLNetConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_xlnet
base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(XLNetPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
......@@ -583,144 +583,6 @@ class XLNetPreTrainedModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a XLNetPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, XLNET_CONFIG_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = XLNetConfig.from_json_file(resolved_config_file)
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint
return load_tf_weights_in_xlnet(model, config, resolved_archive_file)
# Load from a PyTorch state_dict
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, 'transformer') and any(s.startswith('transformer') for s in state_dict.keys()):
start_prefix = 'transformer.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
if isinstance(model, XLNetLMHeadModel):
model.tie_weights() # make sure word embedding weights are still tied
return model
class XLNetModel(XLNetPreTrainedModel):
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment