Unverified Commit 0eac9cd0 authored by Ryan Dick's avatar Ryan Dick Committed by GitHub
Browse files

Make T2I-Adapter downscale padding match the UNet (#5435)



* Update get_dummy_inputs(...) in T2I-Adapter tests to take image height and width as params.

* Update the T2I-Adapter unit tests to run with the standard number of UNet down blocks so that all T2I-Adapter down blocks get exercised.

* Update the T2I-Adapter down blocks to better match the padding behavior of the UNet.

* Revert "Update the T2I-Adapter unit tests to run with the standard number of UNet down blocks so that all T2I-Adapter down blocks get exercised."

This reverts commit 6d4a060a34415ec973a252944216f4fb8b9926cd.

* Create  utility functions for testing the T2I-Adapter downscaling bahevior.

* (minor) Improve readability with an intermediate named variable.

* Statically parameterize  T2I-Adapter test dimensions rather than generating them dynamically.

* Fix static checks.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent bc7a4d49
...@@ -20,7 +20,6 @@ import torch.nn as nn ...@@ -20,7 +20,6 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging from ..utils import logging
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .resnet import Downsample2D
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -51,24 +50,28 @@ class MultiAdapter(ModelMixin): ...@@ -51,24 +50,28 @@ class MultiAdapter(ModelMixin):
if len(adapters) == 1: if len(adapters) == 1:
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`") raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")
# The outputs from each adapter are added together with a weight # The outputs from each adapter are added together with a weight.
# This means that the change in dimenstions from downsampling must # This means that the change in dimensions from downsampling must
# be the same for all adapters. Inductively, it also means the total # be the same for all adapters. Inductively, it also means the
# downscale factor must also be the same for all adapters. # downscale_factor and total_downscale_factor must be the same for all
# adapters.
first_adapter_total_downscale_factor = adapters[0].total_downscale_factor first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
first_adapter_downscale_factor = adapters[0].downscale_factor
for idx in range(1, len(adapters)): for idx in range(1, len(adapters)):
adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor if (
adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor: or adapters[idx].downscale_factor != first_adapter_downscale_factor
):
raise ValueError( raise ValueError(
f"Expecting all adapters to have the same total_downscale_factor, " f"Expecting all adapters to have the same downscaling behavior, but got:\n"
f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and " f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}" f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
) )
self.total_downscale_factor = adapters[0].total_downscale_factor self.total_downscale_factor = first_adapter_total_downscale_factor
self.downscale_factor = first_adapter_downscale_factor
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]: def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
r""" r"""
...@@ -274,6 +277,13 @@ class T2IAdapter(ModelMixin, ConfigMixin): ...@@ -274,6 +277,13 @@ class T2IAdapter(ModelMixin, ConfigMixin):
def total_downscale_factor(self): def total_downscale_factor(self):
return self.adapter.total_downscale_factor return self.adapter.total_downscale_factor
@property
def downscale_factor(self):
"""The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
not evenly divisible by the downscale_factor then an exception will be raised.
"""
return self.adapter.unshuffle.downscale_factor
# full adapter # full adapter
...@@ -399,7 +409,7 @@ class AdapterBlock(nn.Module): ...@@ -399,7 +409,7 @@ class AdapterBlock(nn.Module):
self.downsample = None self.downsample = None
if down: if down:
self.downsample = Downsample2D(in_channels) self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.in_conv = None self.in_conv = None
if in_channels != out_channels: if in_channels != out_channels:
...@@ -526,7 +536,7 @@ class LightAdapterBlock(nn.Module): ...@@ -526,7 +536,7 @@ class LightAdapterBlock(nn.Module):
self.downsample = None self.downsample = None
if down: if down:
self.downsample = Downsample2D(in_channels) self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1) self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)]) self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
......
...@@ -568,8 +568,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -568,8 +568,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
elif isinstance(image, torch.Tensor): elif isinstance(image, torch.Tensor):
height = image.shape[-2] height = image.shape[-2]
# round down to nearest multiple of `self.adapter.total_downscale_factor` # round down to nearest multiple of `self.adapter.downscale_factor`
height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor
if width is None: if width is None:
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
...@@ -577,8 +577,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -577,8 +577,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
elif isinstance(image, torch.Tensor): elif isinstance(image, torch.Tensor):
width = image.shape[-1] width = image.shape[-1]
# round down to nearest multiple of `self.adapter.total_downscale_factor` # round down to nearest multiple of `self.adapter.downscale_factor`
width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor
return height, width return height, width
......
...@@ -622,8 +622,8 @@ class StableDiffusionXLAdapterPipeline( ...@@ -622,8 +622,8 @@ class StableDiffusionXLAdapterPipeline(
elif isinstance(image, torch.Tensor): elif isinstance(image, torch.Tensor):
height = image.shape[-2] height = image.shape[-2]
# round down to nearest multiple of `self.adapter.total_downscale_factor` # round down to nearest multiple of `self.adapter.downscale_factor`
height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor
if width is None: if width is None:
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
...@@ -631,8 +631,8 @@ class StableDiffusionXLAdapterPipeline( ...@@ -631,8 +631,8 @@ class StableDiffusionXLAdapterPipeline(
elif isinstance(image, torch.Tensor): elif isinstance(image, torch.Tensor):
width = image.shape[-1] width = image.shape[-1]
# round down to nearest multiple of `self.adapter.total_downscale_factor` # round down to nearest multiple of `self.adapter.downscale_factor`
width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor
return height, width return height, width
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from parameterized import parameterized
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers import diffusers
...@@ -137,11 +138,100 @@ class AdapterTests: ...@@ -137,11 +138,100 @@ class AdapterTests:
} }
return components return components
def get_dummy_inputs(self, device, seed=0, num_images=1): def get_dummy_components_with_full_downscaling(self, adapter_type):
"""Get dummy components with x8 VAE downscaling and 4 UNet down blocks.
These dummy components are intended to fully-exercise the T2I-Adapter
downscaling behavior.
"""
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 32, 32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
cross_attention_dim=32,
)
scheduler = PNDMScheduler(skip_prk_steps=True)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 32, 32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
if adapter_type == "full_adapter" or adapter_type == "light_adapter":
adapter = T2IAdapter(
in_channels=3,
channels=[32, 32, 32, 64],
num_res_blocks=2,
downscale_factor=8,
adapter_type=adapter_type,
)
elif adapter_type == "multi_adapter":
adapter = MultiAdapter(
[
T2IAdapter(
in_channels=3,
channels=[32, 32, 32, 64],
num_res_blocks=2,
downscale_factor=8,
adapter_type="full_adapter",
),
T2IAdapter(
in_channels=3,
channels=[32, 32, 32, 64],
num_res_blocks=2,
downscale_factor=8,
adapter_type="full_adapter",
),
]
)
else:
raise ValueError(
f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter', 'light_adapter', or 'multi_adapter''"
)
components = {
"adapter": adapter,
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0, height=64, width=64, num_images=1):
if num_images == 1: if num_images == 1:
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) image = floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device)
else: else:
image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)] image = [
floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device) for _ in range(num_images)
]
if str(device).startswith("mps"): if str(device).startswith("mps"):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
...@@ -170,11 +260,45 @@ class AdapterTests: ...@@ -170,11 +260,45 @@ class AdapterTests:
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3) self._test_inference_batch_single_identical(expected_max_diff=2e-3)
@parameterized.expand(
[
# (dim=264) The internal feature map will be 33x33 after initial pixel unshuffling (downscaled x8).
((4 * 8 + 1) * 8),
# (dim=272) The internal feature map will be 17x17 after the first T2I down block (downscaled x16).
((4 * 4 + 1) * 16),
# (dim=288) The internal feature map will be 9x9 after the second T2I down block (downscaled x32).
((4 * 2 + 1) * 32),
# (dim=320) The internal feature map will be 5x5 after the third T2I down block (downscaled x64).
((4 * 1 + 1) * 64),
]
)
def test_multiple_image_dimensions(self, dim):
"""Test that the T2I-Adapter pipeline supports any input dimension that
is divisible by the adapter's `downscale_factor`. This test was added in
response to an issue where the T2I Adapter's downscaling padding
behavior did not match the UNet's behavior.
Note that we have selected `dim` values to produce odd resolutions at
each downscaling level.
"""
components = self.get_dummy_components_with_full_downscaling()
sd_pipe = StableDiffusionAdapterPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device, height=dim, width=dim)
image = sd_pipe(**inputs).images
assert image.shape == (1, dim, dim, 3)
class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self): def get_dummy_components(self):
return super().get_dummy_components("full_adapter") return super().get_dummy_components("full_adapter")
def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("full_adapter")
def test_stable_diffusion_adapter_default_case(self): def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
...@@ -195,6 +319,9 @@ class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterM ...@@ -195,6 +319,9 @@ class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterM
def get_dummy_components(self): def get_dummy_components(self):
return super().get_dummy_components("light_adapter") return super().get_dummy_components("light_adapter")
def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("light_adapter")
def test_stable_diffusion_adapter_default_case(self): def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
...@@ -215,8 +342,11 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM ...@@ -215,8 +342,11 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM
def get_dummy_components(self): def get_dummy_components(self):
return super().get_dummy_components("multi_adapter") return super().get_dummy_components("multi_adapter")
def get_dummy_inputs(self, device, seed=0): def get_dummy_components_with_full_downscaling(self):
inputs = super().get_dummy_inputs(device, seed, num_images=2) return super().get_dummy_components_with_full_downscaling("multi_adapter")
def get_dummy_inputs(self, device, height=64, width=64, seed=0):
inputs = super().get_dummy_inputs(device, seed, height=height, width=width, num_images=2)
inputs["adapter_conditioning_scale"] = [0.5, 0.5] inputs["adapter_conditioning_scale"] = [0.5, 0.5]
return inputs return inputs
......
...@@ -153,11 +153,119 @@ class StableDiffusionXLAdapterPipelineFastTests( ...@@ -153,11 +153,119 @@ class StableDiffusionXLAdapterPipelineFastTests(
} }
return components return components
def get_dummy_inputs(self, device, seed=0, num_images=1): def get_dummy_components_with_full_downscaling(self, adapter_type="full_adapter_xl"):
"""Get dummy components with x8 VAE downscaling and 3 UNet down blocks.
These dummy components are intended to fully-exercise the T2I-Adapter
downscaling behavior.
"""
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
attention_head_dim=2,
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=1,
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="scaled_linear",
timestep_spacing="leading",
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 32, 32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=32,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
if adapter_type == "full_adapter_xl":
adapter = T2IAdapter(
in_channels=3,
channels=[32, 32, 64],
num_res_blocks=2,
downscale_factor=16,
adapter_type=adapter_type,
)
elif adapter_type == "multi_adapter":
adapter = MultiAdapter(
[
T2IAdapter(
in_channels=3,
channels=[32, 32, 64],
num_res_blocks=2,
downscale_factor=16,
adapter_type="full_adapter_xl",
),
T2IAdapter(
in_channels=3,
channels=[32, 32, 64],
num_res_blocks=2,
downscale_factor=16,
adapter_type="full_adapter_xl",
),
]
)
else:
raise ValueError(
f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter_xl', or 'multi_adapter''"
)
components = {
"adapter": adapter,
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
# "safety_checker": None,
# "feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0, height=64, width=64, num_images=1):
if num_images == 1: if num_images == 1:
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) image = floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device)
else: else:
image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)] image = [
floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device) for _ in range(num_images)
]
if str(device).startswith("mps"): if str(device).startswith("mps"):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
...@@ -190,6 +298,33 @@ class StableDiffusionXLAdapterPipelineFastTests( ...@@ -190,6 +298,33 @@ class StableDiffusionXLAdapterPipelineFastTests(
) )
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
@parameterized.expand(
[
# (dim=144) The internal feature map will be 9x9 after initial pixel unshuffling (downscaled x16).
((4 * 2 + 1) * 16),
# (dim=160) The internal feature map will be 5x5 after the first T2I down block (downscaled x32).
((4 * 1 + 1) * 32),
]
)
def test_multiple_image_dimensions(self, dim):
"""Test that the T2I-Adapter pipeline supports any input dimension that
is divisible by the adapter's `downscale_factor`. This test was added in
response to an issue where the T2I Adapter's downscaling padding
behavior did not match the UNet's behavior.
Note that we have selected `dim` values to produce odd resolutions at
each downscaling level.
"""
components = self.get_dummy_components_with_full_downscaling()
sd_pipe = StableDiffusionXLAdapterPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device, height=dim, width=dim)
image = sd_pipe(**inputs).images
assert image.shape == (1, dim, dim, 3)
@parameterized.expand(["full_adapter", "full_adapter_xl", "light_adapter"]) @parameterized.expand(["full_adapter", "full_adapter_xl", "light_adapter"])
def test_total_downscale_factor(self, adapter_type): def test_total_downscale_factor(self, adapter_type):
"""Test that the T2IAdapter correctly reports its total_downscale_factor.""" """Test that the T2IAdapter correctly reports its total_downscale_factor."""
...@@ -231,8 +366,11 @@ class StableDiffusionXLMultiAdapterPipelineFastTests( ...@@ -231,8 +366,11 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
def get_dummy_components(self): def get_dummy_components(self):
return super().get_dummy_components("multi_adapter") return super().get_dummy_components("multi_adapter")
def get_dummy_inputs(self, device, seed=0): def get_dummy_components_with_full_downscaling(self):
inputs = super().get_dummy_inputs(device, seed, num_images=2) return super().get_dummy_components_with_full_downscaling("multi_adapter")
def get_dummy_inputs(self, device, seed=0, height=64, width=64):
inputs = super().get_dummy_inputs(device, seed, height, width, num_images=2)
inputs["adapter_conditioning_scale"] = [0.5, 0.5] inputs["adapter_conditioning_scale"] = [0.5, 0.5]
return inputs return inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment