Commit fea15cc9 authored by thomwolf's avatar thomwolf
Browse files

update model conversion

parent a28dfc86
...@@ -68,7 +68,10 @@ def build_tf_to_pytorch_map(model, config): ...@@ -68,7 +68,10 @@ def build_tf_to_pytorch_map(model, config):
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias, layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
}) })
# Softmax cutoffs # Adaptive Softmax
tf_to_pt_map.update({
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
for i, (out_l, proj_l, tie_proj) in enumerate(zip( for i, (out_l, proj_l, tie_proj) in enumerate(zip(
model.crit.out_layers, model.crit.out_layers,
model.crit.out_projs, model.crit.out_projs,
...@@ -169,14 +172,17 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, ...@@ -169,14 +172,17 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
raise raise
print("Initialize PyTorch weight {} for layer {}".format(name, i)) print("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i) p_i.data = torch.from_numpy(arr_i)
continue else:
try: try:
assert pointer.shape == array.shape assert pointer.shape == array.shape
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
print("Initialize PyTorch weight {}".format(name)) print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array) pointer.data = torch.from_numpy(array)
del tf_weights[name]
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
......
...@@ -802,20 +802,6 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -802,20 +802,6 @@ class TransfoXLPreTrainedModel(nn.Module):
if state_dict is None: if state_dict is None:
state_dict = torch.load(resolved_archive_file) state_dict = torch.load(resolved_archive_file)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = [] missing_keys = []
unexpected_keys = [] unexpected_keys = []
error_msgs = [] error_msgs = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment