"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "85fff78c2ddea175cc04ff8399e690de0cead686"
Commit b12616fd authored by thomwolf's avatar thomwolf
Browse files

updating code organization to fix imports

parent d77dd62f
...@@ -24,7 +24,7 @@ import argparse ...@@ -24,7 +24,7 @@ import argparse
import torch import torch
import numpy as np import numpy as np
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME from .modeling_openai import load_tf_weights_in_openai_gpt, OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
# Construct model # Construct model
...@@ -46,66 +46,6 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c ...@@ -46,66 +46,6 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string()) f.write(config.to_json_string())
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
"""
print("Loading weights...")
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
offsets = np.cumsum([np.prod(shape) for shape in shapes])
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
del init_params[1]
init_params = [arr.squeeze() for arr in init_params]
try:
assert model.embed.weight.shape == init_params[0].shape
except AssertionError as e:
e.args += (model.embed.weight.shape, init_params[0].shape)
raise
model.embed.weight.data = torch.from_numpy(init_params[0])
names.pop(0)
init_params.pop(0)
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
name = name[6:] # skip "model/"
assert name[-2:] == ":0"
name = name[:-2]
name = name.split('/')
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
l = re.split(r'(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'g':
pointer = getattr(pointer, 'weight')
elif l[0] == 'b':
pointer = getattr(pointer, 'bias')
elif l[0] == 'w':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -25,7 +25,7 @@ import tensorflow as tf ...@@ -25,7 +25,7 @@ import tensorflow as tf
import torch import torch
import numpy as np import numpy as np
from .modeling import BertConfig, BertForPreTraining from .modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model # Initialise PyTorch model
...@@ -40,57 +40,6 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor ...@@ -40,57 +40,6 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
print("Save PyTorch model to {}".format(pytorch_dump_path)) print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path) torch.save(model.state_dict(), pytorch_dump_path)
def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -27,7 +27,7 @@ import tensorflow as tf ...@@ -27,7 +27,7 @@ import tensorflow as tf
import torch import torch
import numpy as np import numpy as np
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME, load_tf_weights_in_transfo_xl
from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME
# We do this to be able to load the python 2 datasets pickles # We do this to be able to load the python 2 datasets pickles
...@@ -38,74 +38,6 @@ data_utils.Corpus = data_utils.TransfoXLCorpus ...@@ -38,74 +38,6 @@ data_utils.Corpus = data_utils.TransfoXLCorpus
sys.modules['data_utils'] = data_utils sys.modules['data_utils'] = data_utils
sys.modules['vocabulary'] = data_utils sys.modules['vocabulary'] = data_utils
def build_tf_to_pytorch_map(model, config):
""" A map of modules from TF to PyTorch.
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
"""
tf_to_pt_map = {}
# Embeddings cutoffs
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
tf_to_pt_map.update({
layer_str + 'lookup_table': embed_l.weight,
layer_str + 'proj_W': proj_l
})
# Transformer blocks
for i, b in enumerate(model.layers):
layer_str = "transformer/layer_%d/" % i
tf_to_pt_map.update({
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
})
# Adaptive Softmax
tf_to_pt_map.update({
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
for i, (out_l, proj_l, tie_proj) in enumerate(zip(
model.crit.out_layers,
model.crit.out_projs,
config.tie_projs)):
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
if config.tie_weight:
tf_to_pt_map.update({
layer_str + 'b': out_l.bias})
else:
raise NotImplementedError
# I don't think this is implemented in the TF code
tf_to_pt_map.update({
layer_str + 'lookup_table': out_l.weight,
layer_str + 'b': out_l.bias})
if not tie_proj:
tf_to_pt_map.update({
layer_str + 'proj': proj_l
})
# Relative positioning biases
if config.untie_r:
r_r_list = []
r_w_list = []
for b in model.layers:
r_r_list.append(b.dec_attn.r_r_bias)
r_w_list.append(b.dec_attn.r_w_bias)
else:
r_r_list = [model.r_r_bias]
r_w_list = [model.r_w_bias]
tf_to_pt_map.update({
'transformer/r_r_bias': r_r_list,
'transformer/r_w_bias': r_w_list})
return tf_to_pt_map
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
transfo_xl_config_file, transfo_xl_config_file,
pytorch_dump_folder_path, pytorch_dump_folder_path,
...@@ -150,54 +82,6 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, ...@@ -150,54 +82,6 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
f.write(config.to_json_string()) f.write(config.to_json_string())
def load_tf_weights_in_transfo_xl(model, config, tf_path):
""" Load tf checkpoints in a pytorch model
"""
# Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tf_weights = {}
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_weights[name] = array
for name, pointer in tf_to_pt_map.items():
assert name in tf_weights
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if 'kernel' in name or 'proj' in name:
array = np.transpose(array)
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
# Here we will split the TF weigths
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
arr_i = array[i, ...]
try:
assert p_i.shape == arr_i.shape
except AssertionError as e:
e.args += (p_i.shape, arr_i.shape)
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
else:
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
tf_weights.pop(name, None)
tf_weights.pop(name + '/Adam', None)
tf_weights.pop(name + '/Adam_1', None)
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
return model
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
......
...@@ -33,7 +33,6 @@ from torch import nn ...@@ -33,7 +33,6 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .file_utils import cached_path from .file_utils import cached_path
from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -50,6 +49,59 @@ CONFIG_NAME = 'bert_config.json' ...@@ -50,6 +49,59 @@ CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin' WEIGHTS_NAME = 'pytorch_model.bin'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
def gelu(x): def gelu(x):
"""Implementation of the gelu activation function. """Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
......
...@@ -32,7 +32,6 @@ from torch.nn.parameter import Parameter ...@@ -32,7 +32,6 @@ from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm from .modeling import BertLayerNorm as LayerNorm
from .file_utils import cached_path from .file_utils import cached_path
from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -40,6 +39,67 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.h ...@@ -40,6 +39,67 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.h
CONFIG_NAME = "openai_gpt_config.json" CONFIG_NAME = "openai_gpt_config.json"
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
"""
print("Loading weights...")
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
offsets = np.cumsum([np.prod(shape) for shape in shapes])
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
del init_params[1]
init_params = [arr.squeeze() for arr in init_params]
try:
assert model.embed.weight.shape == init_params[0].shape
except AssertionError as e:
e.args += (model.embed.weight.shape, init_params[0].shape)
raise
model.embed.weight.data = torch.from_numpy(init_params[0])
names.pop(0)
init_params.pop(0)
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
name = name[6:] # skip "model/"
assert name[-2:] == ":0"
name = name[:-2]
name = name.split('/')
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
l = re.split(r'(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'g':
pointer = getattr(pointer, 'weight')
elif l[0] == 'b':
pointer = getattr(pointer, 'bias')
elif l[0] == 'w':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
def gelu(x): def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
......
...@@ -37,7 +37,6 @@ from torch.nn.parameter import Parameter ...@@ -37,7 +37,6 @@ from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm from .modeling import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path from .file_utils import cached_path
from .convert_transfo_xl_checkpoint_to_pytorch import load_tf_weights_in_transfo_xl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -51,6 +50,123 @@ CONFIG_NAME = 'transfo_xl_config.json' ...@@ -51,6 +50,123 @@ CONFIG_NAME = 'transfo_xl_config.json'
WEIGHTS_NAME = 'pytorch_model.bin' WEIGHTS_NAME = 'pytorch_model.bin'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
def build_tf_to_pytorch_map(model, config):
""" A map of modules from TF to PyTorch.
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
"""
tf_to_pt_map = {}
# Embeddings cutoffs
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
tf_to_pt_map.update({
layer_str + 'lookup_table': embed_l.weight,
layer_str + 'proj_W': proj_l
})
# Transformer blocks
for i, b in enumerate(model.layers):
layer_str = "transformer/layer_%d/" % i
tf_to_pt_map.update({
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
})
# Adaptive Softmax
tf_to_pt_map.update({
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
for i, (out_l, proj_l, tie_proj) in enumerate(zip(
model.crit.out_layers,
model.crit.out_projs,
config.tie_projs)):
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
if config.tie_weight:
tf_to_pt_map.update({
layer_str + 'b': out_l.bias})
else:
raise NotImplementedError
# I don't think this is implemented in the TF code
tf_to_pt_map.update({
layer_str + 'lookup_table': out_l.weight,
layer_str + 'b': out_l.bias})
if not tie_proj:
tf_to_pt_map.update({
layer_str + 'proj': proj_l
})
# Relative positioning biases
if config.untie_r:
r_r_list = []
r_w_list = []
for b in model.layers:
r_r_list.append(b.dec_attn.r_r_bias)
r_w_list.append(b.dec_attn.r_w_bias)
else:
r_r_list = [model.r_r_bias]
r_w_list = [model.r_w_bias]
tf_to_pt_map.update({
'transformer/r_r_bias': r_r_list,
'transformer/r_w_bias': r_w_list})
return tf_to_pt_map
def load_tf_weights_in_transfo_xl(model, config, tf_path):
""" Load tf checkpoints in a pytorch model
"""
# Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tf_weights = {}
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_weights[name] = array
for name, pointer in tf_to_pt_map.items():
assert name in tf_weights
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if 'kernel' in name or 'proj' in name:
array = np.transpose(array)
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
# Here we will split the TF weigths
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
arr_i = array[i, ...]
try:
assert p_i.shape == arr_i.shape
except AssertionError as e:
e.args += (p_i.shape, arr_i.shape)
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
else:
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
tf_weights.pop(name, None)
tf_weights.pop(name + '/Adam', None)
tf_weights.pop(name + '/Adam_1', None)
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
return model
class TransfoXLConfig(object): class TransfoXLConfig(object):
"""Configuration class to store the configuration of a `TransfoXLModel`. """Configuration class to store the configuration of a `TransfoXLModel`.
""" """
......
...@@ -291,7 +291,7 @@ if __name__ == '__main__': ...@@ -291,7 +291,7 @@ if __name__ == '__main__':
# sampler = LogUniformSampler(n_vocab, unique=False) # sampler = LogUniformSampler(n_vocab, unique=False)
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels) # new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
sampler = LogUniformSampler(n_vocab, unique=True) sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True)
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
# print('true_probs', true_probs.numpy().tolist()) # print('true_probs', true_probs.numpy().tolist())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment