models.py 4.31 KB
Newer Older
1
import json
2
import logging
3
4
from typing import Optional

5
from pydantic import BaseModel, ConfigDict
6
7
from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session
8

9
from apps.webui.internal.db import Base, JSONField, get_session
10

Timothy J. Baek's avatar
Timothy J. Baek committed
11
from typing import List, Union, Optional
12
13
from config import SRC_LOG_LEVELS

Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
import time

16
17
18
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

19
20
21
22
23
24
25
26

####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
27
    model_config = ConfigDict(extra="allow")
28
29
30
31
32
    pass


# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
33
34
    profile_image_url: Optional[str] = "/favicon.png"

35
    description: Optional[str] = None
36
    """
37
        User-facing description of the model.
38
39
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
40
    capabilities: Optional[dict] = None
41

42
43
44
45
    model_config = ConfigDict(extra="allow")

    pass

46

47
48
49
50
class Model(Base):
    __tablename__ = "model"

    id = Column(String, primary_key=True)
51
52
53
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """
54
    user_id = Column(String)
55

56
    base_model_id = Column(String, nullable=True)
57
    """
58
        An optional pointer to the actual model that should be used when proxying requests.
59
60
    """

61
    name = Column(String)
62
    """
63
        The human-readable display name of the model.
64
65
    """

66
    params = Column(JSONField)
67
    """
68
        Holds a JSON encoded blob of parameters, see `ModelParams`.
69
70
    """

71
    meta = Column(JSONField)
Timothy J. Baek's avatar
Timothy J. Baek committed
72
73
74
75
    """
        Holds a JSON encoded blob of metadata, see `ModelMeta`.
    """

76
77
    updated_at = Column(BigInteger)
    created_at = Column(BigInteger)
78
79
80
81


class ModelModel(BaseModel):
    id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
82
    user_id: str
83
    base_model_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
84

85
    name: str
86
    params: ModelParams
Timothy J. Baek's avatar
Timothy J. Baek committed
87
    meta: ModelMeta
88

Timothy J. Baek's avatar
Timothy J. Baek committed
89
90
91
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch

92
93
    model_config = ConfigDict(from_attributes=True)

94
95
96
97
98
99

####################
# Forms
####################


Timothy J. Baek's avatar
Timothy J. Baek committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class ModelResponse(BaseModel):
    id: str
    name: str
    meta: ModelMeta
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch


class ModelForm(BaseModel):
    id: str
    base_model_id: Optional[str] = None
    name: str
    meta: ModelMeta
    params: ModelParams


116
117
class ModelsTable:

Timothy J. Baek's avatar
Timothy J. Baek committed
118
    def insert_new_model(
119
        self, form_data: ModelForm, user_id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
120
121
122
123
124
125
126
127
128
    ) -> Optional[ModelModel]:
        model = ModelModel(
            **{
                **form_data.model_dump(),
                "user_id": user_id,
                "created_at": int(time.time()),
                "updated_at": int(time.time()),
            }
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
129
        try:
130
131
132
133
134
135
136
137
138
139
            with get_session() as db:
                result = Model(**model.model_dump())
                db.add(result)
                db.commit()
                db.refresh(result)

                if result:
                    return ModelModel.model_validate(result)
                else:
                    return None
Timothy J. Baek's avatar
Timothy J. Baek committed
140
141
        except Exception as e:
            print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
142
143
            return None

144
145
146
    def get_all_models(self) -> List[ModelModel]:
        with get_session() as db:
            return [ModelModel.model_validate(model) for model in db.query(Model).all()]
147

148
    def get_model_by_id(self, id: str) -> Optional[ModelModel]:
149
        try:
150
151
152
            with get_session() as db:
                model = db.get(Model, id)
                return ModelModel.model_validate(model)
Timothy J. Baek's avatar
Timothy J. Baek committed
153
154
        except:
            return None
155

156
    def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
157
158
        try:
            # update only the fields that are present in the model
159
160
161
162
163
164
            with get_session() as db:
                model = db.query(Model).get(id)
                model.update(**model.model_dump())
                db.commit()
                db.refresh(model)
                return ModelModel.model_validate(model)
Timothy J. Baek's avatar
Timothy J. Baek committed
165
166
167
        except Exception as e:
            print(e)

Timothy J. Baek's avatar
Timothy J. Baek committed
168
169
            return None

170
    def delete_model_by_id(self, id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
171
        try:
172
173
            with get_session() as db:
                db.query(Model).filter_by(id=id).delete()
174
            return True
Timothy J. Baek's avatar
Timothy J. Baek committed
175
        except:
176
177
178
            return False


179
Models = ModelsTable()