Unverified Commit 29a11c2a authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[AudioLDM 2] Pipeline fixes (#4738)

* fix docs

* fix unet docs

* use image output for latents

* fix hub checkpoints

* fix pipeline example

* update example

* return_dict = False

* revert image pipeline output

* revert doc changes

* remove dtype test

* make style

* remove docstring updates

* remove unet docstring update

* Empty commit to re-trigger CI

* fix cpu offload

* fix dtype test

* add offload test
parent cdacd8f1
...@@ -208,12 +208,15 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -208,12 +208,15 @@ class AudioLDM2Pipeline(DiffusionPipeline):
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
model_sequence = [ model_sequence = [
self.text_encoder, self.text_encoder.text_model,
self.text_encoder.text_projection,
self.text_encoder_2, self.text_encoder_2,
self.projection_model, self.projection_model,
self.language_model, self.language_model,
self.unet, self.unet,
self.vae, self.vae,
self.vocoder,
self.text_encoder,
] ]
hook = None hook = None
...@@ -927,7 +930,8 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -927,7 +930,8 @@ class AudioLDM2Pipeline(DiffusionPipeline):
encoder_hidden_states=generated_prompt_embeds, encoder_hidden_states=generated_prompt_embeds,
encoder_hidden_states_1=prompt_embeds, encoder_hidden_states_1=prompt_embeds,
encoder_attention_mask_1=attention_mask, encoder_attention_mask_1=attention_mask,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
...@@ -44,7 +44,7 @@ from diffusers import ( ...@@ -44,7 +44,7 @@ from diffusers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from diffusers.utils import is_xformers_available, slow, torch_device from diffusers.utils import is_accelerate_available, is_accelerate_version, is_xformers_available, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
...@@ -477,7 +477,6 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -477,7 +477,6 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# The method component.dtype returns the dtype of the first parameter registered in the model, not the # The method component.dtype returns the dtype of the first parameter registered in the model, not the
# dtype of the entire model. In the case of CLAP, the first parameter is a float64 constant (logit scale) # dtype of the entire model. In the case of CLAP, the first parameter is a float64 constant (logit scale)
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(model_dtypes["text_encoder"] == torch.float64)
# Without the logit scale parameters, everything is float32 # Without the logit scale parameters, everything is float32
model_dtypes.pop("text_encoder") model_dtypes.pop("text_encoder")
...@@ -492,6 +491,26 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -492,6 +491,26 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
)
def test_model_cpu_offload(self, expected_max_diff=2e-4):
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_without_offload = audioldm_pipe(**inputs)[0]
audioldm_pipe.enable_model_cpu_offload()
inputs = self.get_dummy_inputs(torch_device)
output_with_offload = audioldm_pipe(**inputs)[0]
max_diff = np.abs(output_with_offload - output_without_offload).max()
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
@slow @slow
class AudioLDM2PipelineSlowTests(unittest.TestCase): class AudioLDM2PipelineSlowTests(unittest.TestCase):
...@@ -514,7 +533,7 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase): ...@@ -514,7 +533,7 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
return inputs return inputs
def test_audioldm2(self): def test_audioldm2(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2") audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
audioldm_pipe = audioldm_pipe.to(torch_device) audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None) audioldm_pipe.set_progress_bar_config(disable=None)
...@@ -532,7 +551,7 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase): ...@@ -532,7 +551,7 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
assert max_diff < 1e-3 assert max_diff < 1e-3
def test_audioldm2_lms(self): def test_audioldm2_lms(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2") audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config) audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)
audioldm_pipe = audioldm_pipe.to(torch_device) audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None) audioldm_pipe.set_progress_bar_config(disable=None)
...@@ -552,7 +571,7 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase): ...@@ -552,7 +571,7 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
assert max_diff < 1e-3 assert max_diff < 1e-3
def test_audioldm2_large(self): def test_audioldm2_large(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2-large") audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2-large")
audioldm_pipe = audioldm_pipe.to(torch_device) audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None) audioldm_pipe.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment