Commit 6089d39b authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

style

parent ef1c723d
......@@ -14,8 +14,8 @@
# limitations under the License.
""" Stable Speech model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig, logging
from transformers.configuration_utils import PretrainedConfig
logger = logging.get_logger(__name__)
......
......@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModel
from transformers.activations import ACT2FN
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
......@@ -43,7 +43,7 @@ from transformers.utils import (
logging,
replace_return_docstrings,
)
from transformers import AutoConfig, AutoModel
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig
......@@ -1091,7 +1091,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook: seq_len + codebook] = input_ids[:, codebook]
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
......@@ -1419,7 +1419,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Stable Speech decoder."
)
if config is None:
config = StableSpeechConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config)
config = StableSpeechConfig.from_sub_models_config(
text_encoder.config, audio_encoder.config, decoder.config
)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment