Unverified Commit 0c64b188 authored by Quentin Lhoest's avatar Quentin Lhoest Committed by GitHub
Browse files

Fix bert position ids in DPR convert script (#7776)

* fix bert position ids in DPR convert script

* style
parent 7968051a
...@@ -44,7 +44,8 @@ class DPRContextEncoderState(DPRState): ...@@ -44,7 +44,8 @@ class DPRContextEncoderState(DPRState):
print("Loading DPR biencoder from {}".format(self.src_file)) print("Loading DPR biencoder from {}".format(self.src_file))
saved_state = load_states_from_checkpoint(self.src_file) saved_state = load_states_from_checkpoint(self.src_file)
encoder, prefix = model.ctx_encoder, "ctx_model." encoder, prefix = model.ctx_encoder, "ctx_model."
state_dict = {} # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
state_dict = {"bert_model.embeddings.position_ids": model.ctx_encoder.bert_model.embeddings.position_ids}
for key, value in saved_state.model_dict.items(): for key, value in saved_state.model_dict.items():
if key.startswith(prefix): if key.startswith(prefix):
key = key[len(prefix) :] key = key[len(prefix) :]
...@@ -61,7 +62,8 @@ class DPRQuestionEncoderState(DPRState): ...@@ -61,7 +62,8 @@ class DPRQuestionEncoderState(DPRState):
print("Loading DPR biencoder from {}".format(self.src_file)) print("Loading DPR biencoder from {}".format(self.src_file))
saved_state = load_states_from_checkpoint(self.src_file) saved_state = load_states_from_checkpoint(self.src_file)
encoder, prefix = model.question_encoder, "question_model." encoder, prefix = model.question_encoder, "question_model."
state_dict = {} # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
state_dict = {"bert_model.embeddings.position_ids": model.question_encoder.bert_model.embeddings.position_ids}
for key, value in saved_state.model_dict.items(): for key, value in saved_state.model_dict.items():
if key.startswith(prefix): if key.startswith(prefix):
key = key[len(prefix) :] key = key[len(prefix) :]
...@@ -77,7 +79,10 @@ class DPRReaderState(DPRState): ...@@ -77,7 +79,10 @@ class DPRReaderState(DPRState):
model = DPRReader(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0])) model = DPRReader(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
print("Loading DPR reader from {}".format(self.src_file)) print("Loading DPR reader from {}".format(self.src_file))
saved_state = load_states_from_checkpoint(self.src_file) saved_state = load_states_from_checkpoint(self.src_file)
state_dict = {} # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
state_dict = {
"encoder.bert_model.embeddings.position_ids": model.span_predictor.encoder.bert_model.embeddings.position_ids
}
for key, value in saved_state.model_dict.items(): for key, value in saved_state.model_dict.items():
if key.startswith("encoder.") and not key.startswith("encoder.encode_proj"): if key.startswith("encoder.") and not key.startswith("encoder.encode_proj"):
key = "encoder.bert_model." + key[len("encoder.") :] key = "encoder.bert_model." + key[len("encoder.") :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment