Unverified Commit caa5884e authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Update Flax TPU tests (#3069)



Update Flax TPU tests.
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fa736e32
...@@ -78,11 +78,10 @@ class FlaxPipelineTests(unittest.TestCase): ...@@ -78,11 +78,10 @@ class FlaxPipelineTests(unittest.TestCase):
assert images.shape == (num_samples, 1, 64, 64, 3) assert images.shape == (num_samples, 1, 64, 64, 3)
if jax.device_count() == 8: if jax.device_count() == 8:
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3 assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.1514745) < 1e-3
assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1 assert np.abs(np.abs(images, dtype=np.float32).sum() - 49947.875) < 5e-1
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
assert len(images_pil) == num_samples assert len(images_pil) == num_samples
def test_stable_diffusion_v1_4(self): def test_stable_diffusion_v1_4(self):
...@@ -140,8 +139,8 @@ class FlaxPipelineTests(unittest.TestCase): ...@@ -140,8 +139,8 @@ class FlaxPipelineTests(unittest.TestCase):
assert images.shape == (num_samples, 1, 512, 512, 3) assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8: if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1 assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self): def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
...@@ -169,8 +168,8 @@ class FlaxPipelineTests(unittest.TestCase): ...@@ -169,8 +168,8 @@ class FlaxPipelineTests(unittest.TestCase):
assert images.shape == (num_samples, 1, 512, 512, 3) assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8: if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1 assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_ddim(self): def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
scheduler = FlaxDDIMScheduler( scheduler = FlaxDDIMScheduler(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment