Unverified Commit 9ee3dd38 authored by hlky's avatar hlky Committed by GitHub
Browse files

AudioLDM2 Fixes (#11244)

parent fd02aad4
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from transformers import ( from transformers import (
ClapFeatureExtractor, ClapFeatureExtractor,
ClapModel, ClapModel,
GPT2Model, GPT2LMHeadModel,
RobertaTokenizer, RobertaTokenizer,
RobertaTokenizerFast, RobertaTokenizerFast,
SpeechT5HifiGan, SpeechT5HifiGan,
...@@ -196,7 +196,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -196,7 +196,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
text_encoder: ClapModel, text_encoder: ClapModel,
text_encoder_2: Union[T5EncoderModel, VitsModel], text_encoder_2: Union[T5EncoderModel, VitsModel],
projection_model: AudioLDM2ProjectionModel, projection_model: AudioLDM2ProjectionModel,
language_model: GPT2Model, language_model: GPT2LMHeadModel,
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer], tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
feature_extractor: ClapFeatureExtractor, feature_extractor: ClapFeatureExtractor,
...@@ -259,7 +259,10 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -259,7 +259,10 @@ class AudioLDM2Pipeline(DiffusionPipeline):
) )
device_type = torch_device.type device_type = torch_device.type
device = torch.device(f"{device_type}:{gpu_id or torch_device.index}") device_str = device_type
if gpu_id or torch_device.index:
device_str = f"{device_str}:{gpu_id or torch_device.index}"
device = torch.device(device_str)
if self.device.type != "cpu": if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True) self.to("cpu", silence_dtype_warnings=True)
...@@ -316,9 +319,9 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -316,9 +319,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs) model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
# forward pass to get next hidden states # forward pass to get next hidden states
output = self.language_model(**model_inputs, return_dict=True) output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True)
next_hidden_states = output.last_hidden_state next_hidden_states = output.hidden_states[-1]
# Update the model input # Update the model input
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1) inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
......
...@@ -26,7 +26,7 @@ from transformers import ( ...@@ -26,7 +26,7 @@ from transformers import (
ClapModel, ClapModel,
ClapTextConfig, ClapTextConfig,
GPT2Config, GPT2Config,
GPT2Model, GPT2LMHeadModel,
RobertaTokenizer, RobertaTokenizer,
SpeechT5HifiGan, SpeechT5HifiGan,
SpeechT5HifiGanConfig, SpeechT5HifiGanConfig,
...@@ -162,7 +162,7 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -162,7 +162,7 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
n_ctx=99, n_ctx=99,
n_positions=99, n_positions=99,
) )
language_model = GPT2Model(language_model_config) language_model = GPT2LMHeadModel(language_model_config)
language_model.config.max_new_tokens = 8 language_model.config.max_new_tokens = 8
torch.manual_seed(0) torch.manual_seed(0)
...@@ -516,6 +516,18 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -516,6 +516,18 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_encode_prompt_works_in_isolation(self): def test_encode_prompt_works_in_isolation(self):
pass pass
@unittest.skip("Not supported yet due to CLAPModel.")
def test_sequential_offload_forward_pass_twice(self):
pass
@unittest.skip("Not supported yet, the second forward has mixed devices and `vocoder` is not offloaded.")
def test_cpu_offload_forward_pass_twice(self):
pass
@unittest.skip("Not supported yet. `vocoder` is not offloaded.")
def test_model_cpu_offload_forward_pass(self):
pass
@nightly @nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase): class AudioLDM2PipelineSlowTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment