Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ab0e8932
Commit
ab0e8932
authored
Nov 01, 2018
by
thomwolf
Browse files
convertion script WIP
parent
5581edb4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
20 deletions
+25
-20
convert_tf_checkpoint.py
convert_tf_checkpoint.py
+15
-10
modeling_pytorch.py
modeling_pytorch.py
+10
-10
No files found.
convert_tf_checkpoint.py
View file @
ab0e8932
...
@@ -10,7 +10,7 @@ import argparse
...
@@ -10,7 +10,7 @@ import argparse
import
tensorflow
as
tf
import
tensorflow
as
tf
import
torch
import
torch
from
.
modeling_pytorch
import
BertConfig
,
BertModel
from
modeling_pytorch
import
BertConfig
,
BertModel
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -35,6 +35,10 @@ parser.add_argument("--pytorch_dump_path",
...
@@ -35,6 +35,10 @@ parser.add_argument("--pytorch_dump_path",
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
def
convert
():
def
convert
():
# Initialise PyTorch model
config
=
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
model
=
BertModel
(
config
)
# Load weights from TF model
# Load weights from TF model
path
=
args
.
tf_checkpoint_path
path
=
args
.
tf_checkpoint_path
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
path
))
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
path
))
...
@@ -49,24 +53,26 @@ def convert():
...
@@ -49,24 +53,26 @@ def convert():
names
.
append
(
name
)
names
.
append
(
name
)
arrays
.
append
(
array
)
arrays
.
append
(
array
)
# Initialise PyTorch model and fill weights-in
config
=
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
model
=
BertModel
(
config
)
for
name
,
array
in
zip
(
names
,
arrays
):
for
name
,
array
in
zip
(
names
,
arrays
):
name
=
name
[
5
:]
# skip "bert/"
name
=
name
[
5
:]
# skip "bert/"
assert
name
[
-
2
:]
==
":0"
name
=
name
[:
-
2
]
name
=
name
.
split
(
'/'
)
name
=
name
.
split
(
'/'
)
pointer
=
model
pointer
=
model
for
m_name
in
name
:
for
m_name
in
name
:
if
re
.
fullmatch
(
r
'[A-Za-z]+\d+'
,
m_name
):
if
re
.
fullmatch
(
r
'[A-Za-z]+
_
\d+'
,
m_name
):
l
=
re
.
split
(
r
'(\d+)'
,
m_name
)
l
=
re
.
split
(
r
'
_
(\d+)'
,
m_name
)
else
:
else
:
l
=
[
m_name
]
l
=
[
m_name
]
pointer
=
getattr
(
pointer
,
l
[
0
])
if
l
[
0
]
==
'kernel'
:
pointer
=
getattr
(
pointer
,
'weight'
)
else
:
pointer
=
getattr
(
pointer
,
l
[
0
])
if
len
(
l
)
>=
2
:
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
num
=
int
(
l
[
1
])
pointer
=
pointer
[
num
]
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
'_embeddings'
:
pointer
=
getattr
(
pointer
,
'weight'
)
# elif m_name == 'kernel':
# pointer = getattr(pointer, 'weight')
try
:
try
:
assert
pointer
.
shape
==
array
.
shape
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
except
AssertionError
as
e
:
...
@@ -79,4 +85,3 @@ def convert():
...
@@ -79,4 +85,3 @@ def convert():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
convert
()
convert
()
return
None
modeling_pytorch.py
View file @
ab0e8932
...
@@ -129,8 +129,8 @@ class BERTLayerNorm(nn.Module):
...
@@ -129,8 +129,8 @@ class BERTLayerNorm(nn.Module):
class
BERTEmbeddings
(
nn
.
Module
):
class
BERTEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
BERTEmbeddings
,
self
).
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
embedding
_size
)
self
.
word_embeddings
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden
_size
)
# Position embeddings are (normally) a contiguous range so we could use a slice
# Position embeddings are (normally) a contiguous range so we could use a slice
# Since the position embedding table is a learned variable, we create it
# Since the position embedding table is a learned variable, we create it
...
@@ -142,12 +142,12 @@ class BERTEmbeddings(nn.Module):
...
@@ -142,12 +142,12 @@ class BERTEmbeddings(nn.Module):
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# perform a slice.
# perform a slice.
self
.
position_embeddings
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
embedding
_size
)
self
.
position_embeddings
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
hidden
_size
)
# token_type_embeddings vocabulary is very small. TF used one-hot embeddings to speedup.
# token_type_embeddings vocabulary is very small. TF used one-hot embeddings to speedup.
self
.
token_type_embeddings
=
nn
.
Embedding
(
config
.
token_
type_vocab_size
,
config
.
embedding
_size
)
self
.
token_type_embeddings
=
nn
.
Embedding
(
config
.
type_vocab_size
,
config
.
hidden
_size
)
self
.
LayerNorm
=
BERTLayerNorm
()
# Not snake-cased to stick with TF model variable name
self
.
LayerNorm
=
BERTLayerNorm
(
config
)
# Not snake-cased to stick with TF model variable name
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
):
...
@@ -185,7 +185,7 @@ class BERTSelfAttention(nn.Module):
...
@@ -185,7 +185,7 @@ class BERTSelfAttention(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
config
.
attention_probs_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
attention_probs_dropout_prob
)
def
transpose_for_scores
(
self
,
input_tensor
,
num_attention_heads
,
is_key_tensor
=
False
):
def
transpose_for_scores
(
self
,
x
,
is_key_tensor
=
False
):
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
self
.
attention_head_size
)
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
self
.
attention_head_size
)
x
=
x
.
view
(
*
new_x_shape
)
x
=
x
.
view
(
*
new_x_shape
)
if
is_key_tensor
:
if
is_key_tensor
:
...
@@ -270,7 +270,7 @@ class BERTAttention(nn.Module):
...
@@ -270,7 +270,7 @@ class BERTAttention(nn.Module):
class
BERTIntermediate
(
nn
.
Module
):
class
BERTIntermediate
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
BERT
Output
,
self
).
__init__
()
super
(
BERT
Intermediate
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
intermediate_act_fn
=
gelu
self
.
intermediate_act_fn
=
gelu
...
@@ -305,13 +305,13 @@ class BERTLayer(nn.Module):
...
@@ -305,13 +305,13 @@ class BERTLayer(nn.Module):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
)
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
)
intermediate_output
=
self
.
intermediate
(
attention_output
)
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
return
hidden_states
return
layer_output
class
BERTEncoder
(
nn
.
Module
):
class
BERTEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
BERTEncoder
,
self
).
__init__
()
super
(
BERTEncoder
,
self
).
__init__
()
layer
=
BERTLayer
(
n_ctx
,
cfg
,
scale
=
True
)
layer
=
BERTLayer
(
config
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
):
def
forward
(
self
,
hidden_states
,
attention_mask
):
...
@@ -383,7 +383,7 @@ class BertModel(nn.Module):
...
@@ -383,7 +383,7 @@ class BertModel(nn.Module):
ValueError: The config is invalid or one of the input tensor shapes
ValueError: The config is invalid or one of the input tensor shapes
is invalid.
is invalid.
"""
"""
super
(
BertModel
).
__init__
()
super
(
BertModel
,
self
).
__init__
()
self
.
embeddings
=
BERTEmbeddings
(
config
)
self
.
embeddings
=
BERTEmbeddings
(
config
)
self
.
encoder
=
BERTEncoder
(
config
)
self
.
encoder
=
BERTEncoder
(
config
)
self
.
pooler
=
BERTPooler
(
config
)
self
.
pooler
=
BERTPooler
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment