base.py 2.66 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from dataclasses import dataclass, field, fields
from typing import Any, Dict

from fastvideo.v1.logger import init_logger

logger = init_logger(__name__)


# 1. ArchConfig contains all fields from diffuser's/transformer's config.json (i.e. all fields related to the architecture of the model)
# 2. ArchConfig should be inherited & overridden by each model arch_config
# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users
@dataclass
class ArchConfig:
    pass


@dataclass
class ModelConfig:
    # Every model config parameter can be categorized into either ArchConfig or everything else
    # Diffuser/Transformer parameters
    arch_config: ArchConfig = field(default_factory=ArchConfig)

    # FastVideo-specific parameters here
    # i.e. STA, quantization, teacache

    def __getattr__(self, name):
        # Only called if 'name' is not found in ModelConfig directly
        if hasattr(self.arch_config, name):
            return getattr(self.arch_config, name)
        raise AttributeError(
            f"'{type(self).__name__}' object has no attribute '{name}'")

    def __getstate__(self):
        # Return a dictionary of attributes to pickle
        # Convert to dict and exclude any problematic attributes
        state = self.__dict__.copy()
        return state

    def __setstate__(self, state):
        # Restore instance attributes from the unpickled state
        self.__dict__.update(state)

    # This should be used only when loading from transformers/diffusers
    def update_model_arch(self, source_model_dict: Dict[str, Any]) -> None:
        arch_config = self.arch_config
        valid_fields = {f.name for f in fields(arch_config)}

        for key, value in source_model_dict.items():
            if key in valid_fields:
                setattr(arch_config, key, value)
            else:
                raise AttributeError(
                    f"{type(arch_config).__name__} has no field '{key}'")

        if hasattr(arch_config, "__post_init__"):
            arch_config.__post_init__()

    def update_model_config(self, source_model_dict: Dict[str, Any]) -> None:
        assert "arch_config" not in source_model_dict, "Source model config shouldn't contain arch_config."

        valid_fields = {f.name for f in fields(self)}

        for key, value in source_model_dict.items():
            if key in valid_fields:
                setattr(self, key, value)
            else:
                logger.warning("%s does not contain field '%s'!",
                               type(self).__name__, key)
                raise AttributeError(f"Invalid field: {key}")

        if hasattr(self, "__post_init__"):
            self.__post_init__()