Unverified Commit 4a7e4cec authored by Robert Dargavel Smith's avatar Robert Dargavel Smith Committed by GitHub
Browse files

Add condtional generation to AudioDiffusionPipeline (#1826)

* Add condtional generation

* add fast test for conditional audio generation
parent f45c675d
...@@ -89,9 +89,11 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -89,9 +89,11 @@ class AudioDiffusionPipeline(DiffusionPipeline):
step_generator: torch.Generator = None, step_generator: torch.Generator = None,
eta: float = 0, eta: float = 0,
noise: torch.Tensor = None, noise: torch.Tensor = None,
encoding: torch.Tensor = None,
return_dict=True, return_dict=True,
) -> Union[ ) -> Union[
Union[AudioPipelineOutput, ImagePipelineOutput], Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]] Union[AudioPipelineOutput, ImagePipelineOutput],
Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]],
]: ]:
"""Generate random mel spectrogram from audio input and convert to audio. """Generate random mel spectrogram from audio input and convert to audio.
...@@ -108,6 +110,7 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -108,6 +110,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
step_generator (`torch.Generator`): random number generator used to de-noise or None step_generator (`torch.Generator`): random number generator used to de-noise or None
eta (`float`): parameter between 0 and 1 used with DDIM scheduler eta (`float`): parameter between 0 and 1 used with DDIM scheduler
noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None
encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
Returns: Returns:
...@@ -124,7 +127,12 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -124,7 +127,12 @@ class AudioDiffusionPipeline(DiffusionPipeline):
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0]) self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
if noise is None: if noise is None:
noise = torch.randn( noise = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size[0], self.unet.sample_size[1]), (
batch_size,
self.unet.in_channels,
self.unet.sample_size[0],
self.unet.sample_size[1],
),
generator=generator, generator=generator,
device=self.device, device=self.device,
) )
...@@ -157,15 +165,25 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -157,15 +165,25 @@ class AudioDiffusionPipeline(DiffusionPipeline):
mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:])) mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:]))
for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])): for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])):
model_output = self.unet(images, t)["sample"] if isinstance(self.unet, UNet2DConditionModel):
model_output = self.unet(images, t, encoding)["sample"]
else:
model_output = self.unet(images, t)["sample"]
if isinstance(self.scheduler, DDIMScheduler): if isinstance(self.scheduler, DDIMScheduler):
images = self.scheduler.step( images = self.scheduler.step(
model_output=model_output, timestep=t, sample=images, eta=eta, generator=step_generator model_output=model_output,
timestep=t,
sample=images,
eta=eta,
generator=step_generator,
)["prev_sample"] )["prev_sample"]
else: else:
images = self.scheduler.step( images = self.scheduler.step(
model_output=model_output, timestep=t, sample=images, generator=step_generator model_output=model_output,
timestep=t,
sample=images,
generator=step_generator,
)["prev_sample"] )["prev_sample"]
if mask is not None: if mask is not None:
......
...@@ -26,6 +26,7 @@ from diffusers import ( ...@@ -26,6 +26,7 @@ from diffusers import (
DDPMScheduler, DDPMScheduler,
DiffusionPipeline, DiffusionPipeline,
Mel, Mel,
UNet2DConditionModel,
UNet2DModel, UNet2DModel,
) )
from diffusers.utils import slow, torch_device from diffusers.utils import slow, torch_device
...@@ -56,6 +57,21 @@ class PipelineFastTests(unittest.TestCase): ...@@ -56,6 +57,21 @@ class PipelineFastTests(unittest.TestCase):
) )
return model return model
@property
def dummy_unet_condition(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
sample_size=(64, 32),
in_channels=1,
out_channels=1,
layers_per_block=2,
block_out_channels=(128, 128),
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D"),
cross_attention_dim=10,
)
return model
@property @property
def dummy_vqvae_and_unet(self): def dummy_vqvae_and_unet(self):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -128,6 +144,19 @@ class PipelineFastTests(unittest.TestCase): ...@@ -128,6 +144,19 @@ class PipelineFastTests(unittest.TestCase):
expected_slice = np.array([120, 117, 110, 109, 138, 167, 138, 148, 132, 121]) expected_slice = np.array([120, 117, 110, 109, 138, 167, 138, 148, 132, 121])
assert np.abs(image_slice.flatten() - expected_slice).max() == 0 assert np.abs(image_slice.flatten() - expected_slice).max() == 0
dummy_unet_condition = self.dummy_unet_condition
pipe = AudioDiffusionPipeline(
vqvae=self.dummy_vqvae_and_unet[0], unet=dummy_unet_condition, mel=mel, scheduler=scheduler
)
np.random.seed(0)
encoding = torch.rand((1, 1, 10))
output = pipe(generator=generator, encoding=encoding)
image = output.images[0]
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
expected_slice = np.array([120, 139, 147, 123, 124, 96, 115, 121, 126, 144])
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment