Unverified Commit 4a564490 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Musicgen: CFG is manually added (#25173)

parent 05cda5df
...@@ -27,7 +27,7 @@ from torch.utils.checkpoint import checkpoint ...@@ -27,7 +27,7 @@ from torch.utils.checkpoint import checkpoint
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation.configuration_utils import GenerationConfig from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import LogitsProcessorList from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from ...generation.stopping_criteria import StoppingCriteriaList from ...generation.stopping_criteria import StoppingCriteriaList
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
...@@ -1351,7 +1351,12 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1351,7 +1351,12 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
and generation_config.do_sample is True and generation_config.do_sample is True
) )
# 8. prepare distribution pre_processing samplers # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None
# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor( logits_processor = self._get_logits_processor(
generation_config=generation_config, generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length, input_ids_seq_length=input_ids_seq_length,
...@@ -1360,7 +1365,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1360,7 +1365,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
logits_processor=logits_processor, logits_processor=logits_processor,
) )
# 9. prepare stopping criteria # 10. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria( stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria generation_config=generation_config, stopping_criteria=stopping_criteria
) )
...@@ -1372,7 +1377,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1372,7 +1377,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
f"but is {generation_config.num_return_sequences}." f"but is {generation_config.num_return_sequences}."
) )
# 8. run greedy search # 11. run greedy search
outputs = self.greedy_search( outputs = self.greedy_search(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
...@@ -1386,7 +1391,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1386,7 +1391,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
) )
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 9. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config)
# expand input_ids with `num_return_sequences` additional sequences per batch # expand input_ids with `num_return_sequences` additional sequences per batch
...@@ -1396,7 +1401,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1396,7 +1401,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
**model_kwargs, **model_kwargs,
) )
# 10. run sample # 12. run sample
outputs = self.sample( outputs = self.sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
...@@ -2375,7 +2380,12 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2375,7 +2380,12 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
and generation_config.do_sample is True and generation_config.do_sample is True
) )
# 8. prepare distribution pre_processing samplers # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None
# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor( logits_processor = self._get_logits_processor(
generation_config=generation_config, generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length, input_ids_seq_length=input_ids_seq_length,
...@@ -2384,7 +2394,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2384,7 +2394,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
logits_processor=logits_processor, logits_processor=logits_processor,
) )
# 9. prepare stopping criteria # 10. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria( stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria generation_config=generation_config, stopping_criteria=stopping_criteria
) )
...@@ -2396,7 +2406,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2396,7 +2406,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
f"but is {generation_config.num_return_sequences}." f"but is {generation_config.num_return_sequences}."
) )
# 10. run greedy search # 11. run greedy search
outputs = self.greedy_search( outputs = self.greedy_search(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment