"requirements-rocm.txt" did not exist on "057daef778ec4e951841f44afda1cd0b1eb50ee4"
supported_models_base.py 2.88 KB
Newer Older
1
2
3
import torch
from . import model_base
from . import utils
4
from . import latent_formats
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

class ClipTarget:
    def __init__(self, tokenizer, clip):
        self.clip = clip
        self.tokenizer = tokenizer
        self.params = {}

class BASE:
    unet_config = {}
    unet_extra_config = {
        "num_heads": -1,
        "num_head_channels": 64,
    }

    clip_prefix = []
    clip_vision_prefix = None
    noise_aug_config = None
22
    sampling_settings = {}
23
    latent_format = latent_formats.LatentFormat
24
    vae_key_prefix = ["first_stage_model."]
comfyanonymous's avatar
comfyanonymous committed
25
    supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
26

27
28
    manual_cast_dtype = None

29
30
31
    @classmethod
    def matches(s, unet_config):
        for k in s.unet_config:
comfyanonymous's avatar
comfyanonymous committed
32
            if k not in unet_config or s.unet_config[k] != unet_config[k]:
33
34
35
                return False
        return True

36
37
    def model_type(self, state_dict, prefix=""):
        return model_base.ModelType.EPS
38
39
40
41
42
43

    def inpaint_model(self):
        return self.unet_config["in_channels"] > 4

    def __init__(self, unet_config):
        self.unet_config = unet_config
44
        self.latent_format = self.latent_format()
45
46
47
        for x in self.unet_extra_config:
            self.unet_config[x] = self.unet_extra_config[x]

48
    def get_model(self, state_dict, prefix="", device=None):
comfyanonymous's avatar
comfyanonymous committed
49
50
        if self.noise_aug_config is not None:
            out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
51
        else:
comfyanonymous's avatar
comfyanonymous committed
52
53
54
55
            out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
        if self.inpaint_model():
            out.set_inpaint()
        return out
56
57
58
59

    def process_clip_state_dict(self, state_dict):
        return state_dict

60
61
62
    def process_unet_state_dict(self, state_dict):
        return state_dict

63
64
65
    def process_vae_state_dict(self, state_dict):
        return state_dict

66
67
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {"": "cond_stage_model."}
68
        return utils.state_dict_prefix_replace(state_dict, replace_prefix)
69

70
71
72
73
74
75
    def process_clip_vision_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
        if self.clip_vision_prefix is not None:
            replace_prefix[""] = self.clip_vision_prefix
        return utils.state_dict_prefix_replace(state_dict, replace_prefix)

76
77
    def process_unet_state_dict_for_saving(self, state_dict):
        replace_prefix = {"": "model.diffusion_model."}
78
        return utils.state_dict_prefix_replace(state_dict, replace_prefix)
79
80
81

    def process_vae_state_dict_for_saving(self, state_dict):
        replace_prefix = {"": "first_stage_model."}
82
        return utils.state_dict_prefix_replace(state_dict, replace_prefix)
83

comfyanonymous's avatar
comfyanonymous committed
84
85
    def set_inference_dtype(self, dtype, manual_cast_dtype):
        self.unet_config['dtype'] = dtype
86
        self.manual_cast_dtype = manual_cast_dtype