models.py 4.08 KB
Newer Older
1
import json
2
import logging
3
4
5
from typing import Optional

import peewee as pw
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
from peewee import *

8
from playhouse.shortcuts import model_to_dict
9
from pydantic import BaseModel, ConfigDict
10

11
from apps.webui.internal.db import DB, JSONField
12

Timothy J. Baek's avatar
Timothy J. Baek committed
13
from typing import List, Union, Optional
14
15
from config import SRC_LOG_LEVELS

Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
import time

18
19
20
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

21
22
23
24
25
26
27
28

####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
29
    model_config = ConfigDict(extra="allow")
30
31
32
33
34
    pass


# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
35
36
    profile_image_url: Optional[str] = "/favicon.png"

37
    description: Optional[str] = None
38
    """
39
        User-facing description of the model.
40
41
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
42
    capabilities: Optional[dict] = None
43

44
45
46
47
    model_config = ConfigDict(extra="allow")

    pass

48
49

class Model(pw.Model):
50
    id = pw.TextField(unique=True)
51
52
53
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """
Timothy J. Baek's avatar
Timothy J. Baek committed
54
    user_id = pw.TextField()
55

56
    base_model_id = pw.TextField(null=True)
57
    """
58
        An optional pointer to the actual model that should be used when proxying requests.
59
60
61
62
    """

    name = pw.TextField()
    """
63
        The human-readable display name of the model.
64
65
    """

66
    params = JSONField()
67
    """
68
        Holds a JSON encoded blob of parameters, see `ModelParams`.
69
70
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
73
74
75
    meta = JSONField()
    """
        Holds a JSON encoded blob of metadata, see `ModelMeta`.
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
76
77
    updated_at = BigIntegerField()
    created_at = BigIntegerField()
Timothy J. Baek's avatar
Timothy J. Baek committed
78

79
80
81
82
83
84
    class Meta:
        database = DB


class ModelModel(BaseModel):
    id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
85
    user_id: str
86
    base_model_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
87

88
    name: str
89
    params: ModelParams
Timothy J. Baek's avatar
Timothy J. Baek committed
90
    meta: ModelMeta
91

Timothy J. Baek's avatar
Timothy J. Baek committed
92
93
94
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch

95
96
97
98
99
100

####################
# Forms
####################


Timothy J. Baek's avatar
Timothy J. Baek committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class ModelResponse(BaseModel):
    id: str
    name: str
    meta: ModelMeta
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch


class ModelForm(BaseModel):
    id: str
    base_model_id: Optional[str] = None
    name: str
    meta: ModelMeta
    params: ModelParams


117
118
119
120
121
122
123
124
class ModelsTable:
    def __init__(
        self,
        db: pw.SqliteDatabase | pw.PostgresqlDatabase,
    ):
        self.db = db
        self.db.create_tables([Model])

Timothy J. Baek's avatar
Timothy J. Baek committed
125
126
127
128
129
130
131
132
133
134
135
    def insert_new_model(
        self, form_data: ModelForm, user_id: str
    ) -> Optional[ModelModel]:
        model = ModelModel(
            **{
                **form_data.model_dump(),
                "user_id": user_id,
                "created_at": int(time.time()),
                "updated_at": int(time.time()),
            }
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
136
        try:
Timothy J. Baek's avatar
Timothy J. Baek committed
137
138
139
140
141
142
143
144
            result = Model.create(**model.model_dump())

            if result:
                return model
            else:
                return None
        except Exception as e:
            print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
145
146
147
            return None

    def get_all_models(self) -> List[ModelModel]:
148
149
        return [ModelModel(**model_to_dict(model)) for model in Model.select()]

Timothy J. Baek's avatar
Timothy J. Baek committed
150
    def get_model_by_id(self, id: str) -> Optional[ModelModel]:
151
        try:
Timothy J. Baek's avatar
Timothy J. Baek committed
152
153
154
155
            model = Model.get(Model.id == id)
            return ModelModel(**model_to_dict(model))
        except:
            return None
156

Timothy J. Baek's avatar
Timothy J. Baek committed
157
158
159
160
161
162
163
164
    def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
        try:
            # update only the fields that are present in the model
            query = Model.update(**model.model_dump()).where(Model.id == id)
            query.execute()

            model = Model.get(Model.id == id)
            return ModelModel(**model_to_dict(model))
Timothy J. Baek's avatar
Timothy J. Baek committed
165
166
167
        except Exception as e:
            print(e)

Timothy J. Baek's avatar
Timothy J. Baek committed
168
169
170
171
172
173
            return None

    def delete_model_by_id(self, id: str) -> bool:
        try:
            query = Model.delete().where(Model.id == id)
            query.execute()
174
            return True
Timothy J. Baek's avatar
Timothy J. Baek committed
175
        except:
176
177
178
179
            return False


Models = ModelsTable(DB)