Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4a564490
Unverified
Commit
4a564490
authored
Jul 31, 2023
by
Joao Gante
Committed by
GitHub
Jul 31, 2023
Browse files
Musicgen: CFG is manually added (#25173)
parent
05cda5df
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
9 deletions
+19
-9
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+19
-9
No files found.
src/transformers/models/musicgen/modeling_musicgen.py
View file @
4a564490
...
@@ -27,7 +27,7 @@ from torch.utils.checkpoint import checkpoint
...
@@ -27,7 +27,7 @@ from torch.utils.checkpoint import checkpoint
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...generation.configuration_utils
import
GenerationConfig
from
...generation.configuration_utils
import
GenerationConfig
from
...generation.logits_process
import
LogitsProcessorList
from
...generation.logits_process
import
ClassifierFreeGuidanceLogitsProcessor
,
LogitsProcessorList
from
...generation.stopping_criteria
import
StoppingCriteriaList
from
...generation.stopping_criteria
import
StoppingCriteriaList
from
...modeling_outputs
import
(
from
...modeling_outputs
import
(
BaseModelOutput
,
BaseModelOutput
,
...
@@ -1351,7 +1351,12 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1351,7 +1351,12 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
and
generation_config
.
do_sample
is
True
and
generation_config
.
do_sample
is
True
)
)
# 8. prepare distribution pre_processing samplers
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if
generation_config
.
guidance_scale
is
not
None
and
generation_config
.
guidance_scale
>
1
:
logits_processor
.
append
(
ClassifierFreeGuidanceLogitsProcessor
(
generation_config
.
guidance_scale
))
generation_config
.
guidance_scale
=
None
# 9. prepare distribution pre_processing samplers
logits_processor
=
self
.
_get_logits_processor
(
logits_processor
=
self
.
_get_logits_processor
(
generation_config
=
generation_config
,
generation_config
=
generation_config
,
input_ids_seq_length
=
input_ids_seq_length
,
input_ids_seq_length
=
input_ids_seq_length
,
...
@@ -1360,7 +1365,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1360,7 +1365,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
)
)
#
9
. prepare stopping criteria
#
10
. prepare stopping criteria
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria
=
self
.
_get_stopping_criteria
(
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
)
...
@@ -1372,7 +1377,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1372,7 +1377,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
f
"but is
{
generation_config
.
num_return_sequences
}
."
f
"but is
{
generation_config
.
num_return_sequences
}
."
)
)
#
8
. run greedy search
#
11
. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
...
@@ -1386,7 +1391,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1386,7 +1391,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
)
)
elif
is_sample_gen_mode
:
elif
is_sample_gen_mode
:
#
9
. prepare logits warper
#
11
. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
# expand input_ids with `num_return_sequences` additional sequences per batch
# expand input_ids with `num_return_sequences` additional sequences per batch
...
@@ -1396,7 +1401,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1396,7 +1401,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
**
model_kwargs
,
**
model_kwargs
,
)
)
# 1
0
. run sample
# 1
2
. run sample
outputs
=
self
.
sample
(
outputs
=
self
.
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
...
@@ -2375,7 +2380,12 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2375,7 +2380,12 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
and
generation_config
.
do_sample
is
True
and
generation_config
.
do_sample
is
True
)
)
# 8. prepare distribution pre_processing samplers
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if
generation_config
.
guidance_scale
is
not
None
and
generation_config
.
guidance_scale
>
1
:
logits_processor
.
append
(
ClassifierFreeGuidanceLogitsProcessor
(
generation_config
.
guidance_scale
))
generation_config
.
guidance_scale
=
None
# 9. prepare distribution pre_processing samplers
logits_processor
=
self
.
_get_logits_processor
(
logits_processor
=
self
.
_get_logits_processor
(
generation_config
=
generation_config
,
generation_config
=
generation_config
,
input_ids_seq_length
=
input_ids_seq_length
,
input_ids_seq_length
=
input_ids_seq_length
,
...
@@ -2384,7 +2394,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2384,7 +2394,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
)
)
#
9
. prepare stopping criteria
#
10
. prepare stopping criteria
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria
=
self
.
_get_stopping_criteria
(
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
)
...
@@ -2396,7 +2406,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2396,7 +2406,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
f
"but is
{
generation_config
.
num_return_sequences
}
."
f
"but is
{
generation_config
.
num_return_sequences
}
."
)
)
# 1
0
. run greedy search
# 1
1
. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment