"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "913986afa53ace2b0becc20535ef7c32cb15276a"
Unverified Commit e3095c5f authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Fix invocation of some slow Flax tests (#3058)

* Fix invocation of some slow tests.

We use __call__ rather than pmapping the generation function ourselves
because the number of static arguments is different now.

* style
parent 526827c3
...@@ -28,7 +28,6 @@ if is_flax_available(): ...@@ -28,7 +28,6 @@ if is_flax_available():
import jax.numpy as jnp import jax.numpy as jnp
from flax.jax_utils import replicate from flax.jax_utils import replicate
from flax.training.common_utils import shard from flax.training.common_utils import shard
from jax import pmap
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
...@@ -70,14 +69,12 @@ class FlaxPipelineTests(unittest.TestCase): ...@@ -70,14 +69,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt] prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt) prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng # shard inputs and rng
params = replicate(params) params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples) prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids) prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 64, 64, 3) assert images.shape == (num_samples, 1, 64, 64, 3)
if jax.device_count() == 8: if jax.device_count() == 8:
...@@ -105,14 +102,12 @@ class FlaxPipelineTests(unittest.TestCase): ...@@ -105,14 +102,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt] prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt) prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng # shard inputs and rng
params = replicate(params) params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples) prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids) prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3) assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8: if jax.device_count() == 8:
...@@ -136,14 +131,12 @@ class FlaxPipelineTests(unittest.TestCase): ...@@ -136,14 +131,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt] prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt) prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng # shard inputs and rng
params = replicate(params) params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples) prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids) prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3) assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8: if jax.device_count() == 8:
...@@ -211,14 +204,12 @@ class FlaxPipelineTests(unittest.TestCase): ...@@ -211,14 +204,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt] prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt) prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng # shard inputs and rng
params = replicate(params) params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples) prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids) prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3) assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8: if jax.device_count() == 8:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment