BasicInpainting.py 2.99 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
This script takes images, a mask, and text, and in-paints where the mask.
The script is basically the same as BasicInference, but it also takes images and masks.
"""

from dalle2_laion.scripts import InferenceScript
from typing import List, Dict
from PIL import Image as PILImage
import torch

class BasicInpainting(InferenceScript):
    def run(
        self,
        images: List[PILImage.Image],
        masks: List[torch.Tensor],  # Boolean tensor of same size as image
        text: List[str],
        prior_cond_scale: float = None,  # Use default cond scale from config by default
        decoder_cond_scale: float = None,
        sample_count: int = 1,
        prior_batch_size: int = 100,
        decoder_batch_size: int = 10
    ) -> Dict[int, List[PILImage.Image]]:
        if isinstance(text, str):
            text = [text]
        self.print("Generating prior embeddings...")
        image_embedding_map = self._sample_prior(text, cond_scale=prior_cond_scale, sample_count=sample_count, batch_size=prior_batch_size, num_samples_per_batch=2)
        self.print("Finished generating prior embeddings.")
        # image_embedding_map is a map between the text index and the generated image embeddings
        image_embeddings: List[torch.Tensor] = []
        decoder_text = []  # The decoder also needs the text, but since we have repeated the text embeddings, we also need to repeat the text
        for i, original_text in enumerate(text):
            decoder_text.extend([original_text] * len(image_embedding_map[i]))
            image_embeddings.extend(image_embedding_map[i])
        # In order to get the original text from the image embeddings, we need to reverse the map
        image_embedding_index_reverse_map = {i: [] for i in range(len(text))}
        current_count = 0
        for i in range(len(text)):
            for _ in range(len(image_embedding_map[i])):
                image_embedding_index_reverse_map[i].append(current_count)
                current_count += 1
        # Now we can use the image embeddings to generate the images
        self.print(f"Grouped {len(text)} texts into {len(image_embeddings)} embeddings.")
        self.print("Sampling from decoder...")
        # images = self._pil_to_torch(images, resize_for_clip=False)
        image_map = self._sample_decoder(
            text=decoder_text,
            image_embed=image_embeddings,
            cond_scale=decoder_cond_scale,
            inpaint_images=images, inpaint_image_masks=masks,
            sample_count=1, batch_size=decoder_batch_size
        )
        self.print("Finished sampling from decoder.")
        # Now we will reconstruct a map from text to a map of img_embedding indices to list of images
        output_map: Dict[int, List[PILImage.Image]] = {}
        for i in range(len(images)):
            output_map[i] = []
            embedding_indices = image_embedding_index_reverse_map[i]
            for embedding_index in embedding_indices:
                output_map[i].extend(image_map[embedding_index])
        return output_map