"tests/models/mt5/test_modeling_mt5.py" did not exist on "29c10a41d04f855c433a6cde7797b325651417d2"
Unverified Commit 4a564490 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Musicgen: CFG is manually added (#25173)

parent 05cda5df
......@@ -27,7 +27,7 @@ from torch.utils.checkpoint import checkpoint
from ...activations import ACT2FN
from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import LogitsProcessorList
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from ...generation.stopping_criteria import StoppingCriteriaList
from ...modeling_outputs import (
BaseModelOutput,
......@@ -1351,7 +1351,12 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
and generation_config.do_sample is True
)
# 8. prepare distribution pre_processing samplers
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None
# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
......@@ -1360,7 +1365,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
logits_processor=logits_processor,
)
# 9. prepare stopping criteria
# 10. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
......@@ -1372,7 +1377,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
f"but is {generation_config.num_return_sequences}."
)
# 8. run greedy search
# 11. run greedy search
outputs = self.greedy_search(
input_ids,
logits_processor=logits_processor,
......@@ -1386,7 +1391,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
)
elif is_sample_gen_mode:
# 9. prepare logits warper
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
# expand input_ids with `num_return_sequences` additional sequences per batch
......@@ -1396,7 +1401,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
**model_kwargs,
)
# 10. run sample
# 12. run sample
outputs = self.sample(
input_ids,
logits_processor=logits_processor,
......@@ -2375,7 +2380,12 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
and generation_config.do_sample is True
)
# 8. prepare distribution pre_processing samplers
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None
# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
......@@ -2384,7 +2394,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
logits_processor=logits_processor,
)
# 9. prepare stopping criteria
# 10. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
......@@ -2396,7 +2406,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
f"but is {generation_config.num_return_sequences}."
)
# 10. run greedy search
# 11. run greedy search
outputs = self.greedy_search(
input_ids,
logits_processor=logits_processor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment