models.py 4.02 KB
Newer Older
1
import json
2
import logging
3
4
5
from typing import Optional

import peewee as pw
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
from peewee import *

8
from playhouse.shortcuts import model_to_dict
9
from pydantic import BaseModel, ConfigDict
10

11
from apps.web.internal.db import DB, JSONField
12

Timothy J. Baek's avatar
Timothy J. Baek committed
13
from typing import List, Union, Optional
14
15
from config import SRC_LOG_LEVELS

Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
import time

18
19
20
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

21
22
23
24
25
26
27
28

####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
29
    model_config = ConfigDict(extra="allow")
30
31
32
33
34
35
    pass


# ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel):
36
    description: Optional[str] = None
37
    """
38
        User-facing description of the model.
39
40
    """

41
    vision_capable: Optional[bool] = None
42
43
44
    """
        A flag indicating if the model is capable of vision and thus image inputs
    """
45

46
47
48
49
    model_config = ConfigDict(extra="allow")

    pass

50
51

class Model(pw.Model):
52
    id = pw.TextField(unique=True)
53
54
55
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """
Timothy J. Baek's avatar
Timothy J. Baek committed
56
    user_id = pw.TextField()
57

58
    base_model_id = pw.TextField(null=True)
59
    """
60
        An optional pointer to the actual model that should be used when proxying requests.
61
62
63
64
    """

    name = pw.TextField()
    """
65
        The human-readable display name of the model.
66
67
    """

68
    params = JSONField()
69
    """
70
        Holds a JSON encoded blob of parameters, see `ModelParams`.
71
72
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
73
74
75
76
77
    meta = JSONField()
    """
        Holds a JSON encoded blob of metadata, see `ModelMeta`.
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
78
79
    updated_at = BigIntegerField()
    created_at = BigIntegerField()
Timothy J. Baek's avatar
Timothy J. Baek committed
80

81
82
83
84
85
86
    class Meta:
        database = DB


class ModelModel(BaseModel):
    id: str
87
    base_model_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
88

89
    name: str
90
    params: ModelParams
Timothy J. Baek's avatar
Timothy J. Baek committed
91
    meta: ModelMeta
92

Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
95
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch

96
97
98
99
100
101

####################
# Forms
####################


Timothy J. Baek's avatar
Timothy J. Baek committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class ModelResponse(BaseModel):
    id: str
    name: str
    meta: ModelMeta
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch


class ModelForm(BaseModel):
    id: str
    base_model_id: Optional[str] = None
    name: str
    meta: ModelMeta
    params: ModelParams


118
119
120
121
122
123
124
125
class ModelsTable:
    def __init__(
        self,
        db: pw.SqliteDatabase | pw.PostgresqlDatabase,
    ):
        self.db = db
        self.db.create_tables([Model])

Timothy J. Baek's avatar
Timothy J. Baek committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    def insert_new_model(self, model: ModelForm, user_id: str) -> Optional[ModelModel]:
        try:
            model = Model.create(
                **{
                    **model.model_dump(),
                    "user_id": user_id,
                    "created_at": int(time.time()),
                    "updated_at": int(time.time()),
                }
            )
            return ModelModel(**model_to_dict(model))
        except:
            return None

    def get_all_models(self) -> List[ModelModel]:
141
142
        return [ModelModel(**model_to_dict(model)) for model in Model.select()]

Timothy J. Baek's avatar
Timothy J. Baek committed
143
    def get_model_by_id(self, id: str) -> Optional[ModelModel]:
144
        try:
Timothy J. Baek's avatar
Timothy J. Baek committed
145
146
147
148
            model = Model.get(Model.id == id)
            return ModelModel(**model_to_dict(model))
        except:
            return None
149

Timothy J. Baek's avatar
Timothy J. Baek committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
        try:
            # update only the fields that are present in the model
            query = Model.update(**model.model_dump()).where(Model.id == id)
            query.execute()

            model = Model.get(Model.id == id)
            return ModelModel(**model_to_dict(model))
        except:
            return None

    def delete_model_by_id(self, id: str) -> bool:
        try:
            query = Model.delete().where(Model.id == id)
            query.execute()
165
            return True
Timothy J. Baek's avatar
Timothy J. Baek committed
166
        except:
167
168
169
170
            return False


Models = ModelsTable(DB)