models.py 3.98 KB
Newer Older
1
import json
2
import logging
3
4
5
6
from typing import Optional

import peewee as pw
from playhouse.shortcuts import model_to_dict
7
from pydantic import BaseModel, ConfigDict
8

9
from apps.web.internal.db import DB, JSONField
10

11
12
13
14
15
from config import SRC_LOG_LEVELS

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

16
17
18
19
20
21
22
23
24

####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelParams(BaseModel):
25
26
    model_config = ConfigDict(extra="allow")

27
28
29
30
31
32
    pass


# ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel):
33
    description: Optional[str] = None
34
    """
35
        User-facing description of the model.
36
37
    """

38
    vision_capable: Optional[bool] = None
39
40
41
    """
        A flag indicating if the model is capable of vision and thus image inputs
    """
42

43
44
45
46
    model_config = ConfigDict(extra="allow")

    pass

47
48

class Model(pw.Model):
49
    id = pw.TextField(unique=True)
50
51
52
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """
Timothy J. Baek's avatar
Timothy J. Baek committed
53
    user_id = pw.TextField()
54

55
    base_model_id = pw.TextField(null=True)
56
    """
57
58
        An optional pointer to the actual model that should be used when proxying requests.
        Currently unused - but will be used to support Modelfile like behaviour in the future
59
60
61
62
    """

    name = pw.TextField()
    """
63
        The human-readable display name of the model.
64
65
    """

66
    params = JSONField()
67
    """
68
        Holds a JSON encoded blob of parameters, see `ModelParams`.
69
70
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
73
74
75
76
77
78
    meta = JSONField()
    """
        Holds a JSON encoded blob of metadata, see `ModelMeta`.
    """

    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch

79
80
81
82
83
84
    class Meta:
        database = DB


class ModelModel(BaseModel):
    id: str
85
    base_model_id: Optional[str] = None
86
    name: str
87
    params: ModelParams
Timothy J. Baek's avatar
Timothy J. Baek committed
88
    meta: ModelMeta
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106


####################
# Forms
####################


class ModelsTable:
    def __init__(
        self,
        db: pw.SqliteDatabase | pw.PostgresqlDatabase,
    ):
        self.db = db
        self.db.create_tables([Model])

    def get_all_models(self) -> list[ModelModel]:
        return [ModelModel(**model_to_dict(model)) for model in Model.select()]

107
    def update_all_models(self, models: list[ModelModel]) -> bool:
108
109
110
111
        try:
            with self.db.atomic():
                # Fetch current models from the database
                current_models = self.get_all_models()
112
                current_model_dict = {model.id: model for model in current_models}
113

114
                # Create a set of model IDs from the current models and the new models
115
                current_model_keys = set(current_model_dict.keys())
116
                new_model_keys = set(model.id for model in models)
117
118
119

                # Determine which models need to be created, updated, or deleted
                models_to_create = [
120
                    model for model in models if model.id not in current_model_keys
121
122
                ]
                models_to_update = [
123
                    model for model in models if model.id in current_model_keys
124
125
126
127
128
                ]
                models_to_delete = current_model_keys - new_model_keys

                # Perform the necessary database operations
                for model in models_to_create:
129
                    Model.create(**model.model_dump())
130
131

                for model in models_to_update:
132
133
                    Model.update(**model.model_dump()).where(
                        Model.id == model.id
134
135
136
                    ).execute()

                for model_id, model_source in models_to_delete:
137
                    Model.delete().where(Model.id == model_id).execute()
138
139
140

            return True
        except Exception as e:
141
            log.exception(e)
142
143
144
145
            return False


Models = ModelsTable(DB)