Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
04287a4d
Commit
04287a4d
authored
Nov 03, 2018
by
thomwolf
Browse files
special edition script
parent
25f73add
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
108 additions
and
4 deletions
+108
-4
convert_tf_checkpoint_to_pytorch_special_edition.py
convert_tf_checkpoint_to_pytorch_special_edition.py
+99
-0
modeling_pytorch.py
modeling_pytorch.py
+6
-1
run_classifier_pytorch.py
run_classifier_pytorch.py
+3
-3
No files found.
convert_tf_checkpoint_to_pytorch_special_edition.py
0 → 100644
View file @
04287a4d
# coding=utf-8
"""Convert BERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
re
import
argparse
import
tensorflow
as
tf
import
torch
import
numpy
as
np
from
modeling_pytorch
import
BertConfig
,
BertForSequenceClassification
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
def
convert
():
# Initialise PyTorch model
config
=
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
model
=
BertForSequenceClassification
(
config
,
num_labels
=
2
)
# Load weights from TF model
path
=
args
.
tf_checkpoint_path
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
path
))
init_vars
=
tf
.
train
.
list_variables
(
path
)
names
=
[]
arrays
=
[]
for
name
,
shape
in
init_vars
:
print
(
"Loading {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
path
,
name
)
print
(
"Numpy array shape {}"
.
format
(
array
.
shape
))
names
.
append
(
name
)
arrays
.
append
(
array
)
for
name
,
array
in
zip
(
names
,
arrays
):
# name = name[5:] # skip "bert/"
print
(
"Loading {} or shape {}"
.
format
(
name
,
array
.
shape
))
name
=
name
.
split
(
'/'
)
if
name
[
0
]
in
[
'cls'
]:
if
name
[
1
]
in
[
'predictions'
]:
print
(
"Skipping"
)
continue
elif
name
[
1
]
in
[
'seq_relationship'
]:
name
=
name
[
2
:]
assert
len
(
name
)
==
1
name
[
0
]
=
name
[
0
][
7
:]
pointer
=
model
.
classifier
else
:
pointer
=
model
for
m_name
in
name
:
if
re
.
fullmatch
(
r
'[A-Za-z]+_\d+'
,
m_name
):
l
=
re
.
split
(
r
'_(\d+)'
,
m_name
)
else
:
l
=
[
m_name
]
if
l
[
0
]
in
[
'kernel'
,
'weights'
]:
pointer
=
getattr
(
pointer
,
'weight'
)
else
:
pointer
=
getattr
(
pointer
,
l
[
0
])
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
'_embeddings'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
m_name
==
'kernel'
:
array
=
np
.
transpose
(
array
)
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
pointer
.
data
=
torch
.
from_numpy
(
array
)
# Save pytorch-model
torch
.
save
(
model
.
state_dict
(),
args
.
pytorch_dump_path
)
if
__name__
==
"__main__"
:
convert
()
modeling_pytorch.py
View file @
04287a4d
...
...
@@ -482,9 +482,14 @@ class BertForQuestionAnswering(nn.Module):
def
init_weights
(
m
):
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Embedding
)):
print
(
"Initializing {}"
.
format
(
m
))
# Slight difference here with the TF version which uses truncated_normal
# Slight difference here with the TF version which uses truncated_normal
for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
m
.
weight
.
data
.
normal_
(
config
.
initializer_range
)
elif
isinstance
(
m
,
BERTLayerNorm
):
m
.
beta
.
data
.
normal_
(
config
.
initializer_range
)
m
.
gamme
.
data
.
normal_
(
config
.
initializer_range
)
if
isinstance
(
m
,
nn
.
Linear
):
m
.
bias
.
data
.
zero_
()
self
.
apply
(
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
,
start_positions
=
None
,
end_positions
=
None
):
...
...
run_classifier_pytorch.py
View file @
04287a4d
...
...
@@ -480,9 +480,9 @@ def main():
model
=
BertForSequenceClassification
(
bert_config
,
len
(
label_list
))
if
args
.
init_checkpoint
is
not
None
:
model
.
bert
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
.
to
(
device
)
if
n_gpu
>
1
:
model
=
torch
.
nn
.
DataParallel
(
model
)
...
...
@@ -575,7 +575,7 @@ def main():
eval_loss
+=
tmp_eval_loss
.
item
()
eval_accuracy
+=
tmp_eval_accuracy
nb_eval_examples
+=
input_ids
.
size
(
0
)
eval_loss
=
eval_loss
/
nb_eval_examples
#len(eval_dataloader)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment