Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0c64b188
Unverified
Commit
0c64b188
authored
Oct 14, 2020
by
Quentin Lhoest
Committed by
GitHub
Oct 14, 2020
Browse files
Fix bert position ids in DPR convert script (#7776)
* fix bert position ids in DPR convert script * style
parent
7968051a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
3 deletions
+8
-3
src/transformers/convert_dpr_original_checkpoint_to_pytorch.py
...ransformers/convert_dpr_original_checkpoint_to_pytorch.py
+8
-3
No files found.
src/transformers/convert_dpr_original_checkpoint_to_pytorch.py
View file @
0c64b188
...
@@ -44,7 +44,8 @@ class DPRContextEncoderState(DPRState):
...
@@ -44,7 +44,8 @@ class DPRContextEncoderState(DPRState):
print
(
"Loading DPR biencoder from {}"
.
format
(
self
.
src_file
))
print
(
"Loading DPR biencoder from {}"
.
format
(
self
.
src_file
))
saved_state
=
load_states_from_checkpoint
(
self
.
src_file
)
saved_state
=
load_states_from_checkpoint
(
self
.
src_file
)
encoder
,
prefix
=
model
.
ctx_encoder
,
"ctx_model."
encoder
,
prefix
=
model
.
ctx_encoder
,
"ctx_model."
state_dict
=
{}
# Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
state_dict
=
{
"bert_model.embeddings.position_ids"
:
model
.
ctx_encoder
.
bert_model
.
embeddings
.
position_ids
}
for
key
,
value
in
saved_state
.
model_dict
.
items
():
for
key
,
value
in
saved_state
.
model_dict
.
items
():
if
key
.
startswith
(
prefix
):
if
key
.
startswith
(
prefix
):
key
=
key
[
len
(
prefix
)
:]
key
=
key
[
len
(
prefix
)
:]
...
@@ -61,7 +62,8 @@ class DPRQuestionEncoderState(DPRState):
...
@@ -61,7 +62,8 @@ class DPRQuestionEncoderState(DPRState):
print
(
"Loading DPR biencoder from {}"
.
format
(
self
.
src_file
))
print
(
"Loading DPR biencoder from {}"
.
format
(
self
.
src_file
))
saved_state
=
load_states_from_checkpoint
(
self
.
src_file
)
saved_state
=
load_states_from_checkpoint
(
self
.
src_file
)
encoder
,
prefix
=
model
.
question_encoder
,
"question_model."
encoder
,
prefix
=
model
.
question_encoder
,
"question_model."
state_dict
=
{}
# Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
state_dict
=
{
"bert_model.embeddings.position_ids"
:
model
.
question_encoder
.
bert_model
.
embeddings
.
position_ids
}
for
key
,
value
in
saved_state
.
model_dict
.
items
():
for
key
,
value
in
saved_state
.
model_dict
.
items
():
if
key
.
startswith
(
prefix
):
if
key
.
startswith
(
prefix
):
key
=
key
[
len
(
prefix
)
:]
key
=
key
[
len
(
prefix
)
:]
...
@@ -77,7 +79,10 @@ class DPRReaderState(DPRState):
...
@@ -77,7 +79,10 @@ class DPRReaderState(DPRState):
model
=
DPRReader
(
DPRConfig
(
**
BertConfig
.
get_config_dict
(
"bert-base-uncased"
)[
0
]))
model
=
DPRReader
(
DPRConfig
(
**
BertConfig
.
get_config_dict
(
"bert-base-uncased"
)[
0
]))
print
(
"Loading DPR reader from {}"
.
format
(
self
.
src_file
))
print
(
"Loading DPR reader from {}"
.
format
(
self
.
src_file
))
saved_state
=
load_states_from_checkpoint
(
self
.
src_file
)
saved_state
=
load_states_from_checkpoint
(
self
.
src_file
)
state_dict
=
{}
# Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
state_dict
=
{
"encoder.bert_model.embeddings.position_ids"
:
model
.
span_predictor
.
encoder
.
bert_model
.
embeddings
.
position_ids
}
for
key
,
value
in
saved_state
.
model_dict
.
items
():
for
key
,
value
in
saved_state
.
model_dict
.
items
():
if
key
.
startswith
(
"encoder."
)
and
not
key
.
startswith
(
"encoder.encode_proj"
):
if
key
.
startswith
(
"encoder."
)
and
not
key
.
startswith
(
"encoder.encode_proj"
):
key
=
"encoder.bert_model."
+
key
[
len
(
"encoder."
)
:]
key
=
"encoder.bert_model."
+
key
[
len
(
"encoder."
)
:]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment