Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4e521884
Commit
4e521884
authored
Nov 06, 2018
by
lukovnikov
Browse files
bert weight loading from tf
parent
907d3569
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
102 additions
and
27 deletions
+102
-27
convert_tf_checkpoint_to_pytorch.py
convert_tf_checkpoint_to_pytorch.py
+30
-26
modeling.py
modeling.py
+1
-1
tests/mytest.py
tests/mytest.py
+71
-0
No files found.
convert_tf_checkpoint_to_pytorch.py
View file @
4e521884
...
@@ -26,35 +26,14 @@ import numpy as np
...
@@ -26,35 +26,14 @@ import numpy as np
from
modeling
import
BertConfig
,
BertModel
from
modeling
import
BertConfig
,
BertModel
parser
=
argparse
.
ArgumentParser
()
## Required parameters
def
convert
(
config_path
,
ckpt_path
,
out_path
=
None
):
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
def
convert
():
# Initialise PyTorch model
# Initialise PyTorch model
config
=
BertConfig
.
from_json_file
(
args
.
bert_
config_
file
)
config
=
BertConfig
.
from_json_file
(
config_
path
)
model
=
BertModel
(
config
)
model
=
BertModel
(
config
)
# Load weights from TF model
# Load weights from TF model
path
=
args
.
tf_checkpoin
t_path
path
=
ckp
t_path
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
path
))
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
path
))
init_vars
=
tf
.
train
.
list_variables
(
path
)
init_vars
=
tf
.
train
.
list_variables
(
path
)
...
@@ -99,7 +78,32 @@ def convert():
...
@@ -99,7 +78,32 @@ def convert():
pointer
.
data
=
torch
.
from_numpy
(
array
)
pointer
.
data
=
torch
.
from_numpy
(
array
)
# Save pytorch-model
# Save pytorch-model
torch
.
save
(
model
.
state_dict
(),
args
.
pytorch_dump_path
)
if
out_path
is
not
None
:
torch
.
save
(
model
.
state_dict
(),
out_path
)
return
model
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
convert
()
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
False
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
print
(
args
)
convert
(
args
.
bert_config_file
,
args
.
tf_checkpoint_path
,
args
.
pytorch_dump_path
)
modeling.py
View file @
4e521884
...
@@ -355,7 +355,7 @@ class BertModel(nn.Module):
...
@@ -355,7 +355,7 @@ class BertModel(nn.Module):
all_encoder_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
)
all_encoder_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
)
sequence_output
=
all_encoder_layers
[
-
1
]
sequence_output
=
all_encoder_layers
[
-
1
]
pooled_output
=
self
.
pooler
(
sequence_output
)
pooled_output
=
self
.
pooler
(
sequence_output
)
return
all_encoder_layers
,
pooled_output
return
[
embedding_output
]
+
all_encoder_layers
,
pooled_output
class
BertForSequenceClassification
(
nn
.
Module
):
class
BertForSequenceClassification
(
nn
.
Module
):
"""BERT model for classification.
"""BERT model for classification.
...
...
tests/mytest.py
0 → 100644
View file @
4e521884
import
unittest
import
json
import
random
import
torch
import
numpy
as
np
import
modeling
import
convert_tf_checkpoint_to_pytorch
import
grouch
class
MyTest
(
unittest
.
TestCase
):
def
test_loading_and_running
(
self
):
bertpath
=
"../../grouch/data/bert/bert-base/"
configpath
=
bertpath
+
"bert_config.json"
ckptpath
=
bertpath
+
"bert_model.ckpt"
m
=
convert_tf_checkpoint_to_pytorch
.
convert
(
configpath
,
ckptpath
)
m
.
eval
()
# print(m)
input_ids
=
torch
.
LongTensor
([[
31
,
51
,
99
],
[
15
,
5
,
0
]])
input_mask
=
torch
.
LongTensor
([[
1
,
1
,
1
],
[
1
,
1
,
0
]])
token_type_ids
=
torch
.
LongTensor
([[
0
,
0
,
1
],
[
0
,
1
,
0
]])
all_y
,
pool_y
=
m
(
input_ids
,
token_type_ids
,
input_mask
)
print
(
pool_y
.
shape
)
# np.save("_bert_ref_pool_out.npy", pool_y.detach().numpy())
# np.save("_bert_ref_all_out.npy", torch.stack(all_y, 0).detach().numpy())
config
=
grouch
.
TransformerBERT
.
load_config
(
configpath
)
gm
=
grouch
.
TransformerBERT
.
init_from_config
(
config
)
gm
.
load_weights_from_tf_checkpoint
(
ckptpath
)
gm
.
eval
()
g_all_y
,
g_pool_y
=
gm
(
input_ids
,
token_type_ids
,
input_mask
)
print
(
g_pool_y
.
shape
)
# check embeddings
# print(m.embeddings)
# print(gm.emb)
# hugging_emb = m.embeddings(input_ids, token_type_ids)
# grouch_emb = gm.emb(input_ids, token_type_ids)
print
((
all_y
[
0
]
-
g_all_y
[
0
]).
norm
())
# print(all_y[0][:, :, :10] - g_all_y[0][:, :, :10])
self
.
assertTrue
(
np
.
allclose
(
all_y
[
0
].
detach
().
numpy
(),
g_all_y
[
0
].
detach
().
numpy
(),
atol
=
1e-7
))
print
(
"embeddings good"
)
print
(
m
.
encoder
.
layer
[
0
])
print
(
gm
.
encoder
.
layers
[
0
])
print
(
"norm of diff at layer 1"
,
(
all_y
[
1
]
-
g_all_y
[
1
]).
norm
())
# print(all_y[1][:, :, :10] - g_all_y[1][:, :, :10])
self
.
assertTrue
(
np
.
allclose
(
all_y
[
1
].
detach
().
numpy
(),
g_all_y
[
1
].
detach
().
numpy
(),
atol
=
1e-6
))
# hugging_layer = m.encoder.layer[0]
# grouch_layer = gm.encoder.layers[0]
# print("comparing weights")
# print((hugging_layer.attention.self.query.weight - grouch_layer.slf_attn.q_proj.weight).norm())
# print((hugging_layer.attention.self.query.bias - grouch_layer.slf_attn.q_proj.bias).norm())
# print((hugging_layer.attention.self.key.weight - grouch_layer.slf_attn.k_proj.weight).norm())
# print((hugging_layer.attention.self.key.bias - grouch_layer.slf_attn.k_proj.bias).norm())
# print((hugging_layer.attention.self.value.weight - grouch_layer.slf_attn.v_proj.weight).norm())
# print((hugging_layer.attention.self.value.bias - grouch_layer.slf_attn.v_proj.bias).norm())
# print((hugging_layer.attention.output.dense.weight - grouch_layer.slf_attn.vw_proj.weight).norm())
# print((hugging_layer.attention.output.dense.bias - grouch_layer.slf_attn.vw_proj.bias).norm())
print
(
"norm of diff at last layer"
,
(
all_y
[
-
1
]
-
g_all_y
[
-
1
]).
norm
())
# print(all_y[-1][:, :, :10] - g_all_y[-1][:, :, :10])
self
.
assertTrue
(
np
.
allclose
(
all_y
[
-
1
].
detach
().
numpy
(),
g_all_y
[
-
1
].
detach
().
numpy
(),
atol
=
1e-4
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment