Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fea15cc9
"examples/trials/sklearn/vscode:/vscode.git/clone" did not exist on "ae72aec87dbce3a3f328c9a70a82132f825634d6"
Commit
fea15cc9
authored
Jan 16, 2019
by
thomwolf
Browse files
update model conversion
parent
a28dfc86
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
23 deletions
+15
-23
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
...etrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
+15
-9
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+0
-14
No files found.
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
View file @
fea15cc9
...
...
@@ -68,7 +68,10 @@ def build_tf_to_pytorch_map(model, config):
layer_str
+
"ff/layer_2/bias"
:
b
.
pos_ff
.
CoreNet
[
3
].
bias
,
})
# Softmax cutoffs
# Adaptive Softmax
tf_to_pt_map
.
update
({
"transformer/adaptive_softmax/cutoff_0/cluster_W"
:
model
.
crit
.
cluster_weight
,
"transformer/adaptive_softmax/cutoff_0/cluster_b"
:
model
.
crit
.
cluster_bias
})
for
i
,
(
out_l
,
proj_l
,
tie_proj
)
in
enumerate
(
zip
(
model
.
crit
.
out_layers
,
model
.
crit
.
out_projs
,
...
...
@@ -169,7 +172,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
raise
print
(
"Initialize PyTorch weight {} for layer {}"
.
format
(
name
,
i
))
p_i
.
data
=
torch
.
from_numpy
(
arr_i
)
continue
else
:
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
...
...
@@ -177,6 +180,9 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
del
tf_weights
[
name
]
print
(
"Weights not copied to PyTorch model: {}"
.
format
(
', '
.
join
(
tf_weights
.
keys
())))
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
fea15cc9
...
...
@@ -802,20 +802,6 @@ class TransfoXLPreTrainedModel(nn.Module):
if
state_dict
is
None
:
state_dict
=
torch
.
load
(
resolved_archive_file
)
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
'gamma'
in
key
:
new_key
=
key
.
replace
(
'gamma'
,
'weight'
)
if
'beta'
in
key
:
new_key
=
key
.
replace
(
'beta'
,
'bias'
)
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment