models.py 4.21 KB
Newer Older
1
import json
2
import logging
3
4
5
from typing import Optional

import peewee as pw
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
from peewee import *

8
from playhouse.shortcuts import model_to_dict
9
from pydantic import BaseModel, ConfigDict
10

11
from apps.web.internal.db import DB, JSONField
12

Timothy J. Baek's avatar
Timothy J. Baek committed
13
from typing import List, Union, Optional
14
15
from config import SRC_LOG_LEVELS

Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
import time

18
19
20
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

21
22
23
24
25
26
27
28

####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
29
    model_config = ConfigDict(extra="allow")
30
31
32
33
34
35
    pass


# ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
36
37
    profile_image_url: Optional[str] = "/favicon.png"

38
    description: Optional[str] = None
39
    """
40
        User-facing description of the model.
41
42
    """

43
    vision_capable: Optional[bool] = None
44
45
46
    """
        A flag indicating if the model is capable of vision and thus image inputs
    """
47

48
49
50
51
    model_config = ConfigDict(extra="allow")

    pass

52
53

class Model(pw.Model):
54
    id = pw.TextField(unique=True)
55
56
57
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """
Timothy J. Baek's avatar
Timothy J. Baek committed
58
    user_id = pw.TextField()
59

60
    base_model_id = pw.TextField(null=True)
61
    """
62
        An optional pointer to the actual model that should be used when proxying requests.
63
64
65
66
    """

    name = pw.TextField()
    """
67
        The human-readable display name of the model.
68
69
    """

70
    params = JSONField()
71
    """
72
        Holds a JSON encoded blob of parameters, see `ModelParams`.
73
74
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
75
76
77
78
79
    meta = JSONField()
    """
        Holds a JSON encoded blob of metadata, see `ModelMeta`.
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
80
81
    updated_at = BigIntegerField()
    created_at = BigIntegerField()
Timothy J. Baek's avatar
Timothy J. Baek committed
82

83
84
85
86
87
88
    class Meta:
        database = DB


class ModelModel(BaseModel):
    id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
89
    user_id: str
90
    base_model_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
91

92
    name: str
93
    params: ModelParams
Timothy J. Baek's avatar
Timothy J. Baek committed
94
    meta: ModelMeta
95

Timothy J. Baek's avatar
Timothy J. Baek committed
96
97
98
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch

99
100
101
102
103
104

####################
# Forms
####################


Timothy J. Baek's avatar
Timothy J. Baek committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class ModelResponse(BaseModel):
    id: str
    name: str
    meta: ModelMeta
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch


class ModelForm(BaseModel):
    id: str
    base_model_id: Optional[str] = None
    name: str
    meta: ModelMeta
    params: ModelParams


121
122
123
124
125
126
127
128
class ModelsTable:
    def __init__(
        self,
        db: pw.SqliteDatabase | pw.PostgresqlDatabase,
    ):
        self.db = db
        self.db.create_tables([Model])

Timothy J. Baek's avatar
Timothy J. Baek committed
129
130
131
132
133
134
135
136
137
138
139
    def insert_new_model(
        self, form_data: ModelForm, user_id: str
    ) -> Optional[ModelModel]:
        model = ModelModel(
            **{
                **form_data.model_dump(),
                "user_id": user_id,
                "created_at": int(time.time()),
                "updated_at": int(time.time()),
            }
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
140
        try:
Timothy J. Baek's avatar
Timothy J. Baek committed
141
142
143
144
145
146
147
148
            result = Model.create(**model.model_dump())

            if result:
                return model
            else:
                return None
        except Exception as e:
            print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
149
150
151
            return None

    def get_all_models(self) -> List[ModelModel]:
152
153
        return [ModelModel(**model_to_dict(model)) for model in Model.select()]

Timothy J. Baek's avatar
Timothy J. Baek committed
154
    def get_model_by_id(self, id: str) -> Optional[ModelModel]:
155
        try:
Timothy J. Baek's avatar
Timothy J. Baek committed
156
157
158
159
            model = Model.get(Model.id == id)
            return ModelModel(**model_to_dict(model))
        except:
            return None
160

Timothy J. Baek's avatar
Timothy J. Baek committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
        try:
            # update only the fields that are present in the model
            query = Model.update(**model.model_dump()).where(Model.id == id)
            query.execute()

            model = Model.get(Model.id == id)
            return ModelModel(**model_to_dict(model))
        except:
            return None

    def delete_model_by_id(self, id: str) -> bool:
        try:
            query = Model.delete().where(Model.id == id)
            query.execute()
176
            return True
Timothy J. Baek's avatar
Timothy J. Baek committed
177
        except:
178
179
180
181
            return False


Models = ModelsTable(DB)