Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
6089d39b
Commit
6089d39b
authored
Feb 13, 2024
by
sanchit-gandhi
Browse files
style
parent
ef1c723d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
5 deletions
+7
-5
stable_speech/configuration_stable_speech.py
stable_speech/configuration_stable_speech.py
+1
-1
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+6
-4
No files found.
stable_speech/configuration_stable_speech.py
View file @
6089d39b
...
...
@@ -14,8 +14,8 @@
# limitations under the License.
""" Stable Speech model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers
import
AutoConfig
,
logging
from
transformers.configuration_utils
import
PretrainedConfig
logger
=
logging
.
get_logger
(
__name__
)
...
...
stable_speech/modeling_stable_speech.py
View file @
6089d39b
...
...
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import
torch
import
torch.nn
as
nn
from
torch.nn
import
CrossEntropyLoss
from
transformers
import
AutoConfig
,
AutoModel
from
transformers.activations
import
ACT2FN
from
transformers.generation.configuration_utils
import
GenerationConfig
from
transformers.generation.logits_process
import
ClassifierFreeGuidanceLogitsProcessor
,
LogitsProcessorList
...
...
@@ -43,7 +43,7 @@ from transformers.utils import (
logging
,
replace_return_docstrings
,
)
from
transformers
import
AutoConfig
,
AutoModel
from
.configuration_stable_speech
import
StableSpeechConfig
,
StableSpeechDecoderConfig
...
...
@@ -1091,7 +1091,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# fill the shifted ids with the prompt entries, offset by the codebook idx
for
codebook
in
range
(
num_codebooks
):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted
[:,
codebook
,
codebook
:
seq_len
+
codebook
]
=
input_ids
[:,
codebook
]
input_ids_shifted
[:,
codebook
,
codebook
:
seq_len
+
codebook
]
=
input_ids
[:,
codebook
]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
...
...
@@ -1419,7 +1419,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Stable Speech decoder."
)
if
config
is
None
:
config
=
StableSpeechConfig
.
from_sub_models_config
(
text_encoder
.
config
,
audio_encoder
.
config
,
decoder
.
config
)
config
=
StableSpeechConfig
.
from_sub_models_config
(
text_encoder
.
config
,
audio_encoder
.
config
,
decoder
.
config
)
else
:
if
not
isinstance
(
config
,
self
.
config_class
):
raise
ValueError
(
f
"Config:
{
config
}
has to be of type
{
self
.
config_class
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment