Unverified Commit e3095c5f authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Fix invocation of some slow Flax tests (#3058)

* Fix invocation of some slow tests.

We use __call__ rather than pmapping the generation function ourselves
because the number of static arguments is different now.

* style
parent 526827c3
......@@ -28,7 +28,6 @@ if is_flax_available():
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from jax import pmap
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
......@@ -70,14 +69,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 64, 64, 3)
if jax.device_count() == 8:
......@@ -105,14 +102,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
......@@ -136,14 +131,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
......@@ -211,14 +204,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment