Unverified Commit 41509333 authored by Ryan Dick's avatar Ryan Dick Committed by GitHub
Browse files

Fix the total_downscale_factor returned by FullAdapterXL T2IAdapters (#5134)



* Fix FullAdapterXL.total_downscale_factor.

* Fix incorrect error message in T2IAdapter.__init__(...).

* Move IP-Adapter test_total_downscale_factor(...) to pipeline test file (requested in code review).

* Add more info to error message about an unsupported T2I-Adapter adapter_type.

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent dfdf85d3
......@@ -252,7 +252,10 @@ class T2IAdapter(ModelMixin, ConfigMixin):
elif adapter_type == "light_adapter":
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
else:
raise ValueError(f"unknown adapter_type: {type}. Choose either 'full_adapter' or 'simple_adapter'")
raise ValueError(
f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or "
"'full_adapter_xl' or 'light_adapter'."
)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
return self.adapter(x)
......@@ -331,8 +334,8 @@ class FullAdapterXL(nn.Module):
self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
self.body = nn.ModuleList(self.body)
# XL has one fewer downsampling
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 2)
# XL has only one downsampling AdapterBlock.
self.total_downscale_factor = downscale_factor * 2
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.unshuffle(x)
......
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
import torch
from parameterized import parameterized
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
import diffusers
......@@ -184,6 +185,37 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
@parameterized.expand(["full_adapter", "full_adapter_xl", "light_adapter"])
def test_total_downscale_factor(self, adapter_type):
"""Test that the T2IAdapter correctly reports its total_downscale_factor."""
batch_size = 1
in_channels = 3
out_channels = [320, 640, 1280, 1280]
in_image_size = 512
adapter = T2IAdapter(
in_channels=in_channels,
channels=out_channels,
num_res_blocks=2,
downscale_factor=8,
adapter_type=adapter_type,
)
adapter.to(torch_device)
in_image = floats_tensor((batch_size, in_channels, in_image_size, in_image_size)).to(torch_device)
adapter_state = adapter(in_image)
# Assume that the last element in `adapter_state` has been downsampled the most, and check
# that it matches the `total_downscale_factor`.
expected_out_image_size = in_image_size // adapter.total_downscale_factor
assert adapter_state[-1].shape == (
batch_size,
out_channels[-1],
expected_out_image_size,
expected_out_image_size,
)
class StableDiffusionXLMultiAdapterPipelineFastTests(
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment