Unverified Commit 4a7e4cec authored by Robert Dargavel Smith's avatar Robert Dargavel Smith Committed by GitHub
Browse files

Add condtional generation to AudioDiffusionPipeline (#1826)

* Add condtional generation

* add fast test for conditional audio generation
parent f45c675d
......@@ -89,9 +89,11 @@ class AudioDiffusionPipeline(DiffusionPipeline):
step_generator: torch.Generator = None,
eta: float = 0,
noise: torch.Tensor = None,
encoding: torch.Tensor = None,
return_dict=True,
) -> Union[
Union[AudioPipelineOutput, ImagePipelineOutput], Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]]
Union[AudioPipelineOutput, ImagePipelineOutput],
Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]],
]:
"""Generate random mel spectrogram from audio input and convert to audio.
......@@ -108,6 +110,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
step_generator (`torch.Generator`): random number generator used to de-noise or None
eta (`float`): parameter between 0 and 1 used with DDIM scheduler
noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None
encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
Returns:
......@@ -124,7 +127,12 @@ class AudioDiffusionPipeline(DiffusionPipeline):
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
if noise is None:
noise = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size[0], self.unet.sample_size[1]),
(
batch_size,
self.unet.in_channels,
self.unet.sample_size[0],
self.unet.sample_size[1],
),
generator=generator,
device=self.device,
)
......@@ -157,15 +165,25 @@ class AudioDiffusionPipeline(DiffusionPipeline):
mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:]))
for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])):
model_output = self.unet(images, t)["sample"]
if isinstance(self.unet, UNet2DConditionModel):
model_output = self.unet(images, t, encoding)["sample"]
else:
model_output = self.unet(images, t)["sample"]
if isinstance(self.scheduler, DDIMScheduler):
images = self.scheduler.step(
model_output=model_output, timestep=t, sample=images, eta=eta, generator=step_generator
model_output=model_output,
timestep=t,
sample=images,
eta=eta,
generator=step_generator,
)["prev_sample"]
else:
images = self.scheduler.step(
model_output=model_output, timestep=t, sample=images, generator=step_generator
model_output=model_output,
timestep=t,
sample=images,
generator=step_generator,
)["prev_sample"]
if mask is not None:
......
......@@ -26,6 +26,7 @@ from diffusers import (
DDPMScheduler,
DiffusionPipeline,
Mel,
UNet2DConditionModel,
UNet2DModel,
)
from diffusers.utils import slow, torch_device
......@@ -56,6 +57,21 @@ class PipelineFastTests(unittest.TestCase):
)
return model
@property
def dummy_unet_condition(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
sample_size=(64, 32),
in_channels=1,
out_channels=1,
layers_per_block=2,
block_out_channels=(128, 128),
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D"),
cross_attention_dim=10,
)
return model
@property
def dummy_vqvae_and_unet(self):
torch.manual_seed(0)
......@@ -128,6 +144,19 @@ class PipelineFastTests(unittest.TestCase):
expected_slice = np.array([120, 117, 110, 109, 138, 167, 138, 148, 132, 121])
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
dummy_unet_condition = self.dummy_unet_condition
pipe = AudioDiffusionPipeline(
vqvae=self.dummy_vqvae_and_unet[0], unet=dummy_unet_condition, mel=mel, scheduler=scheduler
)
np.random.seed(0)
encoding = torch.rand((1, 1, 10))
output = pipe(generator=generator, encoding=encoding)
image = output.images[0]
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
expected_slice = np.array([120, 139, 147, 123, 124, 96, 115, 121, 126, 144])
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
@slow
@require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment