models.py 4.09 KB
Newer Older
1
import json
2
import logging
3
4
from typing import Optional

5
from pydantic import BaseModel, ConfigDict
6
from sqlalchemy import String, Column, BigInteger, Text
7

8
from apps.webui.internal.db import Base, JSONField, Session
9

Timothy J. Baek's avatar
Timothy J. Baek committed
10
from typing import List, Union, Optional
11
12
from config import SRC_LOG_LEVELS

Timothy J. Baek's avatar
Timothy J. Baek committed
13
14
import time

15
16
17
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

18
19
20
21
22
23
24
25

####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
26
    model_config = ConfigDict(extra="allow")
27
28
29
30
31
    pass


# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
    profile_image_url: Optional[str] = "/favicon.png"

34
    description: Optional[str] = None
35
    """
36
        User-facing description of the model.
37
38
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
39
    capabilities: Optional[dict] = None
40

41
42
43
44
    model_config = ConfigDict(extra="allow")

    pass

45

46
47
48
class Model(Base):
    __tablename__ = "model"

49
    id = Column(Text, primary_key=True)
50
51
52
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """
53
    user_id = Column(Text)
54

55
    base_model_id = Column(Text, nullable=True)
56
    """
57
        An optional pointer to the actual model that should be used when proxying requests.
58
59
    """

60
    name = Column(Text)
61
    """
62
        The human-readable display name of the model.
63
64
    """

65
    params = Column(JSONField)
66
    """
67
        Holds a JSON encoded blob of parameters, see `ModelParams`.
68
69
    """

70
    meta = Column(JSONField)
Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
73
74
    """
        Holds a JSON encoded blob of metadata, see `ModelMeta`.
    """

75
76
    updated_at = Column(BigInteger)
    created_at = Column(BigInteger)
77
78
79
80


class ModelModel(BaseModel):
    id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
81
    user_id: str
82
    base_model_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
83

84
    name: str
85
    params: ModelParams
Timothy J. Baek's avatar
Timothy J. Baek committed
86
    meta: ModelMeta
87

Timothy J. Baek's avatar
Timothy J. Baek committed
88
89
90
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch

91
92
    model_config = ConfigDict(from_attributes=True)

93
94
95
96
97
98

####################
# Forms
####################


Timothy J. Baek's avatar
Timothy J. Baek committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class ModelResponse(BaseModel):
    id: str
    name: str
    meta: ModelMeta
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch


class ModelForm(BaseModel):
    id: str
    base_model_id: Optional[str] = None
    name: str
    meta: ModelMeta
    params: ModelParams


115
116
class ModelsTable:

Timothy J. Baek's avatar
Timothy J. Baek committed
117
    def insert_new_model(
118
        self, form_data: ModelForm, user_id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
119
120
121
122
123
124
125
126
127
    ) -> Optional[ModelModel]:
        model = ModelModel(
            **{
                **form_data.model_dump(),
                "user_id": user_id,
                "created_at": int(time.time()),
                "updated_at": int(time.time()),
            }
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
128
        try:
129
130
131
132
133
134
135
136
137
            result = Model(**model.model_dump())
            Session.add(result)
            Session.commit()
            Session.refresh(result)

            if result:
                return ModelModel.model_validate(result)
            else:
                return None
Timothy J. Baek's avatar
Timothy J. Baek committed
138
139
        except Exception as e:
            print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
140
141
            return None

142
    def get_all_models(self) -> List[ModelModel]:
143
144
145
        return [
            ModelModel.model_validate(model) for model in Session.query(Model).all()
        ]
146

147
    def get_model_by_id(self, id: str) -> Optional[ModelModel]:
148
        try:
149
150
            model = Session.get(Model, id)
            return ModelModel.model_validate(model)
Timothy J. Baek's avatar
Timothy J. Baek committed
151
152
        except:
            return None
153

154
    def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
155
156
        try:
            # update only the fields that are present in the model
157
158
159
160
161
            model = Session.query(Model).get(id)
            model.update(**model.model_dump())
            Session.commit()
            Session.refresh(model)
            return ModelModel.model_validate(model)
Timothy J. Baek's avatar
Timothy J. Baek committed
162
163
164
        except Exception as e:
            print(e)

Timothy J. Baek's avatar
Timothy J. Baek committed
165
166
            return None

167
    def delete_model_by_id(self, id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
168
        try:
169
            Session.query(Model).filter_by(id=id).delete()
170
            return True
Timothy J. Baek's avatar
Timothy J. Baek committed
171
        except:
172
173
174
            return False


175
Models = ModelsTable()