Commit 6089d39b authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

style

parent ef1c723d
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
""" Stable Speech model configuration""" """ Stable Speech model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig, logging from transformers import AutoConfig, logging
from transformers.configuration_utils import PretrainedConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union ...@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModel
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
...@@ -43,7 +43,7 @@ from transformers.utils import ( ...@@ -43,7 +43,7 @@ from transformers.utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from transformers import AutoConfig, AutoModel
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig
...@@ -1091,7 +1091,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1091,7 +1091,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# fill the shifted ids with the prompt entries, offset by the codebook idx # fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks): for codebook in range(num_codebooks):
# mono channel - loop over the codebooks one-by-one # mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook: seq_len + codebook] = input_ids[:, codebook] input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
# construct a pattern mask that indicates the positions of padding tokens for each codebook # construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding) # first fill the upper triangular part (the EOS padding)
...@@ -1419,7 +1419,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1419,7 +1419,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Stable Speech decoder." "Either a configuration has to be provided, or all three of text encoder, audio encoder and Stable Speech decoder."
) )
if config is None: if config is None:
config = StableSpeechConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config) config = StableSpeechConfig.from_sub_models_config(
text_encoder.config, audio_encoder.config, decoder.config
)
else: else:
if not isinstance(config, self.config_class): if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}") raise ValueError(f"Config: {config} has to be of type {self.config_class}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment