Unverified Commit ce9825b5 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] pop the LoRA scale so that it doesn't get propagated to the weeds (#7338)



* pop scale from the top-level unet instead of getting it.

* improve readability.

* Apply suggestions from code review
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* fix a little bit.

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 85f9d928
...@@ -1081,6 +1081,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -1081,6 +1081,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
A tuple of tensors that if specified are added to the residuals of down unet blocks. A tuple of tensors that if specified are added to the residuals of down unet blocks.
mid_block_additional_residual: (`torch.Tensor`, *optional*): mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block. A tensor that if specified is added to the residual of the middle unet block.
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
encoder_attention_mask (`torch.Tensor`): encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
...@@ -1088,18 +1090,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -1088,18 +1090,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple. tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
example from ControlNet side model(s)
mid_block_additional_residual (`torch.Tensor`, *optional*):
additional residual to be added to UNet mid block output, for example from ControlNet side model
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns: Returns:
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
...@@ -1185,7 +1175,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -1185,7 +1175,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down # 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
if cross_attention_kwargs is not None:
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND: if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer # weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale) scale_lora_layers(self, lora_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment