Unverified Commit 830a9d1f authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

[fix] pipeline_unclip generator (#1751)

* [fix] pipeline_unclip generator

pass generator to all schedulers

* fix fast tests test data
parent 2dcf64b7
...@@ -292,7 +292,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -292,7 +292,7 @@ class UnCLIPPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
decoder_latents = self.decoder_scheduler.step( decoder_latents = self.decoder_scheduler.step(
noise_pred, t, decoder_latents, prev_timestep=prev_timestep noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample ).prev_sample
decoder_latents = decoder_latents.clamp(-1, 1) decoder_latents = decoder_latents.clamp(-1, 1)
...@@ -348,7 +348,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -348,7 +348,7 @@ class UnCLIPPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
super_res_latents = self.super_res_scheduler.step( super_res_latents = self.super_res_scheduler.step(
noise_pred, t, super_res_latents, prev_timestep=prev_timestep noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample ).prev_sample
image = super_res_latents image = super_res_latents
......
...@@ -233,15 +233,15 @@ class UnCLIPPipelineFastTests(unittest.TestCase): ...@@ -233,15 +233,15 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
expected_slice = np.array( expected_slice = np.array(
[ [
0.0009, 0.0011,
0.0002,
0.9962,
0.9940,
0.0002,
0.9997, 0.9997,
0.0003, 0.0003,
0.9991, 0.9987,
0.9967, 0.9989,
0.0003,
0.9997,
0.0003,
0.0004,
] ]
) )
...@@ -261,7 +261,7 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase): ...@@ -261,7 +261,7 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
def test_unclip_karlo(self): def test_unclip_karlo(self):
expected_image = load_numpy( expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/karlo_v1_alpha/horse.npy" "/unclip/karlo_v1_alpha_horse.npy"
) )
pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha") pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment