Unverified Commit 73bb97ad authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[LoRA] fix typo in attention_processor.py (#5066)

* [LoRA] fix typo in attention_processor.py

fixes #5062

* make style

* make fix-copies, logger comented for torch compile
parent 38a664a3
...@@ -501,7 +501,7 @@ class LocalBlend: ...@@ -501,7 +501,7 @@ class LocalBlend:
alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words) alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words)
for i, (prompt, words_) in enumerate(zip(prompts, words)): for i, (prompt, words_) in enumerate(zip(prompts, words)):
if type(words_) is str: if isinstance(words_, str):
words_ = [words_] words_ = [words_]
for word in words_: for word in words_:
ind = get_word_inds(prompt, word, tokenizer) ind = get_word_inds(prompt, word, tokenizer)
...@@ -565,7 +565,7 @@ class AttentionControlEdit(AttentionStore, abc.ABC): ...@@ -565,7 +565,7 @@ class AttentionControlEdit(AttentionStore, abc.ABC):
self.cross_replace_alpha = get_time_words_attention_alpha( self.cross_replace_alpha = get_time_words_attention_alpha(
prompts, num_steps, cross_replace_steps, self.tokenizer prompts, num_steps, cross_replace_steps, self.tokenizer
).to(self.device) ).to(self.device)
if type(self_replace_steps) is float: if isinstance(self_replace_steps, float):
self_replace_steps = 0, self_replace_steps self_replace_steps = 0, self_replace_steps
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
self.local_blend = local_blend # 在外面定义后传进来 self.local_blend = local_blend # 在外面定义后传进来
...@@ -645,7 +645,7 @@ class AttentionReweight(AttentionControlEdit): ...@@ -645,7 +645,7 @@ class AttentionReweight(AttentionControlEdit):
def update_alpha_time_word( def update_alpha_time_word(
alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None
): ):
if type(bounds) is float: if isinstance(bounds, float):
bounds = 0, bounds bounds = 0, bounds
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
if word_inds is None: if word_inds is None:
...@@ -659,7 +659,7 @@ def update_alpha_time_word( ...@@ -659,7 +659,7 @@ def update_alpha_time_word(
def get_time_words_attention_alpha( def get_time_words_attention_alpha(
prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77 prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77
): ):
if type(cross_replace_steps) is not dict: if not isinstance(cross_replace_steps, dict):
cross_replace_steps = {"default_": cross_replace_steps} cross_replace_steps = {"default_": cross_replace_steps}
if "default_" not in cross_replace_steps: if "default_" not in cross_replace_steps:
cross_replace_steps["default_"] = (0.0, 1.0) cross_replace_steps["default_"] = (0.0, 1.0)
...@@ -679,9 +679,9 @@ def get_time_words_attention_alpha( ...@@ -679,9 +679,9 @@ def get_time_words_attention_alpha(
### util functions for LocalBlend and ReplacementEdit ### util functions for LocalBlend and ReplacementEdit
def get_word_inds(text: str, word_place: int, tokenizer): def get_word_inds(text: str, word_place: int, tokenizer):
split_text = text.split(" ") split_text = text.split(" ")
if type(word_place) is str: if isinstance(word_place, str):
word_place = [i for i, word in enumerate(split_text) if word_place == word] word_place = [i for i, word in enumerate(split_text) if word_place == word]
elif type(word_place) is int: elif isinstance(word_place, str):
word_place = [word_place] word_place = [word_place]
out = [] out = []
if len(word_place) > 0: if len(word_place) > 0:
...@@ -750,7 +750,7 @@ def get_replacement_mapper(prompts, tokenizer, max_len=77): ...@@ -750,7 +750,7 @@ def get_replacement_mapper(prompts, tokenizer, max_len=77):
def get_equalizer( def get_equalizer(
text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer
): ):
if type(word_select) is int or type(word_select) is str: if isinstance(word_select, (int, str)):
word_select = (word_select,) word_select = (word_select,)
equalizer = torch.ones(len(values), 77) equalizer = torch.ones(len(values), 77)
values = torch.tensor(values, dtype=torch.float32) values = torch.tensor(values, dtype=torch.float32)
......
...@@ -8,7 +8,6 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -8,7 +8,6 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from diffusers.utils.torch_utils import randn_tensor
from PIL import Image from PIL import Image
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
...@@ -22,6 +21,7 @@ from diffusers.utils import ( ...@@ -22,6 +21,7 @@ from diffusers.utils import (
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -11,7 +11,6 @@ import PIL.Image ...@@ -11,7 +11,6 @@ import PIL.Image
import pycuda.driver as cuda import pycuda.driver as cuda
import tensorrt as trt import tensorrt as trt
import torch import torch
from diffusers.utils.torch_utils import randn_tensor
from PIL import Image from PIL import Image
from pycuda.tools import make_default_context from pycuda.tools import make_default_context
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
...@@ -26,6 +25,7 @@ from diffusers.utils import ( ...@@ -26,6 +25,7 @@ from diffusers.utils import (
logging, logging,
replace_example_docstring, replace_example_docstring,
) )
from diffusers.utils.torch_utils import randn_tensor
# Initialize CUDA # Initialize CUDA
......
...@@ -382,7 +382,7 @@ class Attention(nn.Module): ...@@ -382,7 +382,7 @@ class Attention(nn.Module):
} }
if hasattr(self.processor, "attention_op"): if hasattr(self.processor, "attention_op"):
kwargs["attention_op"] = self.prcoessor.attention_op kwargs["attention_op"] = self.processor.attention_op
lora_processor = lora_processor_cls(hidden_size, **kwargs) lora_processor = lora_processor_cls(hidden_size, **kwargs)
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
......
...@@ -992,7 +992,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -992,7 +992,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
upsample_size = None upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.") # Forward upsample size to force interpolation output size.
forward_upsample_size = True forward_upsample_size = True
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment