Commit 24d80689 authored by thomwolf's avatar thomwolf
Browse files

weights loading script ok

parent 32da7548
......@@ -18,23 +18,31 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import torch
from pytorch_pretrained_bert.modeling_xlnet import XLNetConfig, XLNetRunConfig, XLNetLMHeadModel, load_tf_weights_in_xlnet
from pytorch_pretrained_bert.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
XLNetConfig, XLNetRunConfig,
XLNetLMHeadModel, load_tf_weights_in_xlnet)
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path):
# Initialise PyTorch model
config = XLNetConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = XLNetLMHeadModel(config)
# Load weights from tf checkpoint
load_tf_weights_in_xlnet(model, tf_checkpoint_path)
load_tf_weights_in_xlnet(model, config, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__":
......@@ -50,13 +58,13 @@ if __name__ == "__main__":
type = str,
required = True,
help = "The config json file corresponding to the pre-trained XLNet model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
help = "Path to the folder to store the PyTorch model or dataset/vocab.")
args = parser.parse_args()
convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.xlnet_config_file,
args.pytorch_dump_path)
args.pytorch_dump_folder_path)
......@@ -45,70 +45,122 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
XLNET_CONFIG_NAME = 'xlnet_config.json'
TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_xlnet(model, tf_checkpoint_path):
def build_tf_xlnet_to_pytorch_map(model, config):
""" A map of modules from TF to PyTorch.
I use a map to keep the PyTorch model as
identical to the original PyTorch model as possible.
"""
tf_to_pt_map = {}
if hasattr(model, 'transformer'):
# We are loading pre-trained weights in a XLNetLMHeadModel => we will load also the output bias
tf_to_pt_map['model/lm_loss/bias'] = model.lm_loss.bias
# Now load the rest of the transformer
model = model.transformer
# Embeddings and output
tf_to_pt_map.update({'model/transformer/word_embedding/lookup_table': model.word_embedding.weight,
'model/transformer/mask_emb/mask_emb': model.mask_emb})
# Transformer blocks
for i, b in enumerate(model.layer):
layer_str = "model/transformer/layer_%d/" % i
tf_to_pt_map.update({
layer_str + "rel_attn/LayerNorm/gamma": b.rel_attn.layer_norm.weight,
layer_str + "rel_attn/LayerNorm/beta": b.rel_attn.layer_norm.bias,
layer_str + "rel_attn/o/kernel": b.rel_attn.o,
layer_str + "rel_attn/q/kernel": b.rel_attn.q,
layer_str + "rel_attn/k/kernel": b.rel_attn.k,
layer_str + "rel_attn/r/kernel": b.rel_attn.r,
layer_str + "rel_attn/v/kernel": b.rel_attn.v,
layer_str + "ff/LayerNorm/gamma": b.ff.layer_norm.weight,
layer_str + "ff/LayerNorm/beta": b.ff.layer_norm.bias,
layer_str + "ff/layer_1/kernel": b.ff.layer_1.weight,
layer_str + "ff/layer_1/bias": b.ff.layer_1.bias,
layer_str + "ff/layer_2/kernel": b.ff.layer_2.weight,
layer_str + "ff/layer_2/bias": b.ff.layer_2.bias,
})
# Relative positioning biases
if config.untie_r:
r_r_list = []
r_w_list = []
r_s_list = []
seg_embed_list = []
for b in model.layer:
r_r_list.append(b.rel_attn.r_r_bias)
r_w_list.append(b.rel_attn.r_w_bias)
r_s_list.append(b.rel_attn.r_s_bias)
seg_embed_list.append(b.rel_attn.seg_embed)
else:
r_r_list = [model.r_r_bias]
r_w_list = [model.r_w_bias]
r_s_list = [model.r_s_bias]
seg_embed_list = [model.seg_embed]
tf_to_pt_map.update({
'model/transformer/r_r_bias': r_r_list,
'model/transformer/r_w_bias': r_w_list,
'model/transformer/r_s_bias': r_s_list,
'model/transformer/seg_embed': seg_embed_list})
return tf_to_pt_map
def load_tf_weights_in_xlnet(model, config, tf_path):
""" Load tf checkpoints in a pytorch model
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config)
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
tf_weights = {}
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
tf_weights[name] = array
for name, array in zip(names, arrays):
name = name.split('/')
for name, pointer in tf_to_pt_map.items():
print("Importing {}".format(name))
assert name in tf_weights
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
elif l[0] == 'squad':
pointer = getattr(pointer, 'classifier')
else:
try:
pointer = getattr(pointer, l[0])
except AttributeError:
print("Skipping {}".format("/".join(name)))
continue
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
if 'kernel' in name and 'ff' in name:
print("Transposing")
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
if isinstance(pointer, list):
# Here we will split the TF weigths
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
arr_i = array[i, ...]
try:
assert p_i.shape == arr_i.shape
except AssertionError as e:
e.args += (p_i.shape, arr_i.shape)
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
else:
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
tf_weights.pop(name, None)
tf_weights.pop(name + '/Adam', None)
tf_weights.pop(name + '/Adam_1', None)
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
return model
......@@ -181,7 +233,18 @@ class XLNetConfig(XLNetBaseConfig):
max_position_embeddings=512,
initializer_range=0.02,
layer_norm_eps=1e-12):
layer_norm_eps=1e-12,
dropout=0.1,
dropatt=0.1,
init="normal",
init_range=0.1,
init_std=0.02,
mem_len=None,
reuse_len=None,
bi_data=False,
clamp_len=-1,
same_length=False):
"""Constructs XLNetConfig.
Args:
......@@ -207,6 +270,22 @@ class XLNetConfig(XLNetBaseConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
"""
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
......@@ -215,7 +294,7 @@ class XLNetConfig(XLNetBaseConfig):
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.n_token = vocab_size_or_config_json_file
self.d_model = d_model
self.n_layer = n_layer
self.n_head = n_head
......@@ -225,9 +304,21 @@ class XLNetConfig(XLNetBaseConfig):
self.d_inner = d_inner
self.untie_r = untie_r
self.attn_type = attn_type
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.init = init
self.init_range = init_range
self.init_std = init_std
self.dropout = dropout
self.dropatt = dropatt
self.mem_len = mem_len
self.reuse_len = reuse_len
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
......@@ -327,7 +418,7 @@ class XLNetRelativeAttention(nn.Module):
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.seg_embed = nn.Parameter(torch.Tensor(2, self.n_head, self.d_head))
self.LayerNorm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
self.layer_norm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.dropout)
def prune_heads(self, heads):
......@@ -385,7 +476,7 @@ class XLNetRelativeAttention(nn.Module):
attn_out = self.dropout(attn_out)
if residual:
attn_out = attn_out + h
output = self.LayerNorm(attn_out)
output = self.layer_norm(attn_out)
return output
......@@ -483,7 +574,7 @@ class XLNetRelativeAttention(nn.Module):
class XLNetFeedForward(nn.Module):
def __init__(self, config):
super(XLNetFeedForward, self).__init__()
self.LayerNorm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
self.layer_norm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
self.layer_1 = nn.Linear(config.d_model, config.d_inner)
self.layer_2 = nn.Linear(config.d_inner, config.d_model)
self.dropout = nn.Dropout(config.dropout)
......@@ -499,7 +590,7 @@ class XLNetFeedForward(nn.Module):
output = self.dropout(output)
output = self.layer_2(output)
output = self.dropout(output)
output = self.LayerNorm(output + inp)
output = self.layer_norm(output + inp)
return output
class XLNetLayer(nn.Module):
......@@ -691,11 +782,26 @@ class XLNetModel(XLNetPreTrainedModel):
self.bi_data = config.bi_data
self.clamp_len = config.clamp_len
self.word_embedding = nn.Embedding(config.n_token, config.d_model)
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
layer = XLNetLayer(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)])
self.dropout = nn.Dropout(config.dropout)
def prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.layer[layer].attention.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [layer.attention.self.multihead_output for layer in self.layer]
def create_mask(self, qlen, mlen):
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
......@@ -783,12 +889,12 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb
def forward(self, word_emb_k, seg_id=None, input_mask=None,
def forward(self, inp_k, seg_id=None, input_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
output_all_encoded_layers=True, head_mask=None):
"""
Args:
word_emb_k: float32 Tensor in shape [len, bsz, d_model], the input token embeddings.
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding.
......@@ -820,11 +926,12 @@ class XLNetModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
qlen, bsz = word_emb_k.shape[0], word_emb_k.shape[1]
qlen, bsz = inp_k.shape[0], inp_k.shape[1]
mlen = mems[0].shape[0] if mems is not None else 0
klen = mlen + qlen
dtype_float = word_emb_k.dtype
device = word_emb_k.device
dtype_float = next(self.parameters()).dtype
device = next(self.parameters()).device
##### Attention mask
# causal attention mask
......@@ -865,7 +972,8 @@ class XLNetModel(XLNetPreTrainedModel):
else:
non_tgt_mask = None
##### Process Word embeddings and prepare h & g hidden states
##### Word embeddings and prepare h & g hidden states
word_emb_k = self.word_embedding(inp_k)
output_h = self.dropout(word_emb_k)
if inp_q is not None:
if target_mapping is not None:
......@@ -983,30 +1091,19 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self.attn_type = config.attn_type
self.same_length = config.same_length
self.word_embedding = nn.Embedding(config.vocab_size, config.d_model)
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
self.transformer = XLNetModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True)
self.dropout = nn.Dropout(config.dropout)
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
# Tie weights
self.lm_loss.weight = self.word_embedding.weight
self.apply(self.init_xlnet_weights)
self.tie_weights()
def prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
def tie_weights(self):
""" Make sure we are sharing the embeddings
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [layer.attention.self.multihead_output for layer in self.encoder.layer]
self.lm_loss.weight = self.transformer.word_embedding.weight
def forward(self, inp_k, seg_id=None, input_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
......@@ -1037,9 +1134,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
word_emb_k = self.word_embedding(inp_k)
output, new_mems = self.transformer(word_emb_k, seg_id, input_mask,
output, new_mems = self.transformer(inp_k, seg_id, input_mask,
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
......@@ -1059,5 +1154,5 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
# return all_attentions, encoded_layers, pooled_output
return logits, new_mems
# return all_attentions, encoded_layers, pooled_output
......@@ -186,13 +186,13 @@ class XLNetModelTest(unittest.TestCase):
self.run_tester(XLNetModelTest.XLNetModelTester(self))
def test_config_to_json_string(self):
config = XLNetConfig(vocab_size_or_config_json_file=96, d_model=37)
config = XLNetConfig(vocab_size_or_config_json_file=96, d_model=16*4)
obj = json.loads(config.to_json_string())
self.assertEqual(obj["n_token"], 96)
self.assertEqual(obj["d_model"], 37)
self.assertEqual(obj["d_model"], 16*4)
def test_config_to_json_file(self):
config_first = XLNetConfig(vocab_size_or_config_json_file=96, d_model=37)
config_first = XLNetConfig(vocab_size_or_config_json_file=96, d_model=16*4)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = XLNetConfig.from_json_file(json_file_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment