Unverified Commit c81a88b2 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] LoRA improvements pt. 3 (#4842)



* throw warning when more than one lora is attempted to be fused.

* introduce support of lora scale during fusion.

* change test name

* changes

* change to _lora_scale

* lora_scale to call whenever applicable.

* debugging

* lora_scale additional.

* cross_attention_kwargs

* lora_scale -> scale.

* lora_scale fix

* lora_scale in patched projection.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* styling.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* remove unneeded prints.

* remove unneeded prints.

* assign cross_attention_kwargs.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* clean up.

* refactor scale retrieval logic a bit.

* fix nonetypw

* fix: tests

* add more tests

* more fixes.

* figure out a way to pass lora_scale.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* unify the retrieval logic of lora_scale.

* move adjust_lora_scale_text_encoder to lora.py.

* introduce dynamic adjustment lora scale support to sd

* fix up copies

* Empty-Commit

* add: test to check fusion equivalence on different scales.

* handle lora fusion warning.

* make lora smaller

* make lora smaller

* make lora smaller

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2c1677ee
...@@ -95,7 +95,7 @@ class PatchedLoraProjection(nn.Module): ...@@ -95,7 +95,7 @@ class PatchedLoraProjection(nn.Module):
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
def _fuse_lora(self): def _fuse_lora(self, lora_scale=1.0):
if self.lora_linear_layer is None: if self.lora_linear_layer is None:
return return
...@@ -108,7 +108,7 @@ class PatchedLoraProjection(nn.Module): ...@@ -108,7 +108,7 @@ class PatchedLoraProjection(nn.Module):
if self.lora_linear_layer.network_alpha is not None: if self.lora_linear_layer.network_alpha is not None:
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0] fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype) self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now # we can drop the lora layer now
...@@ -117,6 +117,7 @@ class PatchedLoraProjection(nn.Module): ...@@ -117,6 +117,7 @@ class PatchedLoraProjection(nn.Module):
# offload the up and down matrices to CPU to not blow the memory # offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu() self.w_up = w_up.cpu()
self.w_down = w_down.cpu() self.w_down = w_down.cpu()
self.lora_scale = lora_scale
def _unfuse_lora(self): def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")): if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
...@@ -128,16 +129,19 @@ class PatchedLoraProjection(nn.Module): ...@@ -128,16 +129,19 @@ class PatchedLoraProjection(nn.Module):
w_up = self.w_up.to(device=device).float() w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float() w_down = self.w_down.to(device).float()
unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None self.w_up = None
self.w_down = None self.w_down = None
def forward(self, input): def forward(self, input):
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
if self.lora_scale is None:
self.lora_scale = 1.0
if self.lora_linear_layer is None: if self.lora_linear_layer is None:
return self.regular_linear_layer(input) return self.regular_linear_layer(input)
return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input) return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
def text_encoder_attn_modules(text_encoder): def text_encoder_attn_modules(text_encoder):
...@@ -576,12 +580,13 @@ class UNet2DConditionLoadersMixin: ...@@ -576,12 +580,13 @@ class UNet2DConditionLoadersMixin:
save_function(state_dict, os.path.join(save_directory, weight_name)) save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
def fuse_lora(self): def fuse_lora(self, lora_scale=1.0):
self.lora_scale = lora_scale
self.apply(self._fuse_lora_apply) self.apply(self._fuse_lora_apply)
def _fuse_lora_apply(self, module): def _fuse_lora_apply(self, module):
if hasattr(module, "_fuse_lora"): if hasattr(module, "_fuse_lora"):
module._fuse_lora() module._fuse_lora(self.lora_scale)
def unfuse_lora(self): def unfuse_lora(self):
self.apply(self._unfuse_lora_apply) self.apply(self._unfuse_lora_apply)
...@@ -924,6 +929,7 @@ class LoraLoaderMixin: ...@@ -924,6 +929,7 @@ class LoraLoaderMixin:
""" """
text_encoder_name = TEXT_ENCODER_NAME text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME unet_name = UNET_NAME
num_fused_loras = 0
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
""" """
...@@ -1807,7 +1813,7 @@ class LoraLoaderMixin: ...@@ -1807,7 +1813,7 @@ class LoraLoaderMixin:
# Safe to call the following regardless of LoRA. # Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch() self._remove_text_encoder_monkey_patch()
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True): def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0):
r""" r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks. Fuses the LoRA parameters into the original parameters of the corresponding blocks.
...@@ -1822,22 +1828,31 @@ class LoraLoaderMixin: ...@@ -1822,22 +1828,31 @@ class LoraLoaderMixin:
fuse_text_encoder (`bool`, defaults to `True`): fuse_text_encoder (`bool`, defaults to `True`):
Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect. LoRA parameters then it won't have any effect.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
""" """
if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1
if self.num_fused_loras > 1:
logger.warn(
"The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.",
)
if fuse_unet: if fuse_unet:
self.unet.fuse_lora() self.unet.fuse_lora(lora_scale)
def fuse_text_encoder_lora(text_encoder): def fuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder): for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection): if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora() attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora() attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora() attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora() attn_module.out_proj._fuse_lora(lora_scale)
for _, mlp_module in text_encoder_mlp_modules(text_encoder): for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection): if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora() mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora() mlp_module.fc2._fuse_lora(lora_scale)
if fuse_text_encoder: if fuse_text_encoder:
if hasattr(self, "text_encoder"): if hasattr(self, "text_encoder"):
...@@ -1884,6 +1899,8 @@ class LoraLoaderMixin: ...@@ -1884,6 +1899,8 @@ class LoraLoaderMixin:
if hasattr(self, "text_encoder_2"): if hasattr(self, "text_encoder_2"):
unfuse_text_encoder_lora(self.text_encoder_2) unfuse_text_encoder_lora(self.text_encoder_2)
self.num_fused_loras -= 1
class FromSingleFileMixin: class FromSingleFileMixin:
""" """
......
...@@ -177,7 +177,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -177,7 +177,7 @@ class BasicTransformerBlock(nn.Module):
class_labels: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None,
): ):
# Notice that normalization is always applied before the real computation in the following blocks. # Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention # 0. Self-Attention
if self.use_ada_layer_norm: if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep) norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero: elif self.use_ada_layer_norm_zero:
...@@ -187,7 +187,10 @@ class BasicTransformerBlock(nn.Module): ...@@ -187,7 +187,10 @@ class BasicTransformerBlock(nn.Module):
else: else:
norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = self.norm1(hidden_states)
# 0. Prepare GLIGEN inputs # 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None) gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
...@@ -201,12 +204,12 @@ class BasicTransformerBlock(nn.Module): ...@@ -201,12 +204,12 @@ class BasicTransformerBlock(nn.Module):
attn_output = gate_msa.unsqueeze(1) * attn_output attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states hidden_states = attn_output + hidden_states
# 1.5 GLIGEN Control # 2.5 GLIGEN Control
if gligen_kwargs is not None: if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 1.5 ends # 2.5 ends
# 2. Cross-Attention # 3. Cross-Attention
if self.attn2 is not None: if self.attn2 is not None:
norm_hidden_states = ( norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
...@@ -220,7 +223,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -220,7 +223,7 @@ class BasicTransformerBlock(nn.Module):
) )
hidden_states = attn_output + hidden_states hidden_states = attn_output + hidden_states
# 3. Feed-forward # 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states) norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero: if self.use_ada_layer_norm_zero:
...@@ -235,11 +238,14 @@ class BasicTransformerBlock(nn.Module): ...@@ -235,11 +238,14 @@ class BasicTransformerBlock(nn.Module):
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat( ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], [
self.ff(hid_slice, scale=lora_scale)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim, dim=self._chunk_dim,
) )
else: else:
ff_output = self.ff(norm_hidden_states) ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.use_ada_layer_norm_zero: if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output ff_output = gate_mlp.unsqueeze(1) * ff_output
...@@ -295,8 +301,11 @@ class FeedForward(nn.Module): ...@@ -295,8 +301,11 @@ class FeedForward(nn.Module):
if final_dropout: if final_dropout:
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states): def forward(self, hidden_states, scale: float = 1.0):
for module in self.net: for module in self.net:
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states) hidden_states = module(hidden_states)
return hidden_states return hidden_states
...@@ -342,8 +351,8 @@ class GEGLU(nn.Module): ...@@ -342,8 +351,8 @@ class GEGLU(nn.Module):
# mps: gelu is not implemented for float16 # mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states): def forward(self, hidden_states, scale: float = 1.0):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
return hidden_states * self.gelu(gate) return hidden_states * self.gelu(gate)
......
...@@ -570,15 +570,15 @@ class AttnProcessor: ...@@ -570,15 +570,15 @@ class AttnProcessor:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, lora_scale=scale) query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, lora_scale=scale) key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, lora_scale=scale) value = attn.to_v(encoder_hidden_states, scale=scale)
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key) key = attn.head_to_batch_dim(key)
...@@ -589,7 +589,7 @@ class AttnProcessor: ...@@ -589,7 +589,7 @@ class AttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -722,17 +722,17 @@ class AttnAddedKVProcessor: ...@@ -722,17 +722,17 @@ class AttnAddedKVProcessor:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, lora_scale=scale) query = attn.to_q(hidden_states, scale=scale)
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, lora_scale=scale) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, lora_scale=scale) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention: if not attn.only_cross_attention:
key = attn.to_k(hidden_states, lora_scale=scale) key = attn.to_k(hidden_states, scale=scale)
value = attn.to_v(hidden_states, lora_scale=scale) value = attn.to_v(hidden_states, scale=scale)
key = attn.head_to_batch_dim(key) key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value) value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
...@@ -746,7 +746,7 @@ class AttnAddedKVProcessor: ...@@ -746,7 +746,7 @@ class AttnAddedKVProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -782,7 +782,7 @@ class AttnAddedKVProcessor2_0: ...@@ -782,7 +782,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, lora_scale=scale) query = attn.to_q(hidden_states, scale=scale)
query = attn.head_to_batch_dim(query, out_dim=4) query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
...@@ -791,8 +791,8 @@ class AttnAddedKVProcessor2_0: ...@@ -791,8 +791,8 @@ class AttnAddedKVProcessor2_0:
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention: if not attn.only_cross_attention:
key = attn.to_k(hidden_states, lora_scale=scale) key = attn.to_k(hidden_states, scale=scale)
value = attn.to_v(hidden_states, lora_scale=scale) value = attn.to_v(hidden_states, scale=scale)
key = attn.head_to_batch_dim(key, out_dim=4) key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
...@@ -809,7 +809,7 @@ class AttnAddedKVProcessor2_0: ...@@ -809,7 +809,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -937,15 +937,15 @@ class XFormersAttnProcessor: ...@@ -937,15 +937,15 @@ class XFormersAttnProcessor:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, lora_scale=scale) query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, lora_scale=scale) key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, lora_scale=scale) value = attn.to_v(encoder_hidden_states, scale=scale)
query = attn.head_to_batch_dim(query).contiguous() query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous() key = attn.head_to_batch_dim(key).contiguous()
...@@ -958,7 +958,7 @@ class XFormersAttnProcessor: ...@@ -958,7 +958,7 @@ class XFormersAttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -1015,15 +1015,15 @@ class AttnProcessor2_0: ...@@ -1015,15 +1015,15 @@ class AttnProcessor2_0:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, lora_scale=scale) query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, lora_scale=scale) key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, lora_scale=scale) value = attn.to_v(encoder_hidden_states, scale=scale)
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
...@@ -1043,7 +1043,7 @@ class AttnProcessor2_0: ...@@ -1043,7 +1043,7 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
......
...@@ -18,12 +18,27 @@ import torch ...@@ -18,12 +18,27 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
from ..utils import logging from ..utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
class LoRALinearLayer(nn.Module): class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
super().__init__() super().__init__()
...@@ -97,12 +112,11 @@ class LoRACompatibleConv(nn.Conv2d): ...@@ -97,12 +112,11 @@ class LoRACompatibleConv(nn.Conv2d):
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer self.lora_layer = lora_layer
def _fuse_lora(self): def _fuse_lora(self, lora_scale=1.0):
if self.lora_layer is None: if self.lora_layer is None:
return return
dtype, device = self.weight.data.dtype, self.weight.data.device dtype, device = self.weight.data.dtype, self.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")
w_orig = self.weight.data.float() w_orig = self.weight.data.float()
w_up = self.lora_layer.up.weight.data.float() w_up = self.lora_layer.up.weight.data.float()
...@@ -113,7 +127,7 @@ class LoRACompatibleConv(nn.Conv2d): ...@@ -113,7 +127,7 @@ class LoRACompatibleConv(nn.Conv2d):
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
fusion = fusion.reshape((w_orig.shape)) fusion = fusion.reshape((w_orig.shape))
fused_weight = w_orig + fusion fused_weight = w_orig + (lora_scale * fusion)
self.weight.data = fused_weight.to(device=device, dtype=dtype) self.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now # we can drop the lora layer now
...@@ -122,33 +136,35 @@ class LoRACompatibleConv(nn.Conv2d): ...@@ -122,33 +136,35 @@ class LoRACompatibleConv(nn.Conv2d):
# offload the up and down matrices to CPU to not blow the memory # offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu() self.w_up = w_up.cpu()
self.w_down = w_down.cpu() self.w_down = w_down.cpu()
self._lora_scale = lora_scale
def _unfuse_lora(self): def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")): if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return return
logger.info(f"Unfusing LoRA weights for {self.__class__}")
fused_weight = self.weight.data fused_weight = self.weight.data
dtype, device = fused_weight.data.dtype, fused_weight.data.device dtype, device = fused_weight.data.dtype, fused_weight.data.device
self.w_up = self.w_up.to(device=device, dtype=dtype) self.w_up = self.w_up.to(device=device).float()
self.w_down = self.w_down.to(device, dtype=dtype) self.w_down = self.w_down.to(device).float()
fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
fusion = fusion.reshape((fused_weight.shape)) fusion = fusion.reshape((fused_weight.shape))
unfused_weight = fused_weight - fusion unfused_weight = fused_weight.float() - (self._lora_scale * fusion)
self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None self.w_up = None
self.w_down = None self.w_down = None
def forward(self, x): def forward(self, hidden_states, scale: float = 1.0):
if self.lora_layer is None: if self.lora_layer is None:
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break # make sure to the functional Conv2D function as otherwise torch.compile's graph will break
# see: https://github.com/huggingface/diffusers/pull/4315 # see: https://github.com/huggingface/diffusers/pull/4315
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return F.conv2d(
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
else: else:
return super().forward(x) + self.lora_layer(x) return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
class LoRACompatibleLinear(nn.Linear): class LoRACompatibleLinear(nn.Linear):
...@@ -163,7 +179,7 @@ class LoRACompatibleLinear(nn.Linear): ...@@ -163,7 +179,7 @@ class LoRACompatibleLinear(nn.Linear):
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
self.lora_layer = lora_layer self.lora_layer = lora_layer
def _fuse_lora(self): def _fuse_lora(self, lora_scale=1.0):
if self.lora_layer is None: if self.lora_layer is None:
return return
...@@ -176,7 +192,7 @@ class LoRACompatibleLinear(nn.Linear): ...@@ -176,7 +192,7 @@ class LoRACompatibleLinear(nn.Linear):
if self.lora_layer.network_alpha is not None: if self.lora_layer.network_alpha is not None:
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0] fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.weight.data = fused_weight.to(device=device, dtype=dtype) self.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now # we can drop the lora layer now
...@@ -185,6 +201,7 @@ class LoRACompatibleLinear(nn.Linear): ...@@ -185,6 +201,7 @@ class LoRACompatibleLinear(nn.Linear):
# offload the up and down matrices to CPU to not blow the memory # offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu() self.w_up = w_up.cpu()
self.w_down = w_down.cpu() self.w_down = w_down.cpu()
self._lora_scale = lora_scale
def _unfuse_lora(self): def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")): if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
...@@ -196,14 +213,16 @@ class LoRACompatibleLinear(nn.Linear): ...@@ -196,14 +213,16 @@ class LoRACompatibleLinear(nn.Linear):
w_up = self.w_up.to(device=device).float() w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float() w_down = self.w_down.to(device).float()
unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None self.w_up = None
self.w_down = None self.w_down = None
def forward(self, hidden_states, lora_scale: int = 1): def forward(self, hidden_states, scale: float = 1.0):
if self.lora_layer is None: if self.lora_layer is None:
return super().forward(hidden_states) out = super().forward(hidden_states)
return out
else: else:
return super().forward(hidden_states) + lora_scale * self.lora_layer(hidden_states) out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
return out
...@@ -135,7 +135,7 @@ class Upsample2D(nn.Module): ...@@ -135,7 +135,7 @@ class Upsample2D(nn.Module):
else: else:
self.Conv2d_0 = conv self.Conv2d_0 = conv
def forward(self, hidden_states, output_size=None): def forward(self, hidden_states, output_size=None, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose: if self.use_conv_transpose:
...@@ -166,7 +166,13 @@ class Upsample2D(nn.Module): ...@@ -166,7 +166,13 @@ class Upsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv: if self.use_conv:
if self.name == "conv": if self.name == "conv":
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states) hidden_states = self.conv(hidden_states)
else:
if isinstance(self.Conv2d_0, LoRACompatibleConv):
hidden_states = self.Conv2d_0(hidden_states, scale)
else: else:
hidden_states = self.Conv2d_0(hidden_states) hidden_states = self.Conv2d_0(hidden_states)
...@@ -211,13 +217,16 @@ class Downsample2D(nn.Module): ...@@ -211,13 +217,16 @@ class Downsample2D(nn.Module):
else: else:
self.conv = conv self.conv = conv
def forward(self, hidden_states): def forward(self, hidden_states, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0: if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states) hidden_states = self.conv(hidden_states)
return hidden_states return hidden_states
...@@ -588,7 +597,7 @@ class ResnetBlock2D(nn.Module): ...@@ -588,7 +597,7 @@ class ResnetBlock2D(nn.Module):
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
) )
def forward(self, input_tensor, temb): def forward(self, input_tensor, temb, scale: float = 1.0):
hidden_states = input_tensor hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
...@@ -603,18 +612,34 @@ class ResnetBlock2D(nn.Module): ...@@ -603,18 +612,34 @@ class ResnetBlock2D(nn.Module):
if hidden_states.shape[0] >= 64: if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous() input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor) input_tensor = (
hidden_states = self.upsample(hidden_states) self.upsample(input_tensor, scale=scale)
if isinstance(self.upsample, Upsample2D)
else self.upsample(input_tensor)
)
hidden_states = (
self.upsample(hidden_states, scale=scale)
if isinstance(self.upsample, Upsample2D)
else self.upsample(hidden_states)
)
elif self.downsample is not None: elif self.downsample is not None:
input_tensor = self.downsample(input_tensor) input_tensor = (
hidden_states = self.downsample(hidden_states) self.downsample(input_tensor, scale=scale)
if isinstance(self.downsample, Downsample2D)
else self.downsample(input_tensor)
)
hidden_states = (
self.downsample(hidden_states, scale=scale)
if isinstance(self.downsample, Downsample2D)
else self.downsample(hidden_states)
)
hidden_states = self.conv1(hidden_states) hidden_states = self.conv1(hidden_states, scale)
if self.time_emb_proj is not None: if self.time_emb_proj is not None:
if not self.skip_time_act: if not self.skip_time_act:
temb = self.nonlinearity(temb) temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None] temb = self.time_emb_proj(temb, scale)[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default": if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb hidden_states = hidden_states + temb
...@@ -631,10 +656,10 @@ class ResnetBlock2D(nn.Module): ...@@ -631,10 +656,10 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.nonlinearity(hidden_states) hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states) hidden_states = self.conv2(hidden_states, scale)
if self.conv_shortcut is not None: if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor) input_tensor = self.conv_shortcut(input_tensor, scale)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
......
...@@ -274,6 +274,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -274,6 +274,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input # 1. Input
if self.is_input_continuous: if self.is_input_continuous:
batch, _, height, width = hidden_states.shape batch, _, height, width = hidden_states.shape
...@@ -281,13 +284,14 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -281,13 +284,14 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states) hidden_states = self.proj_in(hidden_states, lora_scale)
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else: else:
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states) hidden_states = self.proj_in(hidden_states, scale=lora_scale)
elif self.is_input_vectorized: elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states) hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches: elif self.is_input_patches:
...@@ -322,9 +326,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -322,9 +326,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous: if self.is_input_continuous:
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states, scale=lora_scale)
else: else:
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states, scale=lora_scale)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual output = hidden_states + residual
......
This diff is collapsed.
...@@ -934,6 +934,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -934,6 +934,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down # 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
...@@ -956,7 +957,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -956,7 +957,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
**additional_residuals, **additional_residuals,
) )
else: else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
if is_adapter and len(down_block_additional_residuals) > 0: if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0) sample += down_block_additional_residuals.pop(0)
...@@ -1020,7 +1021,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -1020,7 +1021,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
) )
else: else:
sample = upsample_block( sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
) )
# 6. post-process # 6. post-process
......
...@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict ...@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -322,6 +323,9 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -322,6 +323,9 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -27,6 +27,7 @@ from ...configuration_utils import FrozenDict ...@@ -27,6 +27,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -320,6 +321,9 @@ class AltDiffusionImg2ImgPipeline( ...@@ -320,6 +321,9 @@ class AltDiffusionImg2ImgPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -25,6 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -25,6 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
...@@ -312,6 +313,9 @@ class StableDiffusionControlNetPipeline( ...@@ -312,6 +313,9 @@ class StableDiffusionControlNetPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
...@@ -25,6 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -25,6 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
...@@ -337,6 +337,9 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -337,6 +337,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -26,6 +26,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -26,6 +26,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
...@@ -463,6 +464,9 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -463,6 +464,9 @@ class StableDiffusionControlNetInpaintPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -33,6 +33,7 @@ from ...models.attention_processor import ( ...@@ -33,6 +33,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available, is_accelerate_available,
...@@ -342,6 +343,10 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi ...@@ -342,6 +343,10 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -34,6 +34,7 @@ from ...models.attention_processor import ( ...@@ -34,6 +34,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available, is_accelerate_available,
...@@ -315,6 +316,10 @@ class StableDiffusionXLControlNetPipeline( ...@@ -315,6 +316,10 @@ class StableDiffusionXLControlNetPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -33,6 +33,7 @@ from ...models.attention_processor import ( ...@@ -33,6 +33,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available, is_accelerate_available,
...@@ -352,6 +353,10 @@ class StableDiffusionXLControlNetImg2ImgPipeline(DiffusionPipeline, TextualInver ...@@ -352,6 +353,10 @@ class StableDiffusionXLControlNetImg2ImgPipeline(DiffusionPipeline, TextualInver
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -27,6 +27,7 @@ from ...configuration_utils import FrozenDict ...@@ -27,6 +27,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -329,6 +330,9 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -329,6 +330,9 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -23,6 +23,7 @@ from ...configuration_utils import FrozenDict ...@@ -23,6 +23,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
...@@ -321,6 +322,9 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -321,6 +322,9 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -25,6 +25,7 @@ from ...image_processor import VaeImageProcessor ...@@ -25,6 +25,7 @@ from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -321,6 +322,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -321,6 +322,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -26,6 +26,7 @@ from ...configuration_utils import FrozenDict ...@@ -26,6 +26,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -203,6 +204,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -203,6 +204,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment