InferenceScript.py 14.4 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
"""
This module contains an abstract class for inference scripts.
"""

from typing import List, Tuple, Union, Dict, TypeVar
from dalle2_pytorch.tokenizer import tokenizer
from dalle2_laion import DalleModelManager, ModelLoadConfig
from torchvision.transforms import ToPILImage, ToTensor
from PIL import Image as PILImage
import torch
from contextlib import contextmanager

ClassType = TypeVar('ClassType')

class InferenceScript:
    def __init__(self, model_manager: DalleModelManager, verbose: bool = False):
        self.model_manager = model_manager
        self.verbose = verbose
        self.device = model_manager.devices[0] if model_manager is not None else 'cpu'

    @classmethod
    def create(cls: ClassType, config: Union[str, ModelLoadConfig], *args, verbose: bool = False, check_updates: bool = True, **kwargs) -> ClassType:
        """
        Creates an instance of the inference script directly from a config.
        Useful if only one inference script will be run at a time.
        """
        if isinstance(config, str):
            config = ModelLoadConfig.from_json_path(config)
        model_manager = DalleModelManager(config, check_updates=check_updates)
        return cls(model_manager, *args, **kwargs, verbose=verbose)
    
    def print(self, *args, **kwargs):
        if self.verbose:
            print(*args, **kwargs)

    @contextmanager
    def _clip_in_decoder(self):
        assert self.model_manager.decoder_info is not None, "Cannot use the decoder without a decoder model."
        decoder = self.model_manager.decoder_info.model
        clip = self.model_manager.clip
        decoder.clip = clip
        yield decoder
        decoder.clip = None
        
    @contextmanager
    def _clip_in_prior(self):
        assert self.model_manager.prior_info is not None, "Cannot use the prior without a prior model."
        prior = self.model_manager.prior_info.model
        clip = self.model_manager.clip
        prior.clip = clip
        yield prior
        prior.clip = None

    @contextmanager
    def _decoder_in_gpu(self):
        # Moves the decoder to gpu and prior to cpu and removes both from gpu after the context is exited.
        assert self.model_manager.decoder_info is not None, "Cannot use the decoder without a decoder model."
        if self.model_manager.prior_info is not None:
            prior = self.model_manager.prior_info.model
            prior.to('cpu')
        with self._clip_in_decoder() as decoder:
            decoder.to(self.device)
            yield decoder
            decoder.to('cpu')

    @contextmanager
    def _prior_in_gpu(self):
        # Moves the prior to gpu and decoder to cpu and removes both from gpu after the context is exited.
        assert self.model_manager.prior_info is not None, "Cannot use the prior without a prior model."
        if self.model_manager.decoder_info is not None:
            decoder = self.model_manager.decoder_info.model
            decoder.to('cpu')
        with self._clip_in_prior() as prior:
            prior.to(self.device)
            yield prior
            prior.to('cpu')

    @contextmanager
    def _clip_in_gpu(self):
        # Moves the clip model to gpu and doesn't touch the others. If clip was originally on cpu, then it is moved back to cpu after
        clip = self.model_manager.clip
        assert clip is not None, "Cannot use the clip without a clip model."
        original_device = next(iter(clip.parameters())).device
        not_on_device = original_device != self.device
        if not_on_device:
            clip.to(self.device)
        yield clip
        if not_on_device:
            clip.to(original_device)
    
    def _pil_to_torch(self, image: Union[PILImage.Image, List[PILImage.Image]], resize_for_clip: bool = True):
        """
        Convert a PIL image into a torch tensor.
        Tensor is of dimension 3 if one image is passed in, and of dimension 4 if a list of images is passed in.
        """
        # If the image has an alpha channel, then we need to remove it.
        def process_image(image: PILImage.Image) -> PILImage.Image:
            if resize_for_clip:
                clip_size = self.model_manager.clip.image_size
                image = image.resize((clip_size, clip_size), resample=PILImage.LANCZOS)
            if image.mode == 'RGBA':
                return image.convert('RGB')
            else:
                return image
        
        if isinstance(image, PILImage.Image):
            return ToTensor()(process_image(image))
        else:
            return torch.stack([ToTensor()(process_image(image[i])) for i in range(len(image))])

    def _torch_to_pil(self, image: torch.tensor):
        """
        If the tensor is a batch of images, then we return a list of PIL images.
        """

        if len(image.shape) == 4:
            return [ToPILImage(image[i]) for i in range(image.shape[0])]
        else:
            return ToPILImage()(image)

    def _repeat_tensor_and_batch(self, tensor: Union[torch.Tensor, List[torch.Tensor]], repeat_num: int, batch_size: int) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """
        Repeats each element of the first dimension of a tensor repeat_num times then batches the result into a list of tensors.
        Also returns a map from the repeated tensor to the index of the original tensor.
        """
        if isinstance(tensor, list):
            tensor = torch.stack(tensor, dim=0)
        batched_repeat = tensor.repeat_interleave(repeat_num, dim=0).split(batch_size, dim=0)
        batched_map = torch.arange(0, tensor.shape[0]).repeat_interleave(repeat_num, dim=0).split(batch_size, dim=0)
        return list(batched_repeat), [t.tolist() for t in batched_map]

    def _embed_images(self, images: List[PILImage.Image]) -> torch.Tensor:
        """
        Generates the clip embeddings for a list of images
        """
        assert self.model_manager.clip is not None, "Cannot generate embeddings for this model."
        images_tensor = self._pil_to_torch(images, resize_for_clip=True).to(self.device)
        with self._clip_in_gpu() as clip:
            image_embed = clip.embed_image(images_tensor)
        return image_embed.image_embed

    def _encode_text(self, text: List[str]) -> torch.Tensor:
        """
        Generates the clip embeddings for a list of text
        """
        assert self.model_manager.clip is not None, "Cannot generate embeddings for this model."
        text_tokens = self._tokenize_text(text)
        with self._clip_in_gpu() as clip:
            text_embed = clip.embed_text(text_tokens.to(self.device))
        return text_embed.text_encodings

    def _tokenize_text(self, text: List[str]) -> torch.Tensor:
        """
        Tokenizes a list of text
        """
        return tokenizer.tokenize(text)

    def _sample_decoder(
        self,
        images: List[PILImage.Image] = None, image_embed: List[torch.Tensor] = None,
        text: List[str] = None, text_encoding: List[torch.Tensor] = None,
        inpaint_images: List[PILImage.Image] = None, inpaint_image_masks: List[torch.Tensor] = None,
        cond_scale: float = None, sample_count: int = 1, batch_size: int = 10,
    ):
        """
        Samples images from the decoder
        Capable of doing basic generation with a list of image embeddings (possibly also conditioned with a list of strings or text embeddings)
        Also capable of two more advanced generation techniques:
        1. Variation generation: If images are passed in the image embeddings will be generated based on those.
        2. In-painting generation: If images and masks are passed in, the images will be in-painted using the masks and the image embeddings.
        """
        if cond_scale is None:
            # Then we use the default scale
            load_config = self.model_manager.model_config.decoder
            unet_configs = load_config.unet_sources
            cond_scale = [1.0] * load_config.final_unet_number
            for unet_config in unet_configs:
                if unet_config.default_cond_scale is not None:
                    for unet_number, new_cond_scale in zip(unet_config.unet_numbers, unet_config.default_cond_scale):
                        cond_scale[unet_number - 1] = new_cond_scale
        self.print(f"Sampling decoder with cond_scale: {cond_scale}")
            
        decoder_info = self.model_manager.decoder_info
        assert decoder_info is not None, "No decoder loaded."
        data_requirements = decoder_info.data_requirements
        min_image_size = min(min(image.size) for image in images) if images is not None else None
        is_valid, errors = data_requirements.is_valid(
            has_image_emb=image_embed is not None, has_image=images is not None,
            has_text_encoding=text_encoding is not None, has_text=text is not None,
            image_size=min_image_size
        )
        assert is_valid, f"The data requirements for the decoder are not satisfied: {errors}"

        # Prepare the data
        image_embeddings = []  # The null case where nothing is done. This should never be used in actuality, but for stylistic consistency I'm keeping it.
        if data_requirements.image_embedding:
            if image_embed is None:
                # Then we need to use clip to generate the image embedding
                image_embed = self._embed_images(images)
            # Then we need to group these tensors into batches of size batch_size such that the total number of samples is sample_count
            image_embeddings, image_embeddings_map = self._repeat_tensor_and_batch(image_embed, repeat_num=sample_count, batch_size=batch_size)
            self.print(f"Decoder batched inputs into {len(image_embeddings)} batches. Total number of samples: {sum(len(t) for t in image_embeddings)}.")
        
        if data_requirements.text_encoding:
            if text_encoding is None:
                text_encoding = self._encode_text(text)
            text_encodings, text_encodings_map = self._repeat_tensor_and_batch(text_encoding, repeat_num=sample_count, batch_size=batch_size)

        assert len(image_embeddings) > 0, "No data provided for decoder inference."
        output_image_map: Dict[int, List[PILImage.Image]] = {}
        with self._decoder_in_gpu() as decoder:
            for i in range(len(image_embeddings)):
                args = {}
                embeddings_map = []
                if data_requirements.image_embedding:
                    args["image_embed"] = image_embeddings[i].to(self.device)
                    embeddings_map = image_embeddings_map[i]
                if data_requirements.text_encoding:
                    args["text_encodings"] = text_encodings[i].to(self.device)
                    embeddings_map = text_encodings_map[i]
                if inpaint_images is not None:
                    assert len(inpaint_images) == len(inpaint_image_masks), "Number of inpaint images and masks must match."
                    inpaint_image_tensors = self._pil_to_torch(inpaint_images, resize_for_clip=False)
                    args["inpaint_image"] = inpaint_image_tensors.to(self.device)
                    args["inpaint_mask"] = torch.stack(inpaint_image_masks).to(self.device)
                    self.print(f"image tensor shape: {args['inpaint_image'].shape}. mask shape: {args['inpaint_mask'].shape}")
                output_images = decoder.sample(**args, cond_scale=cond_scale)
                for output_image, input_embedding_number in zip(output_images, embeddings_map):
                    if input_embedding_number not in output_image_map:
                        output_image_map[input_embedding_number] = []
                    output_image_map[input_embedding_number].append(self._torch_to_pil(output_image))
            return output_image_map

    def _sample_prior(self, text_or_tokens: Union[torch.Tensor, List[str]], cond_scale: float = None, sample_count: int = 1, batch_size: int = 100, num_samples_per_batch: int = 2):
        """
        Samples image embeddings from the prior
        :param text_or_tokens: A list of strings to use as input to the prior or a tensor of tokens generated from strings.
        :param cond_scale: The scale of the conditioning.
        :param sample_count: The number of samples to generate for each input.
        :param batch_size: The max number of samples to run in parallel.
        :param num_samples_per_batch: The number of samples to rerank for each output sample.
        """
        if cond_scale is None:
            # Then we use the default scale
            cond_scale = self.model_manager.model_config.prior.default_cond_scale
            if cond_scale is None:
                # Fallback
                cond_scale = 1.0
        self.print(f"Sampling prior with cond_scale: {cond_scale}")

        assert self.model_manager.prior_info is not None
        data_requirements = self.model_manager.prior_info.data_requirements
        is_valid, errors = data_requirements.is_valid(
            has_text_encoding=False, has_text=text_or_tokens is not None,
            has_image_emb=False, has_image=False,
            image_size=None
        )
        assert is_valid, f"The data requirements for the prior are not satisfied. {errors}"
        if isinstance(text_or_tokens, list):
            text_tokens = self._tokenize_text(text_or_tokens)
        else:
            text_tokens = text_or_tokens
        text_batches, text_batches_map = self._repeat_tensor_and_batch(text_tokens, repeat_num=sample_count, batch_size=batch_size)
        self.print(f"Prior batched inputs into {len(text_batches)} batches. Total number of samples: {sum(len(t) for t in text_batches)}.")
        embedding_map: Dict[int, List[torch.Tensor]] = {}
        # Weirdly the prior requires clip be part of itself to work so we insert it 
        with self._prior_in_gpu() as prior:
            for text_batch, batch_map in zip(text_batches, text_batches_map):
                text_batch = text_batch.to(self.device)
                embeddings = prior.sample(text_batch, cond_scale=cond_scale, num_samples_per_batch=num_samples_per_batch)
                for embedding, embedding_number in zip(embeddings, batch_map):
                    if embedding_number not in embedding_map:
                        embedding_map[embedding_number] = []
                    embedding_map[embedding_number].append(embedding)
        return embedding_map

class CliInferenceScript(InferenceScript):
    def __init__(self, model_manager: DalleModelManager):
        super().__init__(model_manager)
        raise NotImplementedError("CliInferenceScript is not implemented cause I have no idea how I'm going to do it yet.")