BasicInference.py 3.05 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
This inference script is used to do basic inference without any bells and whistles.
Pass in text, get out image.
"""

from dalle2_laion.scripts import InferenceScript
from typing import Dict, List, Union
from PIL import Image as PILImage
import torch

class BasicInference(InferenceScript):
    def run(
        self,
        text: Union[str, List[str]],
        prior_cond_scale: float = None, decoder_cond_scale: float = None,  # Use defaults from config by default
        prior_sample_count: int = 1, decoder_sample_count: int = 1,
        prior_batch_size: int = 100, decoder_batch_size: int = 10,
        prior_num_samples_per_batch: int = 2
    ) -> Dict[str, Dict[int, List[PILImage.Image]]]:
        """
        Takes text and generates images.
        Returns a map from the text index to the image embedding index to a list of images generated for that embedding.
        """
        if isinstance(text, str):
            text = [text]
        self.print(f"Generating images for texts: {text}")
        self.print("Generating prior embeddings...")
        image_embedding_map = self._sample_prior(text, cond_scale=prior_cond_scale, sample_count=prior_sample_count, batch_size=prior_batch_size, num_samples_per_batch=prior_num_samples_per_batch)
        self.print("Finished generating prior embeddings.")
        # image_embedding_map is a map between the text index and the generated image embeddings
        image_embeddings: List[torch.Tensor] = []
        decoder_text = []  # The decoder also needs the text, but since we have repeated the text embeddings, we also need to repeat the text
        for i, original_text in enumerate(text):
            decoder_text.extend([original_text] * len(image_embedding_map[i]))
            image_embeddings.extend(image_embedding_map[i])
        # In order to get the original text from the image embeddings, we need to reverse the map
        image_embedding_index_reverse_map = {i: [] for i in range(len(text))}
        current_count = 0
        for i in range(len(text)):
            for _ in range(len(image_embedding_map[i])):
                image_embedding_index_reverse_map[i].append(current_count)
                current_count += 1
        # Now we can use the image embeddings to generate the images
        self.print(f"Grouped {len(text)} texts into {len(image_embeddings)} embeddings.")
        self.print("Sampling from decoder...")
        image_map = self._sample_decoder(text=decoder_text, image_embed=image_embeddings, cond_scale=decoder_cond_scale, sample_count=decoder_sample_count, batch_size=decoder_batch_size)
        self.print("Finished sampling from decoder.")
        # Now we will reconstruct a map from text to a map of img_embedding indices to list of images
        output_map: Dict[str, Dict[int, List[PILImage.Image]]] = {}
        for i, prompt in enumerate(text):
            output_map[prompt] = {}
            embedding_indices = image_embedding_index_reverse_map[i]
            for embedding_index in embedding_indices:
                output_map[prompt][embedding_index] = image_map[embedding_index]
        return output_map