Commit d77dd62f authored by thomwolf's avatar thomwolf
Browse files

directly load from TF checkpoints + code cleanup

parent 9c35c132
...@@ -2,6 +2,7 @@ __version__ = "0.5.0" ...@@ -2,6 +2,7 @@ __version__ = "0.5.0"
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
from .modeling import (BertConfig, BertModel, BertForPreTraining, from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice, BertForSequenceClassification, BertForMultipleChoice,
...@@ -9,6 +10,11 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining, ...@@ -9,6 +10,11 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining,
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel) from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel)
from .optimization import BertAdam from .optimization import BertAdam
from .optimization_openai import OpenAIAdam from .optimization_openai import OpenAIAdam
from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt
from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert
from .convert_transfo_xl_checkpoint_to_pytorch import load_tf_weights_in_transfo_xl
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
...@@ -26,9 +26,29 @@ import numpy as np ...@@ -26,9 +26,29 @@ import numpy as np
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
# Load weights from TF model # Construct model
if openai_config_file == "":
config = OpenAIGPTConfig()
else:
config = OpenAIGPTConfig(openai_config_file)
model = OpenAIGPTModel(config)
# Load weights from numpy
load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path)
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
"""
print("Loading weights...") print("Loading weights...")
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8')) names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8')) shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
...@@ -36,35 +56,11 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c ...@@ -36,35 +56,11 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)] init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
# if n_ctx > 0:
# init_params[0] = init_params[0][:n_ctx]
# if n_special > 0:
# init_params[0] = np.concatenate(
# [init_params[1],
# (np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
# init_params[0]
# ], 0)
# else:
# init_params[0] = np.concatenate(
# [init_params[1],
# init_params[0]
# ], 0)
# del init_params[1]
# if n_transfer == -1:
# n_transfer = 0
# else:
# n_transfer = 1 + n_transfer * 12
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0) init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
del init_params[1] del init_params[1]
init_params = [arr.squeeze() for arr in init_params] init_params = [arr.squeeze() for arr in init_params]
# Construct model
if openai_config_file == "":
config = OpenAIGPTConfig()
else:
config = OpenAIGPTConfig(openai_config_file)
model = OpenAIGPTModel(config)
try: try:
assert model.embed.weight.shape == init_params[0].shape assert model.embed.weight.shape == init_params[0].shape
except AssertionError as e: except AssertionError as e:
...@@ -109,15 +105,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c ...@@ -109,15 +105,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
raise raise
print("Initialize PyTorch weight {}".format(name)) print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array) pointer.data = torch.from_numpy(array)
return model
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -28,9 +28,23 @@ import numpy as np ...@@ -28,9 +28,23 @@ import numpy as np
from .modeling import BertConfig, BertForPreTraining from .modeling import BertConfig, BertForPreTraining
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
config_path = os.path.abspath(bert_config_file) # Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_bert(model, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
tf_path = os.path.abspath(tf_checkpoint_path) tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path)) print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model # Load weights from TF model
init_vars = tf.train.list_variables(tf_path) init_vars = tf.train.list_variables(tf_path)
names = [] names = []
...@@ -41,11 +55,6 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor ...@@ -41,11 +55,6 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
names.append(name) names.append(name)
arrays.append(array) arrays.append(array)
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
for name, array in zip(names, arrays): for name, array in zip(names, arrays):
name = name.split('/') name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
...@@ -81,11 +90,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor ...@@ -81,11 +90,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
raise raise
print("Initialize PyTorch weight {}".format(name)) print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array) pointer.data = torch.from_numpy(array)
return model
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -106,7 +106,6 @@ def build_tf_to_pytorch_map(model, config): ...@@ -106,7 +106,6 @@ def build_tf_to_pytorch_map(model, config):
'transformer/r_w_bias': r_w_list}) 'transformer/r_w_bias': r_w_list})
return tf_to_pt_map return tf_to_pt_map
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
transfo_xl_config_file, transfo_xl_config_file,
pytorch_dump_folder_path, pytorch_dump_folder_path,
...@@ -140,50 +139,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, ...@@ -140,50 +139,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
print("Building PyTorch model from configuration: {}".format(str(config))) print("Building PyTorch model from configuration: {}".format(str(config)))
model = TransfoXLModel(config) model = TransfoXLModel(config)
# Build TF to PyTorch weights loading map model = load_tf_weights_in_transfo_xl(model, config, tf_path)
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tf_weights = {}
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_weights[name] = array
for name, pointer in tf_to_pt_map.items():
assert name in tf_weights
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if 'kernel' in name or 'proj' in name:
array = np.transpose(array)
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
# Here we will split the TF weigths
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
arr_i = array[i, ...]
try:
assert p_i.shape == arr_i.shape
except AssertionError as e:
e.args += (p_i.shape, arr_i.shape)
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
else:
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
tf_weights.pop(name, None)
tf_weights.pop(name + '/Adam', None)
tf_weights.pop(name + '/Adam_1', None)
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
...@@ -194,6 +150,54 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, ...@@ -194,6 +150,54 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
f.write(config.to_json_string()) f.write(config.to_json_string())
def load_tf_weights_in_transfo_xl(model, config, tf_path):
""" Load tf checkpoints in a pytorch model
"""
# Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tf_weights = {}
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_weights[name] = array
for name, pointer in tf_to_pt_map.items():
assert name in tf_weights
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if 'kernel' in name or 'proj' in name:
array = np.transpose(array)
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
# Here we will split the TF weigths
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
arr_i = array[i, ...]
try:
assert p_i.shape == arr_i.shape
except AssertionError as e:
e.args += (p_i.shape, arr_i.shape)
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
else:
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
tf_weights.pop(name, None)
tf_weights.pop(name + '/Adam', None)
tf_weights.pop(name + '/Adam_1', None)
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
return model
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
......
...@@ -33,6 +33,7 @@ from torch import nn ...@@ -33,6 +33,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .file_utils import cached_path from .file_utils import cached_path
from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -47,6 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -47,6 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
CONFIG_NAME = 'bert_config.json' CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin' WEIGHTS_NAME = 'pytorch_model.bin'
TF_WEIGHTS_NAME = 'model.ckpt'
def gelu(x): def gelu(x):
"""Implementation of the gelu activation function. """Implementation of the gelu activation function.
...@@ -445,7 +447,8 @@ class BertPreTrainedModel(nn.Module): ...@@ -445,7 +447,8 @@ class BertPreTrainedModel(nn.Module):
module.bias.data.zero_() module.bias.data.zero_()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
from_tf=False, *inputs, **kwargs):
""" """
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
...@@ -463,6 +466,10 @@ class BertPreTrainedModel(nn.Module): ...@@ -463,6 +466,10 @@ class BertPreTrainedModel(nn.Module):
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model . `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached. cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class *inputs, **kwargs: additional input for the specific Bert class
...@@ -490,7 +497,7 @@ class BertPreTrainedModel(nn.Module): ...@@ -490,7 +497,7 @@ class BertPreTrainedModel(nn.Module):
logger.info("loading archive file {} from cache at {}".format( logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file)) archive_file, resolved_archive_file))
tempdir = None tempdir = None
if os.path.isdir(resolved_archive_file): if os.path.isdir(resolved_archive_file) or from_tf:
serialization_dir = resolved_archive_file serialization_dir = resolved_archive_file
else: else:
# Extract archive to temp dir # Extract archive to temp dir
...@@ -506,10 +513,17 @@ class BertPreTrainedModel(nn.Module): ...@@ -506,10 +513,17 @@ class BertPreTrainedModel(nn.Module):
logger.info("Model config {}".format(config)) logger.info("Model config {}".format(config))
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
if state_dict is None: if state_dict is None and not from_tf:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path) state_dict = torch.load(weights_path)
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
if from_tf:
# Directly load from a TensorFlow checkpoint
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
return load_tf_weights_in_bert(model, weights_path)
# Load from a PyTorch state_dict
old_keys = [] old_keys = []
new_keys = [] new_keys = []
for key in state_dict.keys(): for key in state_dict.keys():
...@@ -550,9 +564,6 @@ class BertPreTrainedModel(nn.Module): ...@@ -550,9 +564,6 @@ class BertPreTrainedModel(nn.Module):
if len(error_msgs) > 0: if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs))) model.__class__.__name__, "\n\t".join(error_msgs)))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return model return model
......
...@@ -32,14 +32,14 @@ from torch.nn.parameter import Parameter ...@@ -32,14 +32,14 @@ from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm from .modeling import BertLayerNorm as LayerNorm
from .file_utils import cached_path from .file_utils import cached_path
from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = { PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz"}
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz", CONFIG_NAME = "openai_gpt_config.json"
} WEIGHTS_NAME = "pytorch_model.bin"
CONFIG_NAME = 'openai_gpt_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
def gelu(x): def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
...@@ -49,27 +49,27 @@ def swish(x): ...@@ -49,27 +49,27 @@ def swish(x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
ACT_FNS = { ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
'relu': nn.ReLU,
'swish': swish,
'gelu': gelu
}
class OpenAIGPTConfig(object): class OpenAIGPTConfig(object):
"""Configuration class to store the configuration of a `OpenAIGPTModel`. """Configuration class to store the configuration of a `OpenAIGPTModel`.
""" """
def __init__(self,
vocab_size_or_config_json_file=40478, def __init__(
n_special=0, self,
n_ctx=512, vocab_size_or_config_json_file=40478,
n_embd=768, n_special=0,
n_layer=12, n_ctx=512,
n_head=12, n_embd=768,
afn="gelu", n_layer=12,
resid_pdrop=0.1, n_head=12,
embd_pdrop=0.1, afn="gelu",
attn_pdrop=0.1, resid_pdrop=0.1,
initializer_range=0.02): embd_pdrop=0.1,
attn_pdrop=0.1,
initializer_range=0.02,
):
"""Constructs OpenAIGPTConfig. """Constructs OpenAIGPTConfig.
Args: Args:
...@@ -91,7 +91,7 @@ class OpenAIGPTConfig(object): ...@@ -91,7 +91,7 @@ class OpenAIGPTConfig(object):
initializing all weight matrices. initializing all weight matrices.
""" """
if isinstance(vocab_size_or_config_json_file, str): if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read()) json_config = json.loads(reader.read())
for key, value in json_config.items(): for key, value in json_config.items():
self.__dict__[key] = value self.__dict__[key] = value
...@@ -108,8 +108,10 @@ class OpenAIGPTConfig(object): ...@@ -108,8 +108,10 @@ class OpenAIGPTConfig(object):
self.attn_pdrop = attn_pdrop self.attn_pdrop = attn_pdrop
self.initializer_range = initializer_range self.initializer_range = initializer_range
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError(
"or the path to a pretrained model config file (str)") "First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@property @property
def total_num_embeddings(self): def total_num_embeddings(self):
...@@ -126,7 +128,7 @@ class OpenAIGPTConfig(object): ...@@ -126,7 +128,7 @@ class OpenAIGPTConfig(object):
@classmethod @classmethod
def from_json_file(cls, json_file): def from_json_file(cls, json_file):
"""Constructs a `OpenAIGPTConfig` from a json file of parameters.""" """Constructs a `OpenAIGPTConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader: with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read() text = reader.read()
return cls.from_dict(json.loads(text)) return cls.from_dict(json.loads(text))
...@@ -142,6 +144,7 @@ class OpenAIGPTConfig(object): ...@@ -142,6 +144,7 @@ class OpenAIGPTConfig(object):
"""Serializes this instance to a JSON string.""" """Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class Conv1D(nn.Module): class Conv1D(nn.Module):
def __init__(self, nf, rf, nx): def __init__(self, nf, rf, nx):
super(Conv1D, self).__init__() super(Conv1D, self).__init__()
...@@ -171,7 +174,7 @@ class Attention(nn.Module): ...@@ -171,7 +174,7 @@ class Attention(nn.Module):
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0 assert n_state % config.n_head == 0
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) self.register_buffer("b", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
...@@ -186,7 +189,7 @@ class Attention(nn.Module): ...@@ -186,7 +189,7 @@ class Attention(nn.Module):
w = w / math.sqrt(v.size(-1)) w = w / math.sqrt(v.size(-1))
# w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights # w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
# XD: self.b may be larger than w, so we need to crop it # XD: self.b may be larger than w, so we need to crop it
b = self.b[:, :, :w.size(-2), :w.size(-1)] b = self.b[:, :, : w.size(-2), : w.size(-1)]
w = w * b + -1e9 * (1 - b) w = w * b + -1e9 * (1 - b)
w = nn.Softmax(dim=-1)(w) w = nn.Softmax(dim=-1)(w)
...@@ -262,7 +265,7 @@ class OpenAIGPTLMHead(nn.Module): ...@@ -262,7 +265,7 @@ class OpenAIGPTLMHead(nn.Module):
def set_embeddings_weights(self, model_embeddings_weights): def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights # Tied weights self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
# Truncated Language modeling logits (we remove the last token) # Truncated Language modeling logits (we remove the last token)
...@@ -281,14 +284,15 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): ...@@ -281,14 +284,15 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(config.n_embd, 1) self.linear = nn.Linear(config.n_embd, 1)
nn.init.normal_(self.linear.weight, std = 0.02) nn.init.normal_(self.linear.weight, std=0.02)
nn.init.normal_(self.linear.bias, 0) nn.init.normal_(self.linear.bias, 0)
def forward(self, hidden_states, multiple_choice_token_mask): def forward(self, hidden_states, mc_token_mask):
# Classification logits # Classification logits
# hidden_states = hidden_states.view(-1, self.n_embd) # hidden_states = hidden_states.view(-1, self.n_embd)
# multiple_choice_token_mask = multiple_choice_token_mask.view(-1, 1).expand_as(hidden_states) # mc_token_mask = mc_token_mask.view(-1, 1).expand_as(hidden_states)
multiple_choice_h = hidden_states * multiple_choice_token_mask.unsqueeze(-1) mc_token_mask = mc_token_mask.float()
multiple_choice_h = hidden_states * mc_token_mask.unsqueeze(-1)
multiple_choice_h = multiple_choice_h.sum(dim=-2) multiple_choice_h = multiple_choice_h.sum(dim=-2)
# flat = x[..., 0].contiguous().view(-1) # flat = x[..., 0].contiguous().view(-1)
# multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :] # multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :]
...@@ -307,6 +311,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -307,6 +311,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(OpenAIGPTPreTrainedModel, self).__init__() super(OpenAIGPTPreTrainedModel, self).__init__()
if not isinstance(config, OpenAIGPTConfig): if not isinstance(config, OpenAIGPTConfig):
...@@ -315,7 +320,8 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -315,7 +320,8 @@ class OpenAIGPTPreTrainedModel(nn.Module):
"To create a model from a pretrained model use " "To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__ self.__class__.__name__, self.__class__.__name__
)) )
)
self.config = config self.config = config
def init_weights(self, module): def init_weights(self, module):
...@@ -335,8 +341,9 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -335,8 +341,9 @@ class OpenAIGPTPreTrainedModel(nn.Module):
pass pass
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None, def from_pretrained(
*inputs, **kwargs): cls, pretrained_model_name, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
):
""" """
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
...@@ -348,6 +355,10 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -348,6 +355,10 @@ class OpenAIGPTPreTrainedModel(nn.Module):
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
. `openai_gpt_config.json` a configuration file for the model . `openai_gpt_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance . `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. a series of NumPy files containing OpenAI TensorFlow trained weights
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached. cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class *inputs, **kwargs: additional input for the specific Bert class
...@@ -365,24 +376,22 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -365,24 +376,22 @@ class OpenAIGPTPreTrainedModel(nn.Module):
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file " "We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format( "associated to this path or url.".format(
pretrained_model_name, pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), )
archive_file)) )
return None return None
if resolved_archive_file == archive_file: if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file)) logger.info("loading archive file {}".format(archive_file))
else: else:
logger.info("loading archive file {} from cache at {}".format( logger.info("loading archive file {} from cache at {}".format(archive_file, resolved_archive_file))
archive_file, resolved_archive_file))
tempdir = None tempdir = None
if os.path.isdir(resolved_archive_file): if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file serialization_dir = resolved_archive_file
else: else:
# Extract archive to temp dir # Extract archive to temp dir
tempdir = tempfile.mkdtemp() tempdir = tempfile.mkdtemp()
logger.info("extracting archive file {} to temp dir {}".format( logger.info("extracting archive file {} to temp dir {}".format(resolved_archive_file, tempdir))
resolved_archive_file, tempdir)) with tarfile.open(resolved_archive_file, "r:gz") as archive:
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir) archive.extractall(tempdir)
serialization_dir = tempdir serialization_dir = tempdir
# Load config # Load config
...@@ -391,18 +400,24 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -391,18 +400,24 @@ class OpenAIGPTPreTrainedModel(nn.Module):
logger.info("Model config {}".format(config)) logger.info("Model config {}".format(config))
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
if state_dict is None: if state_dict is None and not from_tf:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path) state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
if from_tf:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return load_tf_weights_in_openai_gpt(model, serialization_dir)
old_keys = [] old_keys = []
new_keys = [] new_keys = []
for key in state_dict.keys(): for key in state_dict.keys():
new_key = None new_key = None
if 'gamma' in key: if "gamma" in key:
new_key = key.replace('gamma', 'weight') new_key = key.replace("gamma", "weight")
if 'beta' in key: if "beta" in key:
new_key = key.replace('beta', 'bias') new_key = key.replace("beta", "bias")
if new_key: if new_key:
old_keys.append(key) old_keys.append(key)
new_keys.append(new_key) new_keys.append(new_key)
...@@ -413,34 +428,36 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -413,34 +428,36 @@ class OpenAIGPTPreTrainedModel(nn.Module):
unexpected_keys = [] unexpected_keys = []
error_msgs = [] error_msgs = []
# copy state_dict so _load_from_state_dict can modify it # copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None) metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy() state_dict = state_dict.copy()
if metadata is not None: if metadata is not None:
state_dict._metadata = metadata state_dict._metadata = metadata
def load(module, prefix=''): def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict( module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
for name, child in module._modules.items(): for name, child in module._modules.items():
if child is not None: if child is not None:
load(child, prefix + name + '.') load(child, prefix + name + ".")
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
load(model.transformer if hasattr(model, "transformer") else model, prefix="")
if len(missing_keys) > 0: if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format( logger.info(
model.__class__.__name__, missing_keys)) "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
)
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format( logger.info(
model.__class__.__name__, unexpected_keys)) "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
)
if len(error_msgs) > 0: if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( raise RuntimeError(
model.__class__.__name__, "\n\t".join(error_msgs))) "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
# Add additional embeddings for special tokens if needed # Add additional embeddings for special tokens if needed
if num_special_tokens != config.n_special: if num_special_tokens is not None and num_special_tokens != config.n_special:
model.set_num_special_tokens(num_special_tokens) model.set_num_special_tokens(num_special_tokens)
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return model return model
...@@ -495,6 +512,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -495,6 +512,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
hidden_states = model(input_ids) hidden_states = model(input_ids)
``` ```
""" """
def __init__(self, config): def __init__(self, config):
super(OpenAIGPTModel, self).__init__(config) super(OpenAIGPTModel, self).__init__(config)
total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx
...@@ -516,8 +534,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -516,8 +534,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Initialize all new embeddings (in particular the special tokens) # Initialize all new embeddings (in particular the special tokens)
self.init_weights(self.embed) self.init_weights(self.embed)
# Copy word and positional embeddings from the previous weights # Copy word and positional embeddings from the previous weights
self.embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :] self.embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
self.embed.weight.data[-self.config.n_ctx:, :] = old_embed.weight.data[-self.config.n_ctx:, :] self.embed.weight.data[-self.config.n_ctx :, :] = old_embed.weight.data[-self.config.n_ctx :, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None): def forward(self, input_ids, position_ids=None, token_type_ids=None):
if position_ids is None: if position_ids is None:
...@@ -544,6 +562,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -544,6 +562,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
hidden_states = block(hidden_states) hidden_states = block(hidden_states)
return hidden_states.view(*input_shape, hidden_states.size(-1)) return hidden_states.view(*input_shape, hidden_states.size(-1))
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
"""OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training"). """OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").
...@@ -602,6 +621,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -602,6 +621,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
lm_logits = model(input_ids) lm_logits = model(input_ids)
``` ```
""" """
def __init__(self, config): def __init__(self, config):
super(OpenAIGPTLMHeadModel, self).__init__(config) super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config) self.transformer = OpenAIGPTModel(config)
...@@ -622,6 +642,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -622,6 +642,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
return loss return loss
return lm_logits return lm_logits
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
"""OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training"). """OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training").
...@@ -653,7 +674,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -653,7 +674,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word BPE token indices selected in the range [0, config.vocab_size[ with the word BPE token indices selected in the range [0, config.vocab_size[
`multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] `mc_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise. with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
`position_ids`: an optional torch.LongTensor with the same shape as input_ids `position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, with the position indices (selected in the range [config.vocab_size + config.n_special,
...@@ -678,14 +699,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -678,14 +699,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
```python ```python
# Already been converted into BPE token ids # Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
multiple_choice_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) mc_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling_openai.OpenAIGPTConfig() config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTLMHeadModel(config) model = modeling_openai.OpenAIGPTLMHeadModel(config)
lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask) lm_logits, multiple_choice_logits = model(input_ids, mc_token_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config) super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config) self.transformer = OpenAIGPTModel(config)
...@@ -698,18 +720,17 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -698,18 +720,17 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.embed.weight) self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
def forward(self, input_ids, multiple_choice_token_mask, position_ids=None, token_type_ids=None, def forward(self, input_ids, mc_token_mask, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
lm_labels=None, multiple_choice_labels=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids) hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
multiple_choice_logits = self.multiple_choice_head(hidden_states, multiple_choice_token_mask) mc_logits = self.multiple_choice_head(hidden_states, mc_token_mask)
losses = [] losses = []
if lm_labels is not None: if lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))) losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
if multiple_choice_labels is not None: if mc_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
losses.append(loss_fct(multiple_choice_logits, multiple_choice_labels.view(-1))) losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses: if losses:
return losses return losses
return lm_logits, multiple_choice_logits return lm_logits, mc_logits
...@@ -37,6 +37,7 @@ from torch.nn.parameter import Parameter ...@@ -37,6 +37,7 @@ from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm from .modeling import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path from .file_utils import cached_path
from .convert_transfo_xl_checkpoint_to_pytorch import load_tf_weights_in_transfo_xl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -48,6 +49,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -48,6 +49,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
} }
CONFIG_NAME = 'transfo_xl_config.json' CONFIG_NAME = 'transfo_xl_config.json'
WEIGHTS_NAME = 'pytorch_model.bin' WEIGHTS_NAME = 'pytorch_model.bin'
TF_WEIGHTS_NAME = 'model.ckpt'
class TransfoXLConfig(object): class TransfoXLConfig(object):
"""Configuration class to store the configuration of a `TransfoXLModel`. """Configuration class to store the configuration of a `TransfoXLModel`.
...@@ -749,7 +751,7 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -749,7 +751,7 @@ class TransfoXLPreTrainedModel(nn.Module):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
*inputs, **kwargs): from_tf=False, *inputs, **kwargs):
""" """
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
...@@ -761,6 +763,10 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -761,6 +763,10 @@ class TransfoXLPreTrainedModel(nn.Module):
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model . `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached. cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class *inputs, **kwargs: additional input for the specific Bert class
...@@ -799,9 +805,12 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -799,9 +805,12 @@ class TransfoXLPreTrainedModel(nn.Module):
logger.info("Model config {}".format(config)) logger.info("Model config {}".format(config))
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
if state_dict is None: if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file) state_dict = torch.load(resolved_archive_file)
if from_tf:
# Directly load from a TensorFlow checkpoint
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
return load_tf_weights_in_transfo_xl(model, weights_path)
missing_keys = [] missing_keys = []
unexpected_keys = [] unexpected_keys = []
error_msgs = [] error_msgs = []
......
...@@ -130,6 +130,9 @@ class OpenAIGPTTokenizer(object): ...@@ -130,6 +130,9 @@ class OpenAIGPTTokenizer(object):
else: else:
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens): def set_special_tokens(self, special_tokens):
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment