Unverified Commit 0c6d1bc9 authored by Robert Dargavel Smith's avatar Robert Dargavel Smith Committed by GitHub
Browse files

fix audio_diffusion tests (#3850)

parent 13e781f9
...@@ -99,7 +99,10 @@ class PipelineFastTests(unittest.TestCase): ...@@ -99,7 +99,10 @@ class PipelineFastTests(unittest.TestCase):
@slow @slow
def test_audio_diffusion(self): def test_audio_diffusion(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
mel = Mel() mel = Mel(
x_res=self.dummy_unet.config.sample_size[1],
y_res=self.dummy_unet.config.sample_size[0],
)
scheduler = DDPMScheduler() scheduler = DDPMScheduler()
pipe = AudioDiffusionPipeline(vqvae=None, unet=self.dummy_unet, mel=mel, scheduler=scheduler) pipe = AudioDiffusionPipeline(vqvae=None, unet=self.dummy_unet, mel=mel, scheduler=scheduler)
...@@ -127,6 +130,11 @@ class PipelineFastTests(unittest.TestCase): ...@@ -127,6 +130,11 @@ class PipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() == 0 assert np.abs(image_slice.flatten() - expected_slice).max() == 0
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() == 0 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() == 0
mel = Mel(
x_res=self.dummy_vqvae_and_unet[0].config.sample_size[1],
y_res=self.dummy_vqvae_and_unet[0].config.sample_size[0],
)
scheduler = DDIMScheduler() scheduler = DDIMScheduler()
dummy_vqvae_and_unet = self.dummy_vqvae_and_unet dummy_vqvae_and_unet = self.dummy_vqvae_and_unet
pipe = AudioDiffusionPipeline( pipe = AudioDiffusionPipeline(
...@@ -154,13 +162,15 @@ class PipelineFastTests(unittest.TestCase): ...@@ -154,13 +162,15 @@ class PipelineFastTests(unittest.TestCase):
pipe = AudioDiffusionPipeline( pipe = AudioDiffusionPipeline(
vqvae=self.dummy_vqvae_and_unet[0], unet=dummy_unet_condition, mel=mel, scheduler=scheduler vqvae=self.dummy_vqvae_and_unet[0], unet=dummy_unet_condition, mel=mel, scheduler=scheduler
) )
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
np.random.seed(0) np.random.seed(0)
encoding = torch.rand((1, 1, 10)) encoding = torch.rand((1, 1, 10))
output = pipe(generator=generator, encoding=encoding) output = pipe(generator=generator, encoding=encoding)
image = output.images[0] image = output.images[0]
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10] image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
expected_slice = np.array([120, 139, 147, 123, 124, 96, 115, 121, 126, 144]) expected_slice = np.array([107, 103, 120, 127, 142, 122, 113, 122, 97, 111])
assert np.abs(image_slice.flatten() - expected_slice).max() == 0 assert np.abs(image_slice.flatten() - expected_slice).max() == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment