You need to sign in or sign up before continuing.
Unverified Commit 4a7e4cec authored by Robert Dargavel Smith's avatar Robert Dargavel Smith Committed by GitHub
Browse files

Add condtional generation to AudioDiffusionPipeline (#1826)

* Add condtional generation

* add fast test for conditional audio generation
parent f45c675d
...@@ -89,9 +89,11 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -89,9 +89,11 @@ class AudioDiffusionPipeline(DiffusionPipeline):
step_generator: torch.Generator = None, step_generator: torch.Generator = None,
eta: float = 0, eta: float = 0,
noise: torch.Tensor = None, noise: torch.Tensor = None,
encoding: torch.Tensor = None,
return_dict=True, return_dict=True,
) -> Union[ ) -> Union[
Union[AudioPipelineOutput, ImagePipelineOutput], Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]] Union[AudioPipelineOutput, ImagePipelineOutput],
Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]],
]: ]:
"""Generate random mel spectrogram from audio input and convert to audio. """Generate random mel spectrogram from audio input and convert to audio.
...@@ -108,6 +110,7 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -108,6 +110,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
step_generator (`torch.Generator`): random number generator used to de-noise or None step_generator (`torch.Generator`): random number generator used to de-noise or None
eta (`float`): parameter between 0 and 1 used with DDIM scheduler eta (`float`): parameter between 0 and 1 used with DDIM scheduler
noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None
encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
Returns: Returns:
...@@ -124,7 +127,12 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -124,7 +127,12 @@ class AudioDiffusionPipeline(DiffusionPipeline):
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0]) self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
if noise is None: if noise is None:
noise = torch.randn( noise = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size[0], self.unet.sample_size[1]), (
batch_size,
self.unet.in_channels,
self.unet.sample_size[0],
self.unet.sample_size[1],
),
generator=generator, generator=generator,
device=self.device, device=self.device,
) )
...@@ -157,15 +165,25 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -157,15 +165,25 @@ class AudioDiffusionPipeline(DiffusionPipeline):
mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:])) mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:]))
for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])): for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])):
model_output = self.unet(images, t)["sample"] if isinstance(self.unet, UNet2DConditionModel):
model_output = self.unet(images, t, encoding)["sample"]
else:
model_output = self.unet(images, t)["sample"]
if isinstance(self.scheduler, DDIMScheduler): if isinstance(self.scheduler, DDIMScheduler):
images = self.scheduler.step( images = self.scheduler.step(
model_output=model_output, timestep=t, sample=images, eta=eta, generator=step_generator model_output=model_output,
timestep=t,
sample=images,
eta=eta,
generator=step_generator,
)["prev_sample"] )["prev_sample"]
else: else:
images = self.scheduler.step( images = self.scheduler.step(
model_output=model_output, timestep=t, sample=images, generator=step_generator model_output=model_output,
timestep=t,
sample=images,
generator=step_generator,
)["prev_sample"] )["prev_sample"]
if mask is not None: if mask is not None:
......
...@@ -26,6 +26,7 @@ from diffusers import ( ...@@ -26,6 +26,7 @@ from diffusers import (
DDPMScheduler, DDPMScheduler,
DiffusionPipeline, DiffusionPipeline,
Mel, Mel,
UNet2DConditionModel,
UNet2DModel, UNet2DModel,
) )
from diffusers.utils import slow, torch_device from diffusers.utils import slow, torch_device
...@@ -56,6 +57,21 @@ class PipelineFastTests(unittest.TestCase): ...@@ -56,6 +57,21 @@ class PipelineFastTests(unittest.TestCase):
) )
return model return model
@property
def dummy_unet_condition(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
sample_size=(64, 32),
in_channels=1,
out_channels=1,
layers_per_block=2,
block_out_channels=(128, 128),
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D"),
cross_attention_dim=10,
)
return model
@property @property
def dummy_vqvae_and_unet(self): def dummy_vqvae_and_unet(self):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -128,6 +144,19 @@ class PipelineFastTests(unittest.TestCase): ...@@ -128,6 +144,19 @@ class PipelineFastTests(unittest.TestCase):
expected_slice = np.array([120, 117, 110, 109, 138, 167, 138, 148, 132, 121]) expected_slice = np.array([120, 117, 110, 109, 138, 167, 138, 148, 132, 121])
assert np.abs(image_slice.flatten() - expected_slice).max() == 0 assert np.abs(image_slice.flatten() - expected_slice).max() == 0
dummy_unet_condition = self.dummy_unet_condition
pipe = AudioDiffusionPipeline(
vqvae=self.dummy_vqvae_and_unet[0], unet=dummy_unet_condition, mel=mel, scheduler=scheduler
)
np.random.seed(0)
encoding = torch.rand((1, 1, 10))
output = pipe(generator=generator, encoding=encoding)
image = output.images[0]
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
expected_slice = np.array([120, 139, 147, 123, 124, 96, 115, 121, 126, 144])
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment