models.py 3.73 KB
Newer Older
1
import json
2
import logging
3
4
5
6
7
8
from typing import Optional

import peewee as pw
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel

9
from apps.web.internal.db import DB, JSONField
10

11
12
13
14
15
from config import SRC_LOG_LEVELS

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

16
17
18
19
20
21
22
23
24

####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelParams(BaseModel):
25
26
27
28
29
30
    pass


# ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel):
31
    description: str
32
    """
33
        User-facing description of the model.
34
35
36
    """

    vision_capable: bool
37
38
39
    """
        A flag indicating if the model is capable of vision and thus image inputs
    """
40
41
42


class Model(pw.Model):
43
    id = pw.TextField(unique=True)
44
45
46
47
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """

48
    meta = JSONField()
49
    """
50
        Holds a JSON encoded blob of metadata, see `ModelMeta`.
51
52
    """

53
    base_model_id = pw.TextField(null=True)
54
    """
55
56
        An optional pointer to the actual model that should be used when proxying requests.
        Currently unused - but will be used to support Modelfile like behaviour in the future
57
58
59
60
    """

    name = pw.TextField()
    """
61
        The human-readable display name of the model.
62
63
    """

64
    params = JSONField()
65
    """
66
        Holds a JSON encoded blob of parameters, see `ModelParams`.
67
68
69
70
71
72
73
74
    """

    class Meta:
        database = DB


class ModelModel(BaseModel):
    id: str
75
76
    meta: ModelMeta
    base_model_id: Optional[str] = None
77
    name: str
78
    params: ModelParams
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97


####################
# Forms
####################


class ModelsTable:

    def __init__(
        self,
        db: pw.SqliteDatabase | pw.PostgresqlDatabase,
    ):
        self.db = db
        self.db.create_tables([Model])

    def get_all_models(self) -> list[ModelModel]:
        return [ModelModel(**model_to_dict(model)) for model in Model.select()]

98
    def update_all_models(self, models: list[ModelModel]) -> bool:
99
100
101
102
        try:
            with self.db.atomic():
                # Fetch current models from the database
                current_models = self.get_all_models()
103
                current_model_dict = {model.id: model for model in current_models}
104

105
                # Create a set of model IDs from the current models and the new models
106
                current_model_keys = set(current_model_dict.keys())
107
                new_model_keys = set(model.id for model in models)
108
109
110

                # Determine which models need to be created, updated, or deleted
                models_to_create = [
111
                    model for model in models if model.id not in current_model_keys
112
113
                ]
                models_to_update = [
114
                    model for model in models if model.id in current_model_keys
115
116
117
118
119
                ]
                models_to_delete = current_model_keys - new_model_keys

                # Perform the necessary database operations
                for model in models_to_create:
120
                    Model.create(**model.model_dump())
121
122

                for model in models_to_update:
123
124
                    Model.update(**model.model_dump()).where(
                        Model.id == model.id
125
126
127
                    ).execute()

                for model_id, model_source in models_to_delete:
128
                    Model.delete().where(Model.id == model_id).execute()
129
130
131

            return True
        except Exception as e:
132
            log.exception(e)
133
134
135
136
            return False


Models = ModelsTable(DB)