"src/vscode:/vscode.git/clone" did not exist on "a971c598b59532671a271520227cfd2fd54b1cd0"
Unverified Commit 532f41c9 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Deprecate Flax support (#12151)



* start removing flax stuff.

* add deprecation warning.

* add warning messages.

* more warnings.

* remove dockerfiles.

* remove more.

* Update src/diffusers/models/attention_flax.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* up

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 5fcd5f56
name: Run Flax dependency tests
on:
pull_request:
branches:
- main
paths:
- "src/diffusers/**.py"
push:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
check_flax_dependencies:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pip install --upgrade pip uv
python -m uv pip install -e .
python -m uv pip install "jax[cpu]>=0.2.16,!=0.3.2"
python -m uv pip install "flax>=0.4.1"
python -m uv pip install "jaxlib>=0.1.65"
python -m uv pip install pytest
- name: Check for soft dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
pytest tests/others/test_dependencies.py
FROM ubuntu:20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
&& apt-get install -y software-properties-common \
&& add-apt-repository ppa:deadsnakes/ppa
RUN apt install -y bash \
build-essential \
git \
git-lfs \
curl \
ca-certificates \
libsndfile1-dev \
libgl1 \
python3.10 \
python3-pip \
python3.10-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3.10 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3 -m uv pip install --upgrade --no-cache-dir \
clu \
"jax[cpu]>=0.2.16,!=0.3.2" \
"flax>=0.4.1" \
"jaxlib>=0.1.65" && \
python3 -m uv pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
huggingface-hub \
Jinja2 \
librosa \
numpy==1.26.4 \
scipy \
tensorboard \
transformers \
hf_transfer
CMD ["/bin/bash"]
\ No newline at end of file
FROM ubuntu:20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
&& apt-get install -y software-properties-common \
&& add-apt-repository ppa:deadsnakes/ppa
RUN apt install -y bash \
build-essential \
git \
git-lfs \
curl \
ca-certificates \
libsndfile1-dev \
libgl1 \
python3.10 \
python3-pip \
python3.10-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3.10 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3 -m pip install --no-cache-dir \
"jax[tpu]>=0.2.16,!=0.3.2" \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
python3 -m uv pip install --upgrade --no-cache-dir \
clu \
"flax>=0.4.1" \
"jaxlib>=0.1.65" && \
python3 -m uv pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
huggingface-hub \
Jinja2 \
librosa \
numpy==1.26.4 \
scipy \
tensorboard \
transformers \
hf_transfer
CMD ["/bin/bash"]
\ No newline at end of file
......@@ -19,6 +19,11 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from ..utils import logging
logger = logging.get_logger(__name__)
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""Multi-head dot product attention with a limited number of queries."""
......@@ -151,6 +156,11 @@ class FlaxAttention(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5
......@@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module):
split_head_dim: bool = False
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(
self.dim,
......@@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module):
split_head_dim: bool = False
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head
......@@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
# The second linear layer needs to be called
# net_2 for now to match the index of the Sequential layer
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
......@@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)
......
......@@ -20,7 +20,7 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...utils import BaseOutput
from ...utils import BaseOutput, logging
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from ..unets.unet_2d_blocks_flax import (
......@@ -30,6 +30,9 @@ from ..unets.unet_2d_blocks_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxControlNetOutput(BaseOutput):
"""
......@@ -50,6 +53,11 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv_in = nn.Conv(
self.block_out_channels[0],
kernel_size=(3, 3),
......@@ -184,6 +192,11 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
def setup(self) -> None:
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
......
......@@ -16,6 +16,11 @@ import math
import flax.linen as nn
import jax.numpy as jnp
from ..utils import logging
logger = logging.get_logger(__name__)
def get_sinusoidal_embeddings(
timesteps: jnp.ndarray,
......@@ -76,6 +81,11 @@ class FlaxTimestepEmbedding(nn.Module):
The data type for the embedding parameters.
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32
......@@ -104,6 +114,11 @@ class FlaxTimesteps(nn.Module):
flip_sin_to_cos: bool = False
freq_shift: float = 1
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(
......
......@@ -290,6 +290,10 @@ class FlaxModelMixin(PushToHubMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
......
......@@ -15,12 +15,22 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from ..utils import logging
logger = logging.get_logger(__name__)
class FlaxUpsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
......@@ -45,6 +55,11 @@ class FlaxDownsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
......@@ -68,6 +83,11 @@ class FlaxResnetBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
......
......@@ -15,10 +15,14 @@
import flax.linen as nn
import jax.numpy as jnp
from ...utils import logging
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
logger = logging.get_logger(__name__)
class FlaxCrossAttnDownBlock2D(nn.Module):
r"""
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
......@@ -60,6 +64,11 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
attentions = []
......@@ -135,6 +144,11 @@ class FlaxDownBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
......@@ -208,6 +222,11 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
attentions = []
......@@ -288,6 +307,11 @@ class FlaxUpBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
......@@ -356,6 +380,11 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(
......
......@@ -20,7 +20,7 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...utils import BaseOutput
from ...utils import BaseOutput, logging
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from .unet_2d_blocks_flax import (
......@@ -32,6 +32,9 @@ from .unet_2d_blocks_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxUNet2DConditionOutput(BaseOutput):
"""
......@@ -163,6 +166,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
def setup(self) -> None:
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
......
......@@ -25,10 +25,13 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, logging
from .modeling_flax_utils import FlaxModelMixin
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxDecoderOutput(BaseOutput):
"""
......@@ -73,6 +76,10 @@ class FlaxUpsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
......@@ -107,6 +114,11 @@ class FlaxDownsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
......@@ -149,6 +161,11 @@ class FlaxResnetBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
......@@ -221,6 +238,11 @@ class FlaxAttentionBlock(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
......@@ -302,6 +324,11 @@ class FlaxDownEncoderBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
......@@ -359,6 +386,11 @@ class FlaxUpDecoderBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
......@@ -413,6 +445,11 @@ class FlaxUNetMidBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
# there is always at least one resnet
......@@ -504,6 +541,11 @@ class FlaxEncoder(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
# in
self.conv_in = nn.Conv(
......@@ -616,6 +658,11 @@ class FlaxDecoder(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
# z to block_in
......@@ -788,6 +835,11 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.encoder = FlaxEncoder(
in_channels=self.config.in_channels,
out_channels=self.config.latent_channels,
......
......@@ -312,6 +312,11 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> dpm_params["scheduler"] = dpmpp_state
```
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
cache_dir = kwargs.pop("cache_dir", None)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
......
......@@ -22,9 +22,11 @@ import flax
import jax.numpy as jnp
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import BaseOutput, PushToHubMixin
from ..utils import BaseOutput, PushToHubMixin, logging
logger = logging.get_logger(__name__)
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
......@@ -133,6 +135,10 @@ class FlaxSchedulerMixin(PushToHubMixin):
</Tip>
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
config, kwargs = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
......
import unittest
from diffusers import FlaxAutoencoderKL
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
from ..test_modeling_common_flax import FlaxModelTesterMixin
if is_flax_available():
import jax
@require_flax
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
model_class = FlaxAutoencoderKL
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
prng_key = jax.random.PRNGKey(0)
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
return {"sample": image, "prng_key": prng_key}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
import inspect
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
if is_flax_available():
import jax
@require_flax
class FlaxModelTesterMixin:
def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)
output = model.apply(variables, inputs_dict["sample"])
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)
output = model.apply(variables, inputs_dict["sample"])
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
if has_kwarg_in_model_class and not has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if not has_kwarg_in_model_class and has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
import gc
import unittest
from parameterized import parameterized
from diffusers import FlaxUNet2DConditionModel
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow
if is_flax_available():
import jax
import jax.numpy as jnp
@slow
@require_flax
class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = jnp.bfloat16 if fp16 else jnp.float32
image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
return image
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
dtype = jnp.bfloat16 if fp16 else jnp.float32
revision = "bf16" if fp16 else None
model, params = FlaxUNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", dtype=dtype, revision=revision
)
return model, params
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
dtype = jnp.bfloat16 if fp16 else jnp.float32
hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
return hidden_states
@parameterized.expand(
[
# fmt: off
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
# fmt: on
]
)
def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
sample = model.apply(
{"params": params},
latents,
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=encoder_hidden_states,
).sample
assert sample.shape == latents.shape
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
@parameterized.expand(
[
# fmt: off
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
# fmt: on
]
)
def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
sample = model.apply(
{"params": params},
latents,
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=encoder_hidden_states,
).sample
assert sample.shape == latents.shape
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
from diffusers.utils import is_flax_available, load_image
from diffusers.utils.testing_utils import require_flax, slow
if is_flax_available():
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
@slow
@require_flax
class FlaxControlNetPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def test_canny(self):
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
)
params["controlnet"] = controlnet_params
prompts = "bird"
num_samples = jax.device_count()
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
canny_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, jax.device_count())
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)
images = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
jit=True,
).images
assert images.shape == (jax.device_count(), 1, 768, 512, 3)
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array(
[0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078]
)
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
def test_pose(self):
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose", from_pt=True, dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
)
params["controlnet"] = controlnet_params
prompts = "Chef in the kitchen"
num_samples = jax.device_count()
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
pose_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png"
)
processed_image = pipe.prepare_image_inputs([pose_image] * num_samples)
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, jax.device_count())
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)
images = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
jit=True,
).images
assert images.shape == (jax.device_count(), 1, 768, 512, 3)
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array(
[[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]]
)
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import nightly, require_flax
if is_flax_available():
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
@nightly
@require_flax
class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def test_stable_diffusion_flax(self):
sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2",
variant="bf16",
dtype=jnp.bfloat16,
)
prompt = "A painting of a squirrel eating a burger"
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = sd_pipe.prepare_inputs(prompt)
params = replicate(params)
prompt_ids = shard(prompt_ids)
prng_seed = jax.random.PRNGKey(0)
prng_seed = jax.random.split(prng_seed, jax.device_count())
images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
assert images.shape == (jax.device_count(), 1, 768, 768, 3)
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512])
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
@nightly
@require_flax
class FlaxStableDiffusion2PipelineNightlyTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def test_stable_diffusion_dpm_flax(self):
model_id = "stabilityai/stable-diffusion-2"
scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
variant="bf16",
dtype=jnp.bfloat16,
)
params["scheduler"] = scheduler_params
prompt = "A painting of a squirrel eating a burger"
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = sd_pipe.prepare_inputs(prompt)
params = replicate(params)
prompt_ids = shard(prompt_ids)
prng_seed = jax.random.PRNGKey(0)
prng_seed = jax.random.split(prng_seed, jax.device_count())
images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
assert images.shape == (jax.device_count(), 1, 768, 768, 3)
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297])
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from diffusers import FlaxStableDiffusionInpaintPipeline
from diffusers.utils import is_flax_available, load_image
from diffusers.utils.testing_utils import require_flax, slow
if is_flax_available():
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
@slow
@require_flax
class FlaxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def test_stable_diffusion_inpaint_pipeline(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/sd2-inpaint/init_image.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
)
model_id = "xvjiarui/stable-diffusion-2-inpainting"
pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
init_image = num_samples * [init_image]
mask_image = num_samples * [mask_image]
prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
processed_masked_images = shard(processed_masked_images)
processed_masks = shard(processed_masks)
output = pipeline(
prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True
)
images = output.images.reshape(num_samples, 512, 512, 3)
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array(
[0.3611307, 0.37649736, 0.3757408, 0.38213953, 0.39295167, 0.3841631, 0.41554978, 0.4137475, 0.4217084]
)
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax, slow
if is_flax_available():
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
@require_flax
class DownloadTests(unittest.TestCase):
def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
_ = FlaxDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a PyTorch file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
assert not any(f.endswith(".bin") for f in files)
@slow
@require_flax
class FlaxPipelineTests(unittest.TestCase):
def test_dummy_all_tpus(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 4
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 64, 64, 3)
if jax.device_count() == 8:
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.1514745) < 1e-3
assert np.abs(np.abs(images, dtype=np.float32).sum() - 49947.875) < 5e-1
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
assert len(images_pil) == num_samples
def test_stable_diffusion_v1_4(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16, safety_checker=None
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
scheduler = FlaxDDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
steps_offset=1,
)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
variant="bf16",
dtype=jnp.bfloat16,
scheduler=scheduler,
safety_checker=None,
)
scheduler_state = scheduler.create_state()
params["scheduler"] = scheduler_state
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 5e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
def test_jax_memory_efficient_attention(self):
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
variant="bf16",
dtype=jnp.bfloat16,
safety_checker=None,
)
params = replicate(params)
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
slice = images[2, 0, 256, 10:17, 1]
# With memory efficient attention
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
variant="bf16",
dtype=jnp.bfloat16,
safety_checker=None,
use_memory_efficient_attention=True,
)
params = replicate(params)
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images
assert images_eff.shape == (num_samples, 1, 512, 512, 3)
slice_eff = images[2, 0, 256, 10:17, 1]
# I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum`
# over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now.
assert abs(slice_eff - slice).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment