test_pipelines_flax.py 8.54 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
17
import os
import tempfile
Patrick von Platen's avatar
Patrick von Platen committed
18
19
20
21
22
23
24
25
26
27
import unittest

import numpy as np

from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax, slow


if is_flax_available():
    import jax
28
    import jax.numpy as jnp
29
    from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
Patrick von Platen's avatar
Patrick von Platen committed
30
31
32
33
34
    from flax.jax_utils import replicate
    from flax.training.common_utils import shard
    from jax import pmap


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@require_flax
class DownloadTests(unittest.TestCase):
    def test_download_only_pytorch(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            # pipeline has Flax weights
            _ = FlaxDiffusionPipeline.from_pretrained(
                "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
            )

            all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
            files = [item for sublist in all_root_files for item in sublist]

            # None of the downloaded files should be a PyTorch file even if we have some here:
            # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
            assert not any(f.endswith(".bin") for f in files)


Patrick von Platen's avatar
Patrick von Platen committed
52
@slow
53
@require_flax
Patrick von Platen's avatar
Patrick von Platen committed
54
55
56
class FlaxPipelineTests(unittest.TestCase):
    def test_dummy_all_tpus(self):
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
57
            "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
Patrick von Platen's avatar
Patrick von Platen committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        )

        prompt = (
            "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
            " field, close up, split lighting, cinematic"
        )

        prng_seed = jax.random.PRNGKey(0)
        num_inference_steps = 4

        num_samples = jax.device_count()
        prompt = num_samples * [prompt]
        prompt_ids = pipeline.prepare_inputs(prompt)

        p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

        # shard inputs and rng
        params = replicate(params)
76
        prng_seed = jax.random.split(prng_seed, num_samples)
Patrick von Platen's avatar
Patrick von Platen committed
77
78
79
        prompt_ids = shard(prompt_ids)

        images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
80

81
        assert images.shape == (num_samples, 1, 64, 64, 3)
82
83
84
        if jax.device_count() == 8:
            assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3
            assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1
85

Patrick von Platen's avatar
Patrick von Platen committed
86
87
        images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

88
        assert len(images_pil) == num_samples
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

    def test_stable_diffusion_v1_4(self):
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
        )

        prompt = (
            "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
            " field, close up, split lighting, cinematic"
        )

        prng_seed = jax.random.PRNGKey(0)
        num_inference_steps = 50

        num_samples = jax.device_count()
        prompt = num_samples * [prompt]
        prompt_ids = pipeline.prepare_inputs(prompt)

        p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

        # shard inputs and rng
        params = replicate(params)
111
        prng_seed = jax.random.split(prng_seed, num_samples)
112
113
114
115
        prompt_ids = shard(prompt_ids)

        images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images

116
117
118
119
        assert images.shape == (num_samples, 1, 512, 512, 3)
        if jax.device_count() == 8:
            assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
            assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

    def test_stable_diffusion_v1_4_bfloat_16(self):
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None
        )

        prompt = (
            "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
            " field, close up, split lighting, cinematic"
        )

        prng_seed = jax.random.PRNGKey(0)
        num_inference_steps = 50

        num_samples = jax.device_count()
        prompt = num_samples * [prompt]
        prompt_ids = pipeline.prepare_inputs(prompt)

        p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

        # shard inputs and rng
        params = replicate(params)
142
        prng_seed = jax.random.split(prng_seed, num_samples)
143
144
145
146
        prompt_ids = shard(prompt_ids)

        images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images

147
148
149
150
        assert images.shape == (num_samples, 1, 512, 512, 3)
        if jax.device_count() == 8:
            assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
            assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170

    def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16
        )

        prompt = (
            "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
            " field, close up, split lighting, cinematic"
        )

        prng_seed = jax.random.PRNGKey(0)
        num_inference_steps = 50

        num_samples = jax.device_count()
        prompt = num_samples * [prompt]
        prompt_ids = pipeline.prepare_inputs(prompt)

        # shard inputs and rng
        params = replicate(params)
171
        prng_seed = jax.random.split(prng_seed, num_samples)
172
173
174
175
        prompt_ids = shard(prompt_ids)

        images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images

176
177
178
179
        assert images.shape == (num_samples, 1, 512, 512, 3)
        if jax.device_count() == 8:
            assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
            assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

    def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
        scheduler = FlaxDDIMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            set_alpha_to_one=False,
            steps_offset=1,
        )

        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4",
            revision="bf16",
            dtype=jnp.bfloat16,
            scheduler=scheduler,
            safety_checker=None,
        )
        scheduler_state = scheduler.create_state()

        params["scheduler"] = scheduler_state

        prompt = (
            "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
            " field, close up, split lighting, cinematic"
        )

        prng_seed = jax.random.PRNGKey(0)
        num_inference_steps = 50

        num_samples = jax.device_count()
        prompt = num_samples * [prompt]
        prompt_ids = pipeline.prepare_inputs(prompt)

        p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

        # shard inputs and rng
        params = replicate(params)
217
        prng_seed = jax.random.split(prng_seed, num_samples)
218
219
220
221
        prompt_ids = shard(prompt_ids)

        images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images

222
223
224
225
        assert images.shape == (num_samples, 1, 512, 512, 3)
        if jax.device_count() == 8:
            assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3
            assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1