models.py 3.84 KB
Newer Older
1
import json
2
import logging
3
4
5
6
7
8
from typing import Optional

import peewee as pw
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel

9
from apps.web.internal.db import DB, JSONField
10

11
12
13
14
15
from config import SRC_LOG_LEVELS

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

16
17
18
19
20
21
22
23
24

####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelParams(BaseModel):
25
26
27
28
29
30
    pass


# ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel):
31
    description: str
32
    """
33
        User-facing description of the model.
34
35
36
    """

    vision_capable: bool
37
38
39
    """
        A flag indicating if the model is capable of vision and thus image inputs
    """
40
41
42


class Model(pw.Model):
43
    id = pw.TextField(unique=True)
44
45
46
47
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
48
    user_id = pw.TextField()
49

50
    base_model_id = pw.TextField(null=True)
51
    """
52
53
        An optional pointer to the actual model that should be used when proxying requests.
        Currently unused - but will be used to support Modelfile like behaviour in the future
54
55
56
57
    """

    name = pw.TextField()
    """
58
        The human-readable display name of the model.
59
60
    """

61
    params = JSONField()
62
    """
63
        Holds a JSON encoded blob of parameters, see `ModelParams`.
64
65
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
66
67
68
69
70
71
72
73
    meta = JSONField()
    """
        Holds a JSON encoded blob of metadata, see `ModelMeta`.
    """

    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch

74
75
76
77
78
79
    class Meta:
        database = DB


class ModelModel(BaseModel):
    id: str
80
    base_model_id: Optional[str] = None
81
    name: str
82
    params: ModelParams
Timothy J. Baek's avatar
Timothy J. Baek committed
83
    meta: ModelMeta
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102


####################
# Forms
####################


class ModelsTable:

    def __init__(
        self,
        db: pw.SqliteDatabase | pw.PostgresqlDatabase,
    ):
        self.db = db
        self.db.create_tables([Model])

    def get_all_models(self) -> list[ModelModel]:
        return [ModelModel(**model_to_dict(model)) for model in Model.select()]

103
    def update_all_models(self, models: list[ModelModel]) -> bool:
104
105
106
107
        try:
            with self.db.atomic():
                # Fetch current models from the database
                current_models = self.get_all_models()
108
                current_model_dict = {model.id: model for model in current_models}
109

110
                # Create a set of model IDs from the current models and the new models
111
                current_model_keys = set(current_model_dict.keys())
112
                new_model_keys = set(model.id for model in models)
113
114
115

                # Determine which models need to be created, updated, or deleted
                models_to_create = [
116
                    model for model in models if model.id not in current_model_keys
117
118
                ]
                models_to_update = [
119
                    model for model in models if model.id in current_model_keys
120
121
122
123
124
                ]
                models_to_delete = current_model_keys - new_model_keys

                # Perform the necessary database operations
                for model in models_to_create:
125
                    Model.create(**model.model_dump())
126
127

                for model in models_to_update:
128
129
                    Model.update(**model.model_dump()).where(
                        Model.id == model.id
130
131
132
                    ).execute()

                for model_id, model_source in models_to_delete:
133
                    Model.delete().where(Model.id == model_id).execute()
134
135
136

            return True
        except Exception as e:
137
            log.exception(e)
138
139
140
141
            return False


Models = ModelsTable(DB)