Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
314bc6bb
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "4ecf9dd62d8e53bfca08bc38a46019b2c9f2b995"
Commit
314bc6bb
authored
May 27, 2019
by
Chris
Browse files
added transposes to attention.self.[query,key,value]
parent
8de1faea
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
10 deletions
+21
-10
pytorch_pretrained_bert/convert_pytorch_checkpoint_to_tf.py
pytorch_pretrained_bert/convert_pytorch_checkpoint_to_tf.py
+21
-10
No files found.
pytorch_pretrained_bert/convert_pytorch_checkpoint_to_tf.py
View file @
314bc6bb
...
@@ -39,6 +39,24 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
...
@@ -39,6 +39,24 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
N BertForQuestionAnswering
N BertForQuestionAnswering
"""
"""
tensors_to_transopse
=
(
"dense.weight"
,
"attention.self.query"
,
"attention.self.key"
,
"attention.self.value"
)
var_map
=
(
(
'layer.'
,
'layer_'
),
(
'word_embeddings.weight'
,
'word_embeddings'
),
(
'position_embeddings.weight'
,
'position_embeddings'
),
(
'token_type_embeddings.weight'
,
'token_type_embeddings'
),
(
'.'
,
'/'
),
(
'LayerNorm/weight'
,
'LayerNorm/gamma'
),
(
'LayerNorm/bias'
,
'LayerNorm/beta'
),
(
'weight'
,
'kernel'
)
)
if
not
os
.
path
.
isdir
(
ckpt_dir
):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
os
.
makedirs
(
ckpt_dir
)
...
@@ -47,15 +65,8 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
...
@@ -47,15 +65,8 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
tf_vars
=
[]
tf_vars
=
[]
def
to_tf_var_name
(
name
:
str
):
def
to_tf_var_name
(
name
:
str
):
"""todo: compile as regex"""
for
patt
,
repl
in
iter
(
var_map
):
name
=
name
.
replace
(
'layer.'
,
'layer_'
)
name
=
name
.
replace
(
patt
,
repl
)
name
=
name
.
replace
(
'word_embeddings.weight'
,
'word_embeddings'
)
name
=
name
.
replace
(
'position_embeddings.weight'
,
'position_embeddings'
)
name
=
name
.
replace
(
'token_type_embeddings.weight'
,
'token_type_embeddings'
)
name
=
name
.
replace
(
'.'
,
'/'
)
name
=
name
.
replace
(
'LayerNorm/weight'
,
'LayerNorm/gamma'
)
name
=
name
.
replace
(
'LayerNorm/bias'
,
'LayerNorm/beta'
)
name
=
name
.
replace
(
'weight'
,
'kernel'
)
return
'bert/{}'
.
format
(
name
)
return
'bert/{}'
.
format
(
name
)
def
assign_tf_var
(
tensor
:
np
.
ndarray
,
name
:
str
):
def
assign_tf_var
(
tensor
:
np
.
ndarray
,
name
:
str
):
...
@@ -69,7 +80,7 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
...
@@ -69,7 +80,7 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
for
var_name
in
state_dict
:
for
var_name
in
state_dict
:
tf_name
=
to_tf_var_name
(
var_name
)
tf_name
=
to_tf_var_name
(
var_name
)
torch_tensor
=
state_dict
[
var_name
].
numpy
()
torch_tensor
=
state_dict
[
var_name
].
numpy
()
if
var_name
.
endswith
(
'dense.weight'
):
if
any
([
x
in
var_name
for
x
in
tensors_to_transopse
]
):
torch_tensor
=
torch_tensor
.
T
torch_tensor
=
torch_tensor
.
T
tf_tensor
=
assign_tf_var
(
tensor
=
torch_tensor
,
name
=
tf_name
)
tf_tensor
=
assign_tf_var
(
tensor
=
torch_tensor
,
name
=
tf_name
)
tf_vars
.
append
(
tf_tensor
)
tf_vars
.
append
(
tf_tensor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment