Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
0589d6c6
Commit
0589d6c6
authored
May 09, 2024
by
sanchit-gandhi
Browse files
use private generation methods
parent
83d4a719
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
30 deletions
+14
-30
parler_tts/modeling_parler_tts.py
parler_tts/modeling_parler_tts.py
+14
-30
No files found.
parler_tts/modeling_parler_tts.py
View file @
0589d6c6
...
@@ -1386,8 +1386,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1386,8 +1386,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
@@ -1481,14 +1479,11 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1481,14 +1479,11 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
)
)
# 11. run greedy search
# 11. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
_
greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -1506,15 +1501,12 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1506,15 +1501,12 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
)
)
# 12. run sample
# 12. run sample
outputs
=
self
.
sample
(
outputs
=
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -2198,8 +2190,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2198,8 +2190,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
self
,
self
,
inputs_tensor
:
torch
.
Tensor
,
inputs_tensor
:
torch
.
Tensor
,
model_kwargs
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
,
model_input_name
:
Optional
[
str
],
g
uidance_scale
:
Optional
[
float
]
=
None
,
g
eneration_config
:
GenerationConfig
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
# 1. get text encoder
# 1. get text encoder
encoder
=
self
.
get_text_encoder
()
encoder
=
self
.
get_text_encoder
()
...
@@ -2221,6 +2213,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2221,6 +2213,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
encoder_kwargs
=
{
encoder_kwargs
=
{
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
}
}
encoder_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
encoder_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
guidance_scale
=
generation_config
.
guidance_scale
# 3. make sure that encoder returns `ModelOutput`
# 3. make sure that encoder returns `ModelOutput`
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
...
@@ -2452,8 +2447,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2452,8 +2447,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
batch_size
=
inputs_tensor
.
shape
[
0
]
batch_size
=
inputs_tensor
.
shape
[
0
]
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
@@ -2467,10 +2460,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2467,10 +2460,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if
"encoder_outputs"
not
in
model_kwargs
:
if
"encoder_outputs"
not
in
model_kwargs
:
# encoder_outputs are created and added to `model_kwargs`
# encoder_outputs are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
inputs_tensor
,
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
,
model_kwargs
,
model_input_name
,
guidance_scale
=
generation_config
.
guidance_scale
,
)
)
if
"prompt_hidden_states"
not
in
model_kwargs
and
"prompt_input_ids"
in
model_kwargs
:
if
"prompt_hidden_states"
not
in
model_kwargs
and
"prompt_input_ids"
in
model_kwargs
:
...
@@ -2579,14 +2569,11 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2579,14 +2569,11 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
)
)
# 11. run greedy search
# 11. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
_
greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -2605,15 +2592,12 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2605,15 +2592,12 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
)
)
# 12. run sample
# 12. run sample
outputs
=
self
.
sample
(
outputs
=
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment