Unverified Commit 29a11c2a authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[AudioLDM 2] Pipeline fixes (#4738)

* fix docs

* fix unet docs

* use image output for latents

* fix hub checkpoints

* fix pipeline example

* update example

* return_dict = False

* revert image pipeline output

* revert doc changes

* remove dtype test

* make style

* remove docstring updates

* remove unet docstring update

* Empty commit to re-trigger CI

* fix cpu offload

* fix dtype test

* add offload test
parent cdacd8f1
......@@ -208,12 +208,15 @@ class AudioLDM2Pipeline(DiffusionPipeline):
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
model_sequence = [
self.text_encoder,
self.text_encoder.text_model,
self.text_encoder.text_projection,
self.text_encoder_2,
self.projection_model,
self.language_model,
self.unet,
self.vae,
self.vocoder,
self.text_encoder,
]
hook = None
......@@ -927,7 +930,8 @@ class AudioLDM2Pipeline(DiffusionPipeline):
encoder_hidden_states=generated_prompt_embeds,
encoder_hidden_states_1=prompt_embeds,
encoder_attention_mask_1=attention_mask,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
......
......@@ -44,7 +44,7 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import is_xformers_available, slow, torch_device
from diffusers.utils import is_accelerate_available, is_accelerate_version, is_xformers_available, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
......@@ -477,7 +477,6 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# The method component.dtype returns the dtype of the first parameter registered in the model, not the
# dtype of the entire model. In the case of CLAP, the first parameter is a float64 constant (logit scale)
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(model_dtypes["text_encoder"] == torch.float64)
# Without the logit scale parameters, everything is float32
model_dtypes.pop("text_encoder")
......@@ -492,6 +491,26 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
)
def test_model_cpu_offload(self, expected_max_diff=2e-4):
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_without_offload = audioldm_pipe(**inputs)[0]
audioldm_pipe.enable_model_cpu_offload()
inputs = self.get_dummy_inputs(torch_device)
output_with_offload = audioldm_pipe(**inputs)[0]
max_diff = np.abs(output_with_offload - output_without_offload).max()
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
@slow
class AudioLDM2PipelineSlowTests(unittest.TestCase):
......@@ -514,7 +533,7 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
return inputs
def test_audioldm2(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2")
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
......@@ -532,7 +551,7 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
assert max_diff < 1e-3
def test_audioldm2_lms(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2")
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
......@@ -552,7 +571,7 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
assert max_diff < 1e-3
def test_audioldm2_large(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2-large")
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2-large")
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment