Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c5d532e5
Commit
c5d532e5
authored
Nov 01, 2018
by
thomwolf
Browse files
added conversion script
parent
90d360a7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
21 deletions
+100
-21
convert_tf_checkpoint.py
convert_tf_checkpoint.py
+82
-0
modeling_pytorch.py
modeling_pytorch.py
+18
-21
No files found.
convert_tf_checkpoint.py
0 → 100644
View file @
c5d532e5
# coding=utf-8
"""Convert BERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
re
import
argparse
import
tensorflow
as
tf
import
torch
from
.modeling_pytorch
import
BertConfig
,
BertModel
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
def
convert
():
# Load weights from TF model
path
=
args
.
tf_checkpoint_path
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
path
))
init_vars
=
tf
.
train
.
list_variables
(
path
)
names
=
[]
arrays
=
[]
for
name
,
shape
in
init_vars
:
print
(
"Loading {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
path
,
name
)
print
(
"Numpy array shape {}"
.
format
(
array
.
shape
))
names
.
append
(
name
)
arrays
.
append
(
array
)
# Initialise PyTorch model and fill weights-in
config
=
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
model
=
BertModel
(
config
)
for
name
,
array
in
zip
(
names
,
arrays
):
name
=
name
[
5
:]
# skip "bert/"
assert
name
[
-
2
:]
==
":0"
name
=
name
[:
-
2
]
name
=
name
.
split
(
'/'
)
pointer
=
model
for
m_name
in
name
:
if
re
.
fullmatch
(
r
'[A-Za-z]+\d+'
,
m_name
):
l
=
re
.
split
(
r
'(\d+)'
,
m_name
)
else
:
l
=
[
m_name
]
pointer
=
getattr
(
pointer
,
l
[
0
])
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
pointer
=
pointer
[
num
]
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
pointer
.
data
=
torch
.
from_numpy
(
array
)
# Save pytorch-model
torch
.
save
(
model
.
state_dict
(),
args
.
pytorch_dump_path
)
if
__name__
==
"__main__"
:
convert
()
return
None
modeling_pytorch.py
View file @
c5d532e5
...
@@ -119,7 +119,7 @@ class BERTLayerNorm(nn.Module):
...
@@ -119,7 +119,7 @@ class BERTLayerNorm(nn.Module):
self
.
variance_epsilon
=
variance_epsilon
self
.
variance_epsilon
=
variance_epsilon
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# TODO check it's identical to TF implementation in details
# TODO check it's identical to TF implementation in details
(epsilon and axes)
u
=
x
.
mean
(
-
1
,
keepdim
=
True
)
u
=
x
.
mean
(
-
1
,
keepdim
=
True
)
s
=
(
x
-
u
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
s
=
(
x
-
u
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
x
=
(
x
-
u
)
/
torch
.
sqrt
(
s
+
self
.
variance_epsilon
)
x
=
(
x
-
u
)
/
torch
.
sqrt
(
s
+
self
.
variance_epsilon
)
...
@@ -128,9 +128,7 @@ class BERTLayerNorm(nn.Module):
...
@@ -128,9 +128,7 @@ class BERTLayerNorm(nn.Module):
# inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
# inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
class
BERTEmbeddings
(
nn
.
Module
):
class
BERTEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
embedding_size
,
vocab_size
,
def
__init__
(
self
,
config
):
token_type_vocab_size
,
max_position_embeddings
,
config
):
self
.
word_embeddings
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
embedding_size
)
self
.
word_embeddings
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
embedding_size
)
...
@@ -323,27 +321,32 @@ class BERTEncoder(nn.Module):
...
@@ -323,27 +321,32 @@ class BERTEncoder(nn.Module):
Return:
Return:
float Tensor of shape [batch_size, seq_length, hidden_size]
float Tensor of shape [batch_size, seq_length, hidden_size]
"""
"""
all_encoder_layers
=
[]
for
layer_module
in
self
.
layer
:
for
layer_module
in
self
.
layer
:
hidden_states
=
layer_module
(
hidden_states
,
attention_mask
)
hidden_states
=
layer_module
(
hidden_states
,
attention_mask
)
return
hidden_states
all_encoder_layers
.
append
(
hidden_states
)
return
all_encoder_layers
class
BERTPooler
(
nn
.
Module
):
class
BERTPooler
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
BERTPooler
,
self
).
__init__
()
super
(
BERTPooler
,
self
).
__init__
()
layer
=
BERTLayer
(
n_ctx
,
cfg
,
scale
=
Tru
e
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_siz
e
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
activation
=
nn
.
Tanh
()
def
forward
(
self
,
hidden_states
,
attention_mask
):
def
forward
(
self
,
hidden_states
):
"""
"""
Args:
Args:
hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size]
hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size]
Return:
Return:
float Tensor of shape [batch_size,
seq_length,
hidden_size]
float Tensor of shape [batch_size, hidden_size]
"""
"""
for
layer_module
in
self
.
layer
:
# We "pool" the model by simply taking the hidden state corresponding
hidden_states
=
layer_module
(
hidden_states
,
attention_mask
)
# to the first token. We assume that this has been pre-trained
return
hidden_states
first_token_tensor
=
hidden_states
[:,
0
]
pooled_output
=
self
.
dense
(
first_token_tensor
)
pooled_output
=
self
.
activation
(
pooled_output
)
return
pooled_output
class
BertModel
(
nn
.
Module
):
class
BertModel
(
nn
.
Module
):
...
@@ -381,14 +384,6 @@ class BertModel(nn.Module):
...
@@ -381,14 +384,6 @@ class BertModel(nn.Module):
is invalid.
is invalid.
"""
"""
super
(
BertModel
).
__init__
()
super
(
BertModel
).
__init__
()
config
=
copy
.
deepcopy
(
config
)
if
not
is_training
:
config
.
hidden_dropout_prob
=
0.0
config
.
attention_probs_dropout_prob
=
0.0
batch_size
=
input_ids
.
size
(
0
)
seq_length
=
input_ids
.
size
(
1
)
self
.
embeddings
=
BERTEmbeddings
(
config
)
self
.
embeddings
=
BERTEmbeddings
(
config
)
self
.
encoder
=
BERTEncoder
(
config
)
self
.
encoder
=
BERTEncoder
(
config
)
self
.
pooler
=
BERTPooler
(
config
)
self
.
pooler
=
BERTPooler
(
config
)
...
@@ -396,4 +391,6 @@ class BertModel(nn.Module):
...
@@ -396,4 +391,6 @@ class BertModel(nn.Module):
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
):
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
):
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
all_encoder_layers
=
self
.
encoder
(
embedding_output
,
attention_mask
)
all_encoder_layers
=
self
.
encoder
(
embedding_output
,
attention_mask
)
return
all_encoder_layers
sequence_output
=
all_encoder_layers
[
-
1
]
pooled_output
=
self
.
pooler
(
sequence_output
)
return
all_encoder_layers
,
pooled_output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment