"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "ebb9b4060b3785e79d08b061442a9f6864689359"
Unverified Commit 571c7a11 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Rag] Fix wrong usage of `num_beams` and `bos_token_id` in Rag Sequence generation (#7386)

* fix_rag_sequence

* add second bug fix
parent 415071b4
...@@ -882,7 +882,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -882,7 +882,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
hypos = [] hypos = []
kwargs["num_beams"] = num_beams kwargs["num_beams"] = num_beams
kwargs["num_return_sequences"] = num_return_sequences kwargs["num_return_sequences"] = num_beams
kwargs["attention_mask"] = None kwargs["attention_mask"] = None
for index in range(len(input_ids)): for index in range(len(input_ids)):
...@@ -916,7 +916,8 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -916,7 +916,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
) )
# bos_token_id is None for T5 # bos_token_id is None for T5
use_bos = self.config.bos_token_id is not None and target[:, 0].eq(self.config.bos_token_id).all() bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
use_bos = bos_token_id is not None and target[:, 0].eq(bos_token_id).all()
def _mask_pads(ll, smooth_obj): def _mask_pads(ll, smooth_obj):
pad_mask = target.eq(self.config.generator.pad_token_id) pad_mask = target.eq(self.config.generator.pad_token_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment