Unverified Commit 43f46aa7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[RAG] Fix rag from pretrained question encoder generator behavior (#11962)

* fix_torch_device_generate_test

* remove @

* fix rag from pretrained loading

* add test

* uplaod

* finish
parent 6db3a87d
...@@ -245,7 +245,6 @@ class RagPreTrainedModel(PreTrainedModel): ...@@ -245,7 +245,6 @@ class RagPreTrainedModel(PreTrainedModel):
question_encoder_pretrained_model_name_or_path: str = None, question_encoder_pretrained_model_name_or_path: str = None,
generator_pretrained_model_name_or_path: str = None, generator_pretrained_model_name_or_path: str = None,
retriever: RagRetriever = None, retriever: RagRetriever = None,
*model_args,
**kwargs **kwargs
) -> PreTrainedModel: ) -> PreTrainedModel:
r""" r"""
...@@ -310,7 +309,7 @@ class RagPreTrainedModel(PreTrainedModel): ...@@ -310,7 +309,7 @@ class RagPreTrainedModel(PreTrainedModel):
""" """
kwargs_question_encoder = { kwargs_question_encoder = {
argument[len("question_question_encoder_") :]: value argument[len("question_encoder_") :]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if argument.startswith("question_encoder_") if argument.startswith("question_encoder_")
} }
...@@ -340,11 +339,15 @@ class RagPreTrainedModel(PreTrainedModel): ...@@ -340,11 +339,15 @@ class RagPreTrainedModel(PreTrainedModel):
if "config" not in kwargs_question_encoder: if "config" not in kwargs_question_encoder:
from ..auto.configuration_auto import AutoConfig from ..auto.configuration_auto import AutoConfig
question_encoder_config = AutoConfig.from_pretrained(question_encoder_pretrained_model_name_or_path) question_encoder_config, kwargs_question_encoder = AutoConfig.from_pretrained(
question_encoder_pretrained_model_name_or_path,
**kwargs_question_encoder,
return_unused_kwargs=True,
)
kwargs_question_encoder["config"] = question_encoder_config kwargs_question_encoder["config"] = question_encoder_config
question_encoder = AutoModel.from_pretrained( question_encoder = AutoModel.from_pretrained(
question_encoder_pretrained_model_name_or_path, *model_args, **kwargs_question_encoder question_encoder_pretrained_model_name_or_path, **kwargs_question_encoder
) )
generator = kwargs_generator.pop("model", None) generator = kwargs_generator.pop("model", None)
...@@ -357,7 +360,10 @@ class RagPreTrainedModel(PreTrainedModel): ...@@ -357,7 +360,10 @@ class RagPreTrainedModel(PreTrainedModel):
if "config" not in kwargs_generator: if "config" not in kwargs_generator:
from ..auto.configuration_auto import AutoConfig from ..auto.configuration_auto import AutoConfig
generator_config = AutoConfig.from_pretrained(generator_pretrained_model_name_or_path) generator_config, kwargs_generator = AutoConfig.from_pretrained(
generator_pretrained_model_name_or_path, **kwargs_generator, return_unused_kwargs=True
)
kwargs_generator["config"] = generator_config kwargs_generator["config"] = generator_config
generator = AutoModelForSeq2SeqLM.from_pretrained( generator = AutoModelForSeq2SeqLM.from_pretrained(
......
...@@ -1132,12 +1132,17 @@ class RagModelSaveLoadTests(unittest.TestCase): ...@@ -1132,12 +1132,17 @@ class RagModelSaveLoadTests(unittest.TestCase):
"facebook/bart-large-cnn", "facebook/bart-large-cnn",
retriever=rag_retriever, retriever=rag_retriever,
config=rag_config, config=rag_config,
question_encoder_max_length=200,
generator_max_length=200,
).to(torch_device) ).to(torch_device)
# check that the from pretrained methods work # check that the from pretrained methods work
rag_token.save_pretrained(tmp_dirname) rag_token.save_pretrained(tmp_dirname)
rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever) rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
rag_token.to(torch_device) rag_token.to(torch_device)
self.assertTrue(rag_token.question_encoder.config.max_length == 200)
self.assertTrue(rag_token.generator.config.max_length == 200)
with torch.no_grad(): with torch.no_grad():
output = rag_token( output = rag_token(
input_ids, input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment