models.py 4.43 KB
Newer Older
1
import json
2
import logging
3
4
from typing import Optional

5
from pydantic import BaseModel, ConfigDict
6
from sqlalchemy import String, Column, BigInteger, Text
7

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
8
from apps.webui.internal.db import Base, JSONField, get_db
9

Timothy J. Baek's avatar
Timothy J. Baek committed
10
from typing import List, Union, Optional
11
12
from config import SRC_LOG_LEVELS

Timothy J. Baek's avatar
Timothy J. Baek committed
13
14
import time

15
16
17
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

18
19
20
21
22
23
24
25

####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
26
    model_config = ConfigDict(extra="allow")
27
28
29
30
31
    pass


# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
    profile_image_url: Optional[str] = "/favicon.png"

34
    description: Optional[str] = None
35
    """
36
        User-facing description of the model.
37
38
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
39
    capabilities: Optional[dict] = None
40

41
42
43
44
    model_config = ConfigDict(extra="allow")

    pass

45

46
47
48
class Model(Base):
    __tablename__ = "model"

49
    id = Column(Text, primary_key=True)
50
51
52
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """
53
    user_id = Column(Text)
54

55
    base_model_id = Column(Text, nullable=True)
56
    """
57
        An optional pointer to the actual model that should be used when proxying requests.
58
59
    """

60
    name = Column(Text)
61
    """
62
        The human-readable display name of the model.
63
64
    """

65
    params = Column(JSONField)
66
    """
67
        Holds a JSON encoded blob of parameters, see `ModelParams`.
68
69
    """

70
    meta = Column(JSONField)
Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
73
74
    """
        Holds a JSON encoded blob of metadata, see `ModelMeta`.
    """

75
76
    updated_at = Column(BigInteger)
    created_at = Column(BigInteger)
77
78
79
80


class ModelModel(BaseModel):
    id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
81
    user_id: str
82
    base_model_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
83

84
    name: str
85
    params: ModelParams
Timothy J. Baek's avatar
Timothy J. Baek committed
86
    meta: ModelMeta
87

Timothy J. Baek's avatar
Timothy J. Baek committed
88
89
90
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch

91
92
    model_config = ConfigDict(from_attributes=True)

93
94
95
96
97
98

####################
# Forms
####################


Timothy J. Baek's avatar
Timothy J. Baek committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class ModelResponse(BaseModel):
    id: str
    name: str
    meta: ModelMeta
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch


class ModelForm(BaseModel):
    id: str
    base_model_id: Optional[str] = None
    name: str
    meta: ModelMeta
    params: ModelParams


115
116
class ModelsTable:

Timothy J. Baek's avatar
Timothy J. Baek committed
117
    def insert_new_model(
118
        self, form_data: ModelForm, user_id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
119
120
121
122
123
124
125
126
127
    ) -> Optional[ModelModel]:
        model = ModelModel(
            **{
                **form_data.model_dump(),
                "user_id": user_id,
                "created_at": int(time.time()),
                "updated_at": int(time.time()),
            }
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
128
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
129
130
131
132
133
134
135
136
137
138
139
140

            with get_db() as db:

                result = Model(**model.model_dump())
                db.add(result)
                db.commit()
                db.refresh(result)

                if result:
                    return ModelModel.model_validate(result)
                else:
                    return None
Timothy J. Baek's avatar
Timothy J. Baek committed
141
142
        except Exception as e:
            print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
143
144
            return None

145
    def get_all_models(self) -> List[ModelModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
146
147
148
        with get_db() as db:

            return [ModelModel.model_validate(model) for model in db.query(Model).all()]
149

150
    def get_model_by_id(self, id: str) -> Optional[ModelModel]:
151
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
152
153
154
155
            with get_db() as db:

                model = db.get(Model, id)
                return ModelModel.model_validate(model)
Timothy J. Baek's avatar
Timothy J. Baek committed
156
157
        except:
            return None
158

159
    def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
160
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
161
162
            with get_db() as db:
                # update only the fields that are present in the model
Timothy J. Baek's avatar
Timothy J. Baek committed
163
164
165
166
167
                result = (
                    db.query(Model)
                    .filter_by(id=id)
                    .update(model.model_dump(exclude={"id"}, exclude_none=True))
                )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
168
                db.commit()
Timothy J. Baek's avatar
Timothy J. Baek committed
169
170

                model = db.get(Model, id)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
171
172
                db.refresh(model)
                return ModelModel.model_validate(model)
Timothy J. Baek's avatar
Timothy J. Baek committed
173
174
175
        except Exception as e:
            print(e)

Timothy J. Baek's avatar
Timothy J. Baek committed
176
177
            return None

178
    def delete_model_by_id(self, id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
179
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
180
181
182
            with get_db() as db:

                db.query(Model).filter_by(id=id).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
183
184
                db.commit()

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
185
                return True
Timothy J. Baek's avatar
Timothy J. Baek committed
186
        except:
187
188
189
            return False


190
Models = ModelsTable()