Unverified Commit c4d28236 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SDXL Lora] Fix last ben sdxl lora (#4797)

* Fix last ben sdxl lora

* Correct typo

* make style
parent 4f8853e4
...@@ -1084,7 +1084,7 @@ class LoraLoaderMixin: ...@@ -1084,7 +1084,7 @@ class LoraLoaderMixin:
# Map SDXL blocks correctly. # Map SDXL blocks correctly.
if unet_config is not None: if unet_config is not None:
# use unet config to remap block numbers # use unet config to remap block numbers
state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict) state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
return state_dict, network_alphas return state_dict, network_alphas
...@@ -1121,24 +1121,41 @@ class LoraLoaderMixin: ...@@ -1121,24 +1121,41 @@ class LoraLoaderMixin:
return weight_name return weight_name
@classmethod @classmethod
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5): def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
is_all_unet = all(k.startswith("lora_unet") for k in state_dict) # 1. get all state_dict_keys
all_keys = state_dict.keys()
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
# 2. check if needs remapping, if not return original dict
is_in_sgm_format = False
for key in all_keys:
if any(p in key for p in sgm_patterns):
is_in_sgm_format = True
break
if not is_in_sgm_format:
return state_dict
# 3. Else remap from SGM patterns
new_state_dict = {} new_state_dict = {}
inner_block_map = ["resnets", "attentions", "upsamplers"] inner_block_map = ["resnets", "attentions", "upsamplers"]
# Retrieves # of down, mid and up blocks # Retrieves # of down, mid and up blocks
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
for layer in state_dict:
if "text" not in layer: for layer in all_keys:
if "text" in layer:
new_state_dict[layer] = state_dict.pop(layer)
else:
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
if "input_blocks" in layer: if sgm_patterns[0] in layer:
input_block_ids.add(layer_id) input_block_ids.add(layer_id)
elif "middle_block" in layer: elif sgm_patterns[1] in layer:
middle_block_ids.add(layer_id) middle_block_ids.add(layer_id)
elif "output_blocks" in layer: elif sgm_patterns[2] in layer:
output_block_ids.add(layer_id) output_block_ids.add(layer_id)
else: else:
raise ValueError("Checkpoint not supported") raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
input_blocks = { input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
...@@ -1201,12 +1218,8 @@ class LoraLoaderMixin: ...@@ -1201,12 +1218,8 @@ class LoraLoaderMixin:
) )
new_state_dict[new_key] = state_dict.pop(key) new_state_dict[new_key] = state_dict.pop(key)
if is_all_unet and len(state_dict) > 0: if len(state_dict) > 0:
raise ValueError("At this point all state dict entries have to be converted.") raise ValueError("At this point all state dict entries have to be converted.")
else:
# Remaining is the text encoder state dict.
for k, v in state_dict.items():
new_state_dict.update({k: v})
return new_state_dict return new_state_dict
......
...@@ -942,3 +942,19 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -942,3 +942,19 @@ class LoraIntegrationTests(unittest.TestCase):
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
self.assertTrue(np.allclose(images, expected, atol=1e-4)) self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_sdxl_1_0_last_ben(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_model_cpu_offload()
lora_model_id = "TheLastBen/Papercut_SDXL"
lora_filename = "papercut.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe("papercut.safetensors", output_type="np", generator=generator, num_inference_steps=2).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment