Unverified Commit 90750042 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: handle `logits_warper` update in models with custom generate fn (#31957)

handle logits_warper update in models with custom generate fn
parent 454bc14d
......@@ -2219,7 +2219,7 @@ class GenerationMixin:
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: "BaseStreamer",
logits_warper: LogitsProcessorList,
logits_warper: Optional[LogitsProcessorList],
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
......@@ -2826,7 +2826,7 @@ class GenerationMixin:
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: LogitsProcessorList,
logits_warper: Optional[LogitsProcessorList],
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
......@@ -3033,7 +3033,7 @@ class GenerationMixin:
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
logits_warper: LogitsProcessorList,
logits_warper: Optional[LogitsProcessorList],
**model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]:
r"""
......
......@@ -26,7 +26,7 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation.configuration_utils import GenerationConfig
from ...generation.configuration_utils import GenerationConfig, GenerationMode
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from ...generation.stopping_criteria import StoppingCriteriaList
from ...modeling_attn_mask_utils import (
......@@ -1618,16 +1618,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
model_kwargs["delay_pattern_mask"] = delay_pattern_mask
# 7. determine generation mode
is_greedy_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
)
is_sample_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
)
generation_mode = generation_config.get_generation_mode()
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
......@@ -1649,27 +1640,13 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
generation_config=generation_config, stopping_criteria=stopping_criteria
)
if is_greedy_gen_mode:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing greedy search, "
f"but is {generation_config.num_return_sequences}."
)
# 11. run greedy search
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
elif is_sample_gen_mode:
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
......@@ -1682,7 +1659,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper=prepared_logits_warper,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
......@@ -2714,16 +2691,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
streamer.put(input_ids.cpu())
# 7. determine generation mode
is_greedy_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
)
is_sample_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
)
generation_mode = generation_config.get_generation_mode()
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
......@@ -2745,27 +2713,13 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
generation_config=generation_config, stopping_criteria=stopping_criteria
)
if is_greedy_gen_mode:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing greedy search, "
f"but is {generation_config.num_return_sequences}."
)
# 11. run greedy search
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
elif is_sample_gen_mode:
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
......@@ -2779,7 +2733,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper=prepared_logits_warper,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
......
......@@ -26,7 +26,7 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation.configuration_utils import GenerationConfig
from ...generation.configuration_utils import GenerationConfig, GenerationMode
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from ...generation.stopping_criteria import StoppingCriteriaList
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
......@@ -1539,16 +1539,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
model_kwargs["delay_pattern_mask"] = delay_pattern_mask
# 7. determine generation mode
is_greedy_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
)
is_sample_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
)
generation_mode = generation_config.get_generation_mode()
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
......@@ -1570,27 +1561,13 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
generation_config=generation_config, stopping_criteria=stopping_criteria
)
if is_greedy_gen_mode:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing greedy search, "
f"but is {generation_config.num_return_sequences}."
)
# 11. run greedy search
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
elif is_sample_gen_mode:
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
......@@ -1603,7 +1580,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper=prepared_logits_warper,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
......@@ -2557,16 +2534,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
streamer.put(input_ids.cpu())
# 7. determine generation mode
is_greedy_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
)
is_sample_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
)
generation_mode = generation_config.get_generation_mode()
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
......@@ -2588,27 +2556,13 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
generation_config=generation_config, stopping_criteria=stopping_criteria
)
if is_greedy_gen_mode:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing greedy search, "
f"but is {generation_config.num_return_sequences}."
)
# 11. run greedy search
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
elif is_sample_gen_mode:
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
......@@ -2622,7 +2576,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper=prepared_logits_warper,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
......
......@@ -1558,6 +1558,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
generation_config=generation_config,
synced_gpus=False,
streamer=None,
logits_warper=None,
**model_kwargs,
)
elif generation_config.num_beams > 1:
......@@ -1579,6 +1580,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=False,
logits_warper=None,
**model_kwargs,
)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment