config.py 8.59 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from pathlib import Path
from dalle2_pytorch.train_configs import AdapterConfig as ClipConfig
from typing import List, Optional, Union
from enum import Enum
from pydantic import BaseModel, root_validator, ValidationError
from contextlib import contextmanager
import tempfile
import urllib.request
import json
import pdb

class LoadLocation(str, Enum):
    """
    Enum for the possible locations of the data.
    """
    local = "local"
    url = "url"

class File(BaseModel):
    load_type: LoadLocation
    path: str
    checksum_file_path: Optional[str] = None
    cache_dir: Optional[Path] = None
    filename_override: Optional[str] = None

    @root_validator(pre=True)
    def add_default_checksum(cls, values):
        """
        When loading from url, the checksum is the best way to see if there is an update to the model.
        If we are loading from specific places, we know it is already storing a checksum and we can read and compare those to check for updates.
        Sources we can do this with:
        1. Huggingface: If model is at https://huggingface.co/[ORG?]/[REPO]/resolve/main/[PATH_TO_MODEL.pth] we know the checksum is at https://huggingface.co/[ORG?]/[REPO]/raw/main/[PATH_TO_MODEL.pth]
        """
        if values["load_type"] == LoadLocation.url:
            filepath = values["path"]
            existing_checksum = values["checksum_file_path"] if "checksum_file_path" in values else None
            if filepath.startswith("https://huggingface.co/") and "resolve" in filepath and existing_checksum is None:
                values["checksum_file_path"] = filepath.replace("resolve/main/", "raw/main/")
        return values

    def download_to(self, path: Path):
        """
        Downloads the file to the given path
        """
        assert self.load_type == LoadLocation.url
        urllib.request.urlretrieve(self.path, path)
        if self.checksum_file_path is not None:
            urllib.request.urlretrieve(self.checksum_file_path, str(path) + ".checksum")

    def download_checksum_to(self, path: Path):
        """
        Downloads the checksum to the given path
        """
        assert self.load_type == LoadLocation.url
        assert self.checksum_file_path is not None, "No checksum file path specified"
        urllib.request.urlretrieve(self.checksum_file_path, path)

    def get_remote_checksum(self):
        """
        Downloads the remote checksum as a tempfile and returns its content
        """
        with tempfile.TemporaryDirectory() as tmpdir:
            self.download_checksum_to(tmpdir + "/checksum")
            with open(tmpdir + "/checksum", "r") as f:
                checksum = f.read()
        return checksum

    @property
    def filename(self):
        if self.filename_override is not None:
            return self.filename_override
        # The filename is everything after the last '/' but before the '?' if it exists
        filename = self.path.split('/')[-1]
        if '?' in filename:
            filename = filename.split('?')[0]
        return filename

    @contextmanager
    def as_local_file(self, check_update: bool = True):
        """
        Loads the file as a local file.
        If check_update is True, it will download a new version if the checksum is different.
        """
        if self.load_type == LoadLocation.local:
            yield self.path
        elif self.cache_dir is not None:
            # Then we are caching the data in a local directory
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            file_path = self.cache_dir / self.filename
            cached_checksum_path = self.cache_dir / (self.filename + ".checksum")
            if not file_path.exists():
                print(f"Downloading {self.path} to {file_path}")
                self.download_to(file_path)
            else:
                # Then we should download and compare the checksums
                if self.checksum_file_path is None:
                    print(f'{file_path} already exists. Skipping download. No checksum found so if you think this file should be re-downloaded, delete it and try again.')
                elif not cached_checksum_path.exists():
                    # Then we don't know if the file is up to date so we should download it
                    if check_update:
                        print(f"Checksum not found for {file_path}. Downloading it again.")
                        self.download_to(file_path)
                    else:
                        print(f"Checksum not found for {file_path}, but updates are disabled. Skipping download.")
                else:
                    new_checksum = self.get_remote_checksum()
                    with open(cached_checksum_path, "r") as f:
                        old_checksum = f.read()
                    should_update = new_checksum != old_checksum
                    if should_update:
                        if check_update:
                            print(f"Checksum mismatch. Deleting {file_path} and downloading again.")
                            file_path.unlink()
                            self.download_to(file_path)  # This automatically overwrites the checksum file
                        else:
                            print(f"Checksums mismatched, but updates are disabled. Skipping download.")
                    
            yield file_path
        else:
            # Then we are not caching and the file should be stored in a temporary directory
            with tempfile.TemporaryDirectory() as tmpdir:
                tmpfile = tmpdir + "/" + self.filename
                self.download_to(tmpfile)
                yield tmpfile

class SingleDecoderLoadConfig(BaseModel):
    """
    Configuration for the single decoder load.
    """
    unet_numbers: List[int]
    default_sample_timesteps: Optional[List[int]] = None
    default_cond_scale: Optional[List[float]] = None
    load_model_from: File
    # load_config_from: Optional[File] # The config may be defined within the model file if the version is high enough
    load_config_from: Optional[File] = None  # The config may be defined within the model file if the version is high enough

class DecoderLoadConfig(BaseModel):
    """
    Configuration for the decoder load.
    """
    unet_sources: List[SingleDecoderLoadConfig]

    final_unet_number: int

    @root_validator(pre=True)
    def compute_num_unets(cls, values):
        """
        Gets the final unet number
        """
        unet_numbers = []
        assert "unet_sources" in values, "No unet sources defined. Make sure `unet_sources` is defined in the decoder config."
        for value in values["unet_sources"]:
            unet_numbers.extend(value["unet_numbers"])
        final_unet_number = max(unet_numbers)
        values["final_unet_number"] = final_unet_number
        return values

    @root_validator(pre=True)
    def verify_unet_numbers_valid(cls, values):
        """
        The unets must go from 1 to some positive number not skipping any and not repeating any.
        """
        unet_numbers = []
        for value in values["unet_sources"]:
            unet_numbers.extend(value["unet_numbers"])
        unet_numbers.sort()
        if len(unet_numbers) != len(set(unet_numbers)):
            raise ValidationError("The decoder unet numbers must not repeat.")
        if unet_numbers[0] != 1:
            raise ValidationError("The decoder unet numbers must start from 1.")
        differences = [unet_numbers[i] - unet_numbers[i - 1] for i in range(1, len(unet_numbers))]
        if any(diff != 1 for diff in differences):
            raise ValidationError("The decoder unet numbers must not skip any.")
        return values
    
class PriorLoadConfig(BaseModel):
    """
    Configuration for the prior load.
    """
    default_sample_timesteps: Optional[int] = None
    default_cond_scale: Optional[float] = None
    load_model_from: File
    load_config_from: Optional[File]  # The config may be defined within the model file if the version is high enough

class ModelLoadConfig(BaseModel):
    """
    Configuration for the model load.
    """
    decoder: Optional[DecoderLoadConfig] = None
    prior: Optional[PriorLoadConfig] = None
    clip: Optional[ClipConfig] = None

    devices: Union[List[str], str] = 'cuda:0'  # The device(s) to use for model inference. If a list, the first device is used for loading.
    load_on_cpu: bool = True  # Whether to load the state_dict on the first device or on the cpu
    strict_loading: bool = True  # Whether to error on loading if the model is not compatible with the current version of the code

    @classmethod
    def from_json_path(cls, json_path):
        with open(json_path) as f:
            config = json.load(f)
        return cls(**config)