"lm_eval/tasks/bbh/cot_fewshot/date_understanding.yaml" did not exist on "ed53d51c5c9c5f88abe4d55e379b76f501118b43"
dalle2_laion.py 18.7 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
from dataclasses import dataclass
from typing import Any, Tuple, Optional, TypeVar, Generic, List
from dalle2_laion.config import DecoderLoadConfig, SingleDecoderLoadConfig, PriorLoadConfig, ModelLoadConfig
from dalle2_pytorch import __version__ as Dalle2Version, Decoder, DiffusionPrior, Unet
from dalle2_pytorch.train_configs import TrainDecoderConfig, TrainDiffusionPriorConfig, DecoderConfig, UnetConfig, DiffusionPriorConfig
import torch
import torch.nn as nn
from packaging import version

def exists(obj: Any) -> bool:
    return obj is not None

@dataclass
class DataRequirements:
    image_embedding: bool
    text_encoding: bool
    image: bool
    text: bool
    can_generate_embedding: bool
    image_size: int

    def has_clip(self):
        self.can_generate_embedding = True

    def is_valid(
        self,
        has_image_emb: bool = False, has_text_encoding: bool = False,
        has_image: bool = False, has_text: bool = False,
        image_size: Optional[int] = None
    ):
        # The image size must be equal to or greater than the required size
        # Verify that the text input is valid
        errors = []
        is_valid = True
        if self.text_encoding:
            # Then we need to some way to get the text encoding
            if not (has_text_encoding or (self.can_generate_embedding and has_text)):
                errors.append('Text encoding is required, but no text encoding or text was provided')
                is_valid = False
        if self.text:
            # Then this requires text be passed in explicitly
            if not has_text:
                errors.append('Text is required, but no text was provided')
                is_valid = False

        # Verify that the image input is valid
        image_size_greater = exists(image_size) and image_size >= self.image_size
        if self.image_embedding:
            # Then we need to some way to get the image embedding
            # In this case, we also need to make sure that the image size is big enough to generate the embedding
            if not (has_image_emb or (self.can_generate_embedding and has_image and image_size_greater)):
                errors.append('Image embedding is required, but no image embedding or image was provided or the image was too small')
                is_valid = False
        if self.image:
            # Then this requires an image be passed in explicitly
            # In this case we also need to make sure the image is big enough to be used
            if not (has_image and image_size_greater):
                errors.append('Image is required, but no image was provided or the image was too small')
                is_valid = False
        return is_valid, errors

    def __add__(self, other: 'DataRequirements') -> 'DataRequirements':
        return DataRequirements(
            image_embedding=self.image_embedding or other.image_embedding,  # If either needs an image embedding, the combination needs one
            text_embedding=self.text_embedding or other.text_embedding,  # If either needs a text embedding, the combination needs one
            image=self.image or other.image,  # If either needs an image, the combination needs it  
            text=self.text or other.text,  # If either needs a text, the combination needs it
            can_generate_embedding=self.can_generate_embedding and other.can_generate_embedding,  # If either cannot generate an embedding, we know that trying to replace an embedding with raw data will not work
            image_size=max(self.image_size, other.image_size)  # We can downsample without loss of information, so we use the larger image size
        )

ModelType = TypeVar('ModelType', Decoder, DiffusionPrior)

@dataclass
class ModelInfo(Generic[ModelType]):
    model: ModelType
    model_version: Optional[version.Version]
    requires_clip: bool
    data_requirements: DataRequirements

class DalleModelManager:
    """
    Used to load priors and decoders and to provide a simple interface to run general scripts against
    """
    def __init__(self, model_load_config: ModelLoadConfig, check_updates: bool = True):
        """
        Downloads the models and loads them into memory.
        If check_updates is True, then the models will be re-downloaded if checksums do not match.
        """
        self.check_updates = check_updates
        self.model_config = model_load_config
        self.current_version = version.parse(Dalle2Version)
        self.single_device = isinstance(model_load_config.devices, str)
        self.devices = [torch.device(model_load_config.devices)] if self.single_device else [torch.device(d) for d in model_load_config.devices]
        self.load_device = torch.device('cpu') if model_load_config.load_on_cpu else self.devices[0]
        self.strict_loading = model_load_config.strict_loading

        if model_load_config.decoder is not None:
            self.decoder_info = self.load_decoder(model_load_config.decoder)
        else:
            self.decoder_info = None

        if model_load_config.prior is not None:
            self.prior_info = self.load_prior(model_load_config.prior)
        else:
            self.prior_info = None

        if (exists(self.decoder_info) and self.decoder_info.requires_clip) or (exists(self.prior_info) and self.prior_info.requires_clip):
            assert model_load_config.clip is not None, 'Your model requires clip to be loaded. Please provide a clip config.'
            self.clip = model_load_config.clip.create()
            # Update the data requirements to include the clip model
            if exists(self.decoder_info):
                self.decoder_info.data_requirements.has_clip()
            if exists(self.prior_info):
                self.prior_info.data_requirements.has_clip()
        else:
            if model_load_config.clip is not None:
                print(f'WARNING: Your model does not require clip, but you provided a clip config. This will be ignored.')

    def _get_decoder_data_requirements(self, decoder_config: DecoderConfig, min_unet_number: int = 1) -> DataRequirements:
        """
        Returns the data requirements for a decoder
        """
        return DataRequirements(
            image_embedding=True,
            text_encoding=any(unet_config.cond_on_text_encodings for unet_config in decoder_config.unets[min_unet_number - 1:]),
            image=min_unet_number > 1,  # If this is an upsampler we need an image
            text=False,  # Text is never required for anything
            can_generate_embedding=False,  # This might be added later if clip is being used
            image_size=decoder_config.image_sizes[min_unet_number - 1]  # The input image size is the input to the first unet we are using
        )

    def _load_single_decoder(self, load_config: SingleDecoderLoadConfig) -> Tuple[Decoder, DecoderConfig, Optional[version.Version], bool]:
        """
        Loads a single decoder from a model and a config file
        """
        unet_sample_timesteps = load_config.default_sample_timesteps
        def apply_default_config(config: DecoderConfig):
            if unet_sample_timesteps is not None:
                base_sample_timesteps = [None] * len(config.unets)
                for unet_number, timesteps in zip(load_config.unet_numbers, unet_sample_timesteps):
                    base_sample_timesteps[unet_number - 1] = timesteps
                config.sample_timesteps = base_sample_timesteps
        
        with load_config.load_model_from.as_local_file(check_update=self.check_updates) as model_file:
            model_state_dict = torch.load(model_file, map_location=self.load_device)
            if 'version' in model_state_dict:
                model_version = model_state_dict['version']
                if model_version != self.current_version:
                    print(f'WARNING: This decoder was trained on version {model_version} but the current version is {self.current_version}. This may result in the model failing to load.')
                    print(f'FIX: Switch to this version with `pip install DALLE2-pytorch=={model_version}`. If different models suggest different versions, you may just need to choose one.')
            else:
                print(f'WARNING: This decoder was trained on an old version of Dalle2. This may result in the model failing to load or it may lead to producing garbage results.')
                model_version = None  # No version info in the model
            
            requires_clip = False
            if 'config' in model_state_dict:
                # Then we define the decoder config from this object
                decoder_config = TrainDecoderConfig(**model_state_dict['config']).decoder
                apply_default_config(decoder_config)
                if decoder_config.clip is not None:
                    # We don't want to load clip with the model
                    requires_clip = True
                    decoder_config.clip = None
                decoder = decoder_config.create().eval()
                decoder.load_state_dict(model_state_dict['model'], strict=self.strict_loading)  # If the model has a config included, then we know the model_state_dict['model'] is the actual model
            else:
                # In this case, the state_dict is the model itself. This means we also must load the config from an external file
                assert load_config.load_config_from is not None
                with load_config.load_config_from.as_local_file(check_update=self.check_updates) as config_file:
                    decoder_config = TrainDecoderConfig.from_json_path(config_file).decoder
                    apply_default_config(decoder_config)
                    if decoder_config.clip is not None:
                        # We don't want to load clip with the model
                        requires_clip = True
                        decoder_config.clip = None
                decoder = decoder_config.create().eval()
                decoder.load_state_dict(model_state_dict, strict=self.strict_loading)

            return decoder, decoder_config, model_version, requires_clip

    def load_decoder(self, load_config: DecoderLoadConfig) -> 'ModelInfo[Decoder]':
        """
        Loads a decoder from a model and a config file
        """
        if len(load_config.unet_sources) == 1:
            # Then we are loading only one model
            decoder, decoder_config, decoder_version, requires_clip = self._load_single_decoder(load_config.unet_sources[0])
            decoder_data_requirements = self._get_decoder_data_requirements(decoder_config)
            decoder.to(torch.float32)
            return ModelInfo(decoder, decoder_version, requires_clip, decoder_data_requirements)
        else:
            true_unets: List[Unet] = [None] * load_config.final_unet_number  # Stores the unets that will replace the ones in the true decoder
            true_unet_configs: List[UnetConfig] = [None] * load_config.final_unet_number  # Stores the unet configs that will replace the ones in the true decoder config
            true_upsampling_sizes: List[Tuple[int, int]] = [None] * load_config.final_unet_number  # Stores the progression of upsampling sizes for each unet so that we can validate these unets actually work together
            true_train_timesteps: List[int] = [None] * load_config.final_unet_number  # Stores the number of timesteps that each unet trained with
            true_beta_schedules: List[str] = [None] * load_config.final_unet_number  # Stores the beta scheduler that each unet used
            true_uses_learned_variance: List[bool] = [None] * load_config.final_unet_number  # Stores whether each unet uses learned variance
            true_sample_timesteps: List[int] = [None] * load_config.final_unet_number  # Stores the number of timesteps that each unet used to sample

            requires_clip = False
            for source in load_config.unet_sources:
                decoder, decoder_config, decoder_version, unets_requires_clip = self._load_single_decoder(source)
                if unets_requires_clip:
                    requires_clip = True
                if source.default_sample_timesteps is not None:
                    assert len(source.default_sample_timesteps) == len(source.unet_numbers)
                for i, unet_number in enumerate(source.unet_numbers):
                    unet_index = unet_number - 1
                    # Now we need to insert the unet into the true unets and the unet config into the true config
                    if source.default_sample_timesteps is not None:
                        true_sample_timesteps[unet_index] = source.default_sample_timesteps[i]
                    true_unets[unet_index] = decoder.unets[unet_index]
                    true_unet_configs[unet_index] = decoder_config.unets[unet_index]
                    true_upsampling_sizes[unet_index] = None if unet_index == 0 else decoder_config.image_sizes[unet_index - 1], decoder_config.image_sizes[unet_index]
                    true_train_timesteps[unet_index] = decoder_config.timesteps
                    true_beta_schedules[unet_index] = decoder_config.beta_schedule[unet_index]
                    true_uses_learned_variance[unet_index] = decoder_config.learned_variance if isinstance(decoder_config.learned_variance, bool) else decoder_config.learned_variance[unet_index]

            true_decoder_config_obj = {}
            # Insert the true configs into the true decoder config
            true_decoder_config_obj['unets'] = true_unet_configs
            true_image_sizes = []
            for i in range(load_config.final_unet_number):
                if i == 0:
                    true_image_sizes.append(true_upsampling_sizes[i][1])
                else:
                    assert true_upsampling_sizes[i - 1][1] == true_upsampling_sizes[i][0], f"The upsampling sizes for unet {i} are not compatible with unet {i - 1}."
                    true_image_sizes.append(true_upsampling_sizes[i][1])
            true_decoder_config_obj['image_sizes'] = true_image_sizes
            # All unets must have been trained with the same number of sampling timesteps in order to be compatible
            assert all(true_train_timesteps[0] == t for t in true_train_timesteps), f"All unets must have been trained with the same number of sampling timesteps in order to be compatible."
            true_decoder_config_obj['timesteps'] = true_train_timesteps[0]
            true_decoder_config_obj['beta_schedule'] = true_beta_schedules
            true_decoder_config_obj['learned_variance'] = true_uses_learned_variance
            # If any of the sample_timesteps are not None, then we need to insert them into the true decoder config
            if any(true_sample_timesteps):
                true_decoder_config_obj['sample_timesteps'] = true_sample_timesteps

            # Now we can create the decoder and substitute the unets
            true_decoder_config = DecoderConfig(**true_decoder_config_obj)
            decoder_data_requirements = self._get_decoder_data_requirements(true_decoder_config)
            decoder = true_decoder_config.create().eval()
            decoder.unets = nn.ModuleList(true_unets)
            decoder.to(torch.float32)
            return ModelInfo(decoder, decoder_version, requires_clip, decoder_data_requirements)
            
    def _get_prior_data_requirements(self, config: DiffusionPriorConfig) -> DataRequirements:
        """
        Returns the data requirements for a diffusion prior
        """
        return DataRequirements(
            image_embedding=False,  # This is kinda the whole point
            text_encoding=True,  # This is also kinda the whole point
            image=False,  # The prior is never conditioned on the image
            text=False,  # Text is never required for anything
            can_generate_embedding=False,  # This might be added later if clip is being used
            image_size=[-1, -1]  # This is not used
        )

    def load_prior(self, load_config: PriorLoadConfig) -> 'ModelInfo[DiffusionPrior]':
        """
        Loads a prior from a model and a config file
        """
        sample_timesteps = load_config.default_sample_timesteps
        def apply_default_config(config: DiffusionPriorConfig) -> DiffusionPriorConfig:
            """
            Applies the default config to the given config
            """
            if sample_timesteps is not None:
                config.sample_timesteps = sample_timesteps

        with load_config.load_model_from.as_local_file(check_update=self.check_updates) as model_file:
            model_state_dict = torch.load(model_file, map_location=self.load_device)
            if 'version' in model_state_dict:
                model_version = model_state_dict['version']
                if model_version != self.current_version:
                    print(f'WARNING: This prior was trained on version {model_version} but the current version is {self.current_version}. This may result in the model failing to load.')
                    print(f'FIX: Switch to this version with `pip install DALLE2-pytorch=={model_version}`. If different models suggest different versions, you may just need to choose one.')
            else:
                print('WARNING: This prior was trained on an old version of Dalle2. This may result in the model failing to load or it may produce garbage results.')
                model_version = None

            requires_clip = False
            if 'config' in model_state_dict:
                # Then we define the prior config from this object
                prior_config = TrainDiffusionPriorConfig(**model_state_dict['config']).prior
                apply_default_config(prior_config)
                if prior_config.clip is not None:
                    # We don't want to load clip with the model
                    prior_config.clip = None
                    requires_clip = True
                prior = prior_config.create().eval()
                prior.load_state_dict(model_state_dict['model'], strict=self.strict_loading)
            else:
                # In this case, the state_dict is the model itself. This means we also must load the config from an external file
                assert load_config.load_config_from is not None
                with load_config.load_config_from.as_local_file(check_update=self.check_updates) as config_file:
                    prior_config = TrainDiffusionPriorConfig.from_json_path(config_file).prior
                    apply_default_config(prior_config)
                    if prior_config.clip is not None:
                        # We don't want to load clip with the model
                        prior_config.clip = None
                        requires_clip = True
                prior = prior_config.create().eval()
                prior.load_state_dict(model_state_dict, strict=self.strict_loading)

            data_requirements = self._get_prior_data_requirements(prior_config)
            prior.to(torch.float32)
            return ModelInfo(prior, model_version, requires_clip, data_requirements)