Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
04287a4d
Commit
04287a4d
authored
Nov 03, 2018
by
thomwolf
Browse files
special edition script
parent
25f73add
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
108 additions
and
4 deletions
+108
-4
convert_tf_checkpoint_to_pytorch_special_edition.py
convert_tf_checkpoint_to_pytorch_special_edition.py
+99
-0
modeling_pytorch.py
modeling_pytorch.py
+6
-1
run_classifier_pytorch.py
run_classifier_pytorch.py
+3
-3
No files found.
convert_tf_checkpoint_to_pytorch_special_edition.py
0 → 100644
View file @
04287a4d
# coding=utf-8
"""Convert BERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
re
import
argparse
import
tensorflow
as
tf
import
torch
import
numpy
as
np
from
modeling_pytorch
import
BertConfig
,
BertForSequenceClassification
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
def
convert
():
# Initialise PyTorch model
config
=
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
model
=
BertForSequenceClassification
(
config
,
num_labels
=
2
)
# Load weights from TF model
path
=
args
.
tf_checkpoint_path
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
path
))
init_vars
=
tf
.
train
.
list_variables
(
path
)
names
=
[]
arrays
=
[]
for
name
,
shape
in
init_vars
:
print
(
"Loading {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
path
,
name
)
print
(
"Numpy array shape {}"
.
format
(
array
.
shape
))
names
.
append
(
name
)
arrays
.
append
(
array
)
for
name
,
array
in
zip
(
names
,
arrays
):
# name = name[5:] # skip "bert/"
print
(
"Loading {} or shape {}"
.
format
(
name
,
array
.
shape
))
name
=
name
.
split
(
'/'
)
if
name
[
0
]
in
[
'cls'
]:
if
name
[
1
]
in
[
'predictions'
]:
print
(
"Skipping"
)
continue
elif
name
[
1
]
in
[
'seq_relationship'
]:
name
=
name
[
2
:]
assert
len
(
name
)
==
1
name
[
0
]
=
name
[
0
][
7
:]
pointer
=
model
.
classifier
else
:
pointer
=
model
for
m_name
in
name
:
if
re
.
fullmatch
(
r
'[A-Za-z]+_\d+'
,
m_name
):
l
=
re
.
split
(
r
'_(\d+)'
,
m_name
)
else
:
l
=
[
m_name
]
if
l
[
0
]
in
[
'kernel'
,
'weights'
]:
pointer
=
getattr
(
pointer
,
'weight'
)
else
:
pointer
=
getattr
(
pointer
,
l
[
0
])
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
'_embeddings'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
m_name
==
'kernel'
:
array
=
np
.
transpose
(
array
)
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
pointer
.
data
=
torch
.
from_numpy
(
array
)
# Save pytorch-model
torch
.
save
(
model
.
state_dict
(),
args
.
pytorch_dump_path
)
if
__name__
==
"__main__"
:
convert
()
modeling_pytorch.py
View file @
04287a4d
...
...
@@ -482,9 +482,14 @@ class BertForQuestionAnswering(nn.Module):
def
init_weights
(
m
):
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Embedding
)):
print
(
"Initializing {}"
.
format
(
m
))
# Slight difference here with the TF version which uses truncated_normal
# Slight difference here with the TF version which uses truncated_normal
for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
m
.
weight
.
data
.
normal_
(
config
.
initializer_range
)
elif
isinstance
(
m
,
BERTLayerNorm
):
m
.
beta
.
data
.
normal_
(
config
.
initializer_range
)
m
.
gamme
.
data
.
normal_
(
config
.
initializer_range
)
if
isinstance
(
m
,
nn
.
Linear
):
m
.
bias
.
data
.
zero_
()
self
.
apply
(
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
,
start_positions
=
None
,
end_positions
=
None
):
...
...
run_classifier_pytorch.py
View file @
04287a4d
...
...
@@ -480,9 +480,9 @@ def main():
model
=
BertForSequenceClassification
(
bert_config
,
len
(
label_list
))
if
args
.
init_checkpoint
is
not
None
:
model
.
bert
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
.
to
(
device
)
if
n_gpu
>
1
:
model
=
torch
.
nn
.
DataParallel
(
model
)
...
...
@@ -575,7 +575,7 @@ def main():
eval_loss
+=
tmp_eval_loss
.
item
()
eval_accuracy
+=
tmp_eval_accuracy
nb_eval_examples
+=
input_ids
.
size
(
0
)
eval_loss
=
eval_loss
/
nb_eval_examples
#len(eval_dataloader)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment