Unverified Commit 4d1e4e24 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Flax support for Stable Diffusion 2 (#1423)



* Flax: start adapting to Stable Diffusion 2

* More changes.

* attention_head_dim can be a tuple.

* Fix typos

* Add simple SD 2 integration test.

Slice values taken from my Ampere GPU.

* Add simple UNet integration tests for Flax.

Note that the expected values are taken from the PyTorch results. This
ensures the Flax and PyTorch versions are not too far off.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Typos and style

* Tests: verify jax is available.

* Style

* Make flake happy

* Remove typo.

* Simple Flax SD 2 pipeline tests.

* Import order

* Remove unused import.
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: @camenduru 
parent a808a853
...@@ -104,6 +104,8 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -104,6 +104,8 @@ class FlaxBasicTransformerBlock(nn.Module):
Hidden states dimension inside each head Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0): dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate Dropout rate
only_cross_attention (`bool`, defaults to `False`):
Whether to only apply cross attention.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
...@@ -111,10 +113,11 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -111,10 +113,11 @@ class FlaxBasicTransformerBlock(nn.Module):
n_heads: int n_heads: int
d_head: int d_head: int
dropout: float = 0.0 dropout: float = 0.0
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
# self attention # self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention # cross attention
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
...@@ -126,7 +129,10 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -126,7 +129,10 @@ class FlaxBasicTransformerBlock(nn.Module):
def __call__(self, hidden_states, context, deterministic=True): def __call__(self, hidden_states, context, deterministic=True):
# self attention # self attention
residual = hidden_states residual = hidden_states
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) if self.only_cross_attention:
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
else:
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
# cross attention # cross attention
...@@ -159,6 +165,8 @@ class FlaxTransformer2DModel(nn.Module): ...@@ -159,6 +165,8 @@ class FlaxTransformer2DModel(nn.Module):
Number of transformers block Number of transformers block
dropout (:obj:`float`, *optional*, defaults to 0.0): dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate Dropout rate
use_linear_projection (`bool`, defaults to `False`): tbd
only_cross_attention (`bool`, defaults to `False`): tbd
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
...@@ -167,49 +175,70 @@ class FlaxTransformer2DModel(nn.Module): ...@@ -167,49 +175,70 @@ class FlaxTransformer2DModel(nn.Module):
d_head: int d_head: int
depth: int = 1 depth: int = 1
dropout: float = 0.0 dropout: float = 0.0
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head inner_dim = self.n_heads * self.d_head
self.proj_in = nn.Conv( if self.use_linear_projection:
inner_dim, self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
kernel_size=(1, 1), else:
strides=(1, 1), self.proj_in = nn.Conv(
padding="VALID", inner_dim,
dtype=self.dtype, kernel_size=(1, 1),
) strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.transformer_blocks = [ self.transformer_blocks = [
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) FlaxBasicTransformerBlock(
inner_dim,
self.n_heads,
self.d_head,
dropout=self.dropout,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
for _ in range(self.depth) for _ in range(self.depth)
] ]
self.proj_out = nn.Conv( if self.use_linear_projection:
inner_dim, self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
kernel_size=(1, 1), else:
strides=(1, 1), self.proj_out = nn.Conv(
padding="VALID", inner_dim,
dtype=self.dtype, kernel_size=(1, 1),
) strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states, context, deterministic=True): def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape batch, height, width, channels = hidden_states.shape
residual = hidden_states residual = hidden_states
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states) if self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height * width, channels)
hidden_states = hidden_states.reshape(batch, height * width, channels) hidden_states = self.proj_in(hidden_states)
else:
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
for transformer_block in self.transformer_blocks: for transformer_block in self.transformer_blocks:
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
hidden_states = hidden_states.reshape(batch, height, width, channels) if self.use_linear_projection:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channels)
else:
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
return hidden_states return hidden_states
......
...@@ -46,6 +46,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -46,6 +46,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
num_layers: int = 1 num_layers: int = 1
attn_num_head_channels: int = 1 attn_num_head_channels: int = 1
add_downsample: bool = True add_downsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
...@@ -68,6 +70,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -68,6 +70,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
n_heads=self.attn_num_head_channels, n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels,
depth=1, depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype, dtype=self.dtype,
) )
attentions.append(attn_block) attentions.append(attn_block)
...@@ -178,6 +182,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -178,6 +182,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
num_layers: int = 1 num_layers: int = 1
attn_num_head_channels: int = 1 attn_num_head_channels: int = 1
add_upsample: bool = True add_upsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
...@@ -201,6 +207,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -201,6 +207,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
n_heads=self.attn_num_head_channels, n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels,
depth=1, depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype, dtype=self.dtype,
) )
attentions.append(attn_block) attentions.append(attn_block)
...@@ -310,6 +318,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -310,6 +318,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
attn_num_head_channels: int = 1 attn_num_head_channels: int = 1
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
...@@ -331,6 +340,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -331,6 +340,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
n_heads=self.attn_num_head_channels, n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels, d_head=self.in_channels // self.attn_num_head_channels,
depth=1, depth=1,
use_linear_projection=self.use_linear_projection,
dtype=self.dtype, dtype=self.dtype,
) )
attentions.append(attn_block) attentions.append(attn_block)
......
...@@ -79,7 +79,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -79,7 +79,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The tuple of output channels for each block. The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block. The number of layers per block.
attention_head_dim (`int`, *optional*, defaults to 8): attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads. The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768): cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features. The dimension of the cross attention features.
...@@ -97,11 +97,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -97,11 +97,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"DownBlock2D", "DownBlock2D",
) )
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
only_cross_attention: Union[bool, Tuple[bool]] = False
block_out_channels: Tuple[int] = (320, 640, 1280, 1280) block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2 layers_per_block: int = 2
attention_head_dim: int = 8 attention_head_dim: Union[int, Tuple[int]] = 8
cross_attention_dim: int = 1280 cross_attention_dim: int = 1280
dropout: float = 0.0 dropout: float = 0.0
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
freq_shift: int = 0 freq_shift: int = 0
...@@ -134,6 +136,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -134,6 +136,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift) self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
only_cross_attention = self.only_cross_attention
if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
attention_head_dim = self.attention_head_dim
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
# down # down
down_blocks = [] down_blocks = []
output_channel = block_out_channels[0] output_channel = block_out_channels[0]
...@@ -148,8 +158,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -148,8 +158,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel, out_channels=output_channel,
dropout=self.dropout, dropout=self.dropout,
num_layers=self.layers_per_block, num_layers=self.layers_per_block,
attn_num_head_channels=self.attention_head_dim, attn_num_head_channels=attention_head_dim[i],
add_downsample=not is_final_block, add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
dtype=self.dtype, dtype=self.dtype,
) )
else: else:
...@@ -169,13 +181,16 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -169,13 +181,16 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.mid_block = FlaxUNetMidBlock2DCrossAttn( self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
dropout=self.dropout, dropout=self.dropout,
attn_num_head_channels=self.attention_head_dim, attn_num_head_channels=attention_head_dim[-1],
use_linear_projection=self.use_linear_projection,
dtype=self.dtype, dtype=self.dtype,
) )
# up # up
up_blocks = [] up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types): for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -190,9 +205,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -190,9 +205,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel, out_channels=output_channel,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1, num_layers=self.layers_per_block + 1,
attn_num_head_channels=self.attention_head_dim, attn_num_head_channels=reversed_attention_head_dim[i],
add_upsample=not is_final_block, add_upsample=not is_final_block,
dropout=self.dropout, dropout=self.dropout,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
dtype=self.dtype, dtype=self.dtype,
) )
else: else:
......
...@@ -639,3 +639,29 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -639,3 +639,29 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
expected_output_slice = torch.tensor(expected_slice) expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
@parameterized.expand(
[
# fmt: off
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
# fmt: on
]
)
@require_torch_gpu
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == latents.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
import gc
import unittest
from diffusers import FlaxUNet2DConditionModel
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow
from parameterized import parameterized
if is_flax_available():
import jax
import jax.numpy as jnp
@slow
@require_flax
class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = jnp.bfloat16 if fp16 else jnp.float32
image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
return image
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
dtype = jnp.bfloat16 if fp16 else jnp.float32
revision = "bf16" if fp16 else None
model, params = FlaxUNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", dtype=dtype, revision=revision
)
return model, params
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
dtype = jnp.bfloat16 if fp16 else jnp.float32
hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
return hidden_states
@parameterized.expand(
[
# fmt: off
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
# fmt: on
]
)
def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
sample = model.apply(
{"params": params},
latents,
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=encoder_hidden_states,
).sample
assert sample.shape == latents.shape
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
@parameterized.expand(
[
# fmt: off
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
# fmt: on
]
)
def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
sample = model.apply(
{"params": params},
latents,
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=encoder_hidden_states,
).sample
assert sample.shape == latents.shape
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline
from diffusers.utils import is_flax_available, slow
from diffusers.utils.testing_utils import require_flax
if is_flax_available():
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
@slow
@require_flax
class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def test_stable_diffusion_flax(self):
sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2",
revision="bf16",
dtype=jnp.bfloat16,
)
prompt = "A painting of a squirrel eating a burger"
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = sd_pipe.prepare_inputs(prompt)
params = replicate(params)
prompt_ids = shard(prompt_ids)
prng_seed = jax.random.PRNGKey(0)
prng_seed = jax.random.split(prng_seed, jax.device_count())
images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
assert images.shape == (jax.device_count(), 1, 768, 768, 3)
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512])
print(f"output_slice: {output_slice}")
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
def test_stable_diffusion_dpm_flax(self):
model_id = "stabilityai/stable-diffusion-2"
scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
revision="bf16",
dtype=jnp.bfloat16,
)
params["scheduler"] = scheduler_params
prompt = "A painting of a squirrel eating a burger"
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = sd_pipe.prepare_inputs(prompt)
params = replicate(params)
prompt_ids = shard(prompt_ids)
prng_seed = jax.random.PRNGKey(0)
prng_seed = jax.random.split(prng_seed, jax.device_count())
images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
assert images.shape == (jax.device_count(), 1, 768, 768, 3)
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297])
print(f"output_slice: {output_slice}")
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment