Unverified Commit 41509333 authored by Ryan Dick's avatar Ryan Dick Committed by GitHub
Browse files

Fix the total_downscale_factor returned by FullAdapterXL T2IAdapters (#5134)



* Fix FullAdapterXL.total_downscale_factor.

* Fix incorrect error message in T2IAdapter.__init__(...).

* Move IP-Adapter test_total_downscale_factor(...) to pipeline test file (requested in code review).

* Add more info to error message about an unsupported T2I-Adapter adapter_type.

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent dfdf85d3
...@@ -252,7 +252,10 @@ class T2IAdapter(ModelMixin, ConfigMixin): ...@@ -252,7 +252,10 @@ class T2IAdapter(ModelMixin, ConfigMixin):
elif adapter_type == "light_adapter": elif adapter_type == "light_adapter":
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor) self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
else: else:
raise ValueError(f"unknown adapter_type: {type}. Choose either 'full_adapter' or 'simple_adapter'") raise ValueError(
f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or "
"'full_adapter_xl' or 'light_adapter'."
)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
return self.adapter(x) return self.adapter(x)
...@@ -331,8 +334,8 @@ class FullAdapterXL(nn.Module): ...@@ -331,8 +334,8 @@ class FullAdapterXL(nn.Module):
self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks)) self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
self.body = nn.ModuleList(self.body) self.body = nn.ModuleList(self.body)
# XL has one fewer downsampling # XL has only one downsampling AdapterBlock.
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 2) self.total_downscale_factor = downscale_factor * 2
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.unshuffle(x) x = self.unshuffle(x)
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from parameterized import parameterized
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
import diffusers import diffusers
...@@ -184,6 +185,37 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te ...@@ -184,6 +185,37 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
@parameterized.expand(["full_adapter", "full_adapter_xl", "light_adapter"])
def test_total_downscale_factor(self, adapter_type):
"""Test that the T2IAdapter correctly reports its total_downscale_factor."""
batch_size = 1
in_channels = 3
out_channels = [320, 640, 1280, 1280]
in_image_size = 512
adapter = T2IAdapter(
in_channels=in_channels,
channels=out_channels,
num_res_blocks=2,
downscale_factor=8,
adapter_type=adapter_type,
)
adapter.to(torch_device)
in_image = floats_tensor((batch_size, in_channels, in_image_size, in_image_size)).to(torch_device)
adapter_state = adapter(in_image)
# Assume that the last element in `adapter_state` has been downsampled the most, and check
# that it matches the `total_downscale_factor`.
expected_out_image_size = in_image_size // adapter.total_downscale_factor
assert adapter_state[-1].shape == (
batch_size,
out_channels[-1],
expected_out_image_size,
expected_out_image_size,
)
class StableDiffusionXLMultiAdapterPipelineFastTests( class StableDiffusionXLMultiAdapterPipelineFastTests(
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment