Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4e521884
Commit
4e521884
authored
Nov 06, 2018
by
lukovnikov
Browse files
bert weight loading from tf
parent
907d3569
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
102 additions
and
27 deletions
+102
-27
convert_tf_checkpoint_to_pytorch.py
convert_tf_checkpoint_to_pytorch.py
+30
-26
modeling.py
modeling.py
+1
-1
tests/mytest.py
tests/mytest.py
+71
-0
No files found.
convert_tf_checkpoint_to_pytorch.py
View file @
4e521884
...
@@ -26,35 +26,14 @@ import numpy as np
...
@@ -26,35 +26,14 @@ import numpy as np
from
modeling
import
BertConfig
,
BertModel
from
modeling
import
BertConfig
,
BertModel
parser
=
argparse
.
ArgumentParser
()
## Required parameters
def
convert
(
config_path
,
ckpt_path
,
out_path
=
None
):
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
def
convert
():
# Initialise PyTorch model
# Initialise PyTorch model
config
=
BertConfig
.
from_json_file
(
args
.
bert_
config_
file
)
config
=
BertConfig
.
from_json_file
(
config_
path
)
model
=
BertModel
(
config
)
model
=
BertModel
(
config
)
# Load weights from TF model
# Load weights from TF model
path
=
args
.
tf_checkpoin
t_path
path
=
ckp
t_path
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
path
))
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
path
))
init_vars
=
tf
.
train
.
list_variables
(
path
)
init_vars
=
tf
.
train
.
list_variables
(
path
)
...
@@ -99,7 +78,32 @@ def convert():
...
@@ -99,7 +78,32 @@ def convert():
pointer
.
data
=
torch
.
from_numpy
(
array
)
pointer
.
data
=
torch
.
from_numpy
(
array
)
# Save pytorch-model
# Save pytorch-model
torch
.
save
(
model
.
state_dict
(),
args
.
pytorch_dump_path
)
if
out_path
is
not
None
:
torch
.
save
(
model
.
state_dict
(),
out_path
)
return
model
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
convert
()
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
False
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
print
(
args
)
convert
(
args
.
bert_config_file
,
args
.
tf_checkpoint_path
,
args
.
pytorch_dump_path
)
modeling.py
View file @
4e521884
...
@@ -355,7 +355,7 @@ class BertModel(nn.Module):
...
@@ -355,7 +355,7 @@ class BertModel(nn.Module):
all_encoder_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
)
all_encoder_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
)
sequence_output
=
all_encoder_layers
[
-
1
]
sequence_output
=
all_encoder_layers
[
-
1
]
pooled_output
=
self
.
pooler
(
sequence_output
)
pooled_output
=
self
.
pooler
(
sequence_output
)
return
all_encoder_layers
,
pooled_output
return
[
embedding_output
]
+
all_encoder_layers
,
pooled_output
class
BertForSequenceClassification
(
nn
.
Module
):
class
BertForSequenceClassification
(
nn
.
Module
):
"""BERT model for classification.
"""BERT model for classification.
...
...
tests/mytest.py
0 → 100644
View file @
4e521884
import
unittest
import
json
import
random
import
torch
import
numpy
as
np
import
modeling
import
convert_tf_checkpoint_to_pytorch
import
grouch
class
MyTest
(
unittest
.
TestCase
):
def
test_loading_and_running
(
self
):
bertpath
=
"../../grouch/data/bert/bert-base/"
configpath
=
bertpath
+
"bert_config.json"
ckptpath
=
bertpath
+
"bert_model.ckpt"
m
=
convert_tf_checkpoint_to_pytorch
.
convert
(
configpath
,
ckptpath
)
m
.
eval
()
# print(m)
input_ids
=
torch
.
LongTensor
([[
31
,
51
,
99
],
[
15
,
5
,
0
]])
input_mask
=
torch
.
LongTensor
([[
1
,
1
,
1
],
[
1
,
1
,
0
]])
token_type_ids
=
torch
.
LongTensor
([[
0
,
0
,
1
],
[
0
,
1
,
0
]])
all_y
,
pool_y
=
m
(
input_ids
,
token_type_ids
,
input_mask
)
print
(
pool_y
.
shape
)
# np.save("_bert_ref_pool_out.npy", pool_y.detach().numpy())
# np.save("_bert_ref_all_out.npy", torch.stack(all_y, 0).detach().numpy())
config
=
grouch
.
TransformerBERT
.
load_config
(
configpath
)
gm
=
grouch
.
TransformerBERT
.
init_from_config
(
config
)
gm
.
load_weights_from_tf_checkpoint
(
ckptpath
)
gm
.
eval
()
g_all_y
,
g_pool_y
=
gm
(
input_ids
,
token_type_ids
,
input_mask
)
print
(
g_pool_y
.
shape
)
# check embeddings
# print(m.embeddings)
# print(gm.emb)
# hugging_emb = m.embeddings(input_ids, token_type_ids)
# grouch_emb = gm.emb(input_ids, token_type_ids)
print
((
all_y
[
0
]
-
g_all_y
[
0
]).
norm
())
# print(all_y[0][:, :, :10] - g_all_y[0][:, :, :10])
self
.
assertTrue
(
np
.
allclose
(
all_y
[
0
].
detach
().
numpy
(),
g_all_y
[
0
].
detach
().
numpy
(),
atol
=
1e-7
))
print
(
"embeddings good"
)
print
(
m
.
encoder
.
layer
[
0
])
print
(
gm
.
encoder
.
layers
[
0
])
print
(
"norm of diff at layer 1"
,
(
all_y
[
1
]
-
g_all_y
[
1
]).
norm
())
# print(all_y[1][:, :, :10] - g_all_y[1][:, :, :10])
self
.
assertTrue
(
np
.
allclose
(
all_y
[
1
].
detach
().
numpy
(),
g_all_y
[
1
].
detach
().
numpy
(),
atol
=
1e-6
))
# hugging_layer = m.encoder.layer[0]
# grouch_layer = gm.encoder.layers[0]
# print("comparing weights")
# print((hugging_layer.attention.self.query.weight - grouch_layer.slf_attn.q_proj.weight).norm())
# print((hugging_layer.attention.self.query.bias - grouch_layer.slf_attn.q_proj.bias).norm())
# print((hugging_layer.attention.self.key.weight - grouch_layer.slf_attn.k_proj.weight).norm())
# print((hugging_layer.attention.self.key.bias - grouch_layer.slf_attn.k_proj.bias).norm())
# print((hugging_layer.attention.self.value.weight - grouch_layer.slf_attn.v_proj.weight).norm())
# print((hugging_layer.attention.self.value.bias - grouch_layer.slf_attn.v_proj.bias).norm())
# print((hugging_layer.attention.output.dense.weight - grouch_layer.slf_attn.vw_proj.weight).norm())
# print((hugging_layer.attention.output.dense.bias - grouch_layer.slf_attn.vw_proj.bias).norm())
print
(
"norm of diff at last layer"
,
(
all_y
[
-
1
]
-
g_all_y
[
-
1
]).
norm
())
# print(all_y[-1][:, :, :10] - g_all_y[-1][:, :, :10])
self
.
assertTrue
(
np
.
allclose
(
all_y
[
-
1
].
detach
().
numpy
(),
g_all_y
[
-
1
].
detach
().
numpy
(),
atol
=
1e-4
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment