Unverified Commit de2f7221 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: remove near-duplicate sample/greedy copy (#30773)

parent ce87dca1
This diff is collapsed.
...@@ -1739,7 +1739,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1739,7 +1739,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self._greedy_search( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
...@@ -2832,7 +2832,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2832,7 +2832,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self._greedy_search( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
......
...@@ -1676,7 +1676,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ...@@ -1676,7 +1676,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self._greedy_search( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
...@@ -2691,7 +2691,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2691,7 +2691,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self._greedy_search( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
......
...@@ -1550,7 +1550,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1550,7 +1550,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" greedy search." " greedy search."
) )
return self._greedy_search( return self._sample(
input_ids, input_ids,
logits_processor=pre_processor, logits_processor=pre_processor,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment