"vscode:/vscode.git/clone" did not exist on "3dd069c23a03d5f9dada1b517224a126a0705f04"
models.py 4.38 KB
Newer Older
1
import logging
Michael Poluektov's avatar
Michael Poluektov committed
2
from typing import Optional, List
3

4
from pydantic import BaseModel, ConfigDict
Michael Poluektov's avatar
Michael Poluektov committed
5
from sqlalchemy import Column, BigInteger, Text
6

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
7
from apps.webui.internal.db import Base, JSONField, get_db
8

9
10
from config import SRC_LOG_LEVELS

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
import time

13
14
15
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

16
17
18
19
20
21
22
23

####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
24
    model_config = ConfigDict(extra="allow")
25
26
27
28
29
    pass


# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
30
    profile_image_url: Optional[str] = "/static/favicon.png"
Timothy J. Baek's avatar
Timothy J. Baek committed
31

32
    description: Optional[str] = None
33
    """
34
        User-facing description of the model.
35
36
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
37
    capabilities: Optional[dict] = None
38

39
40
41
42
    model_config = ConfigDict(extra="allow")

    pass

43

44
45
46
class Model(Base):
    __tablename__ = "model"

47
    id = Column(Text, primary_key=True)
48
49
50
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """
51
    user_id = Column(Text)
52

53
    base_model_id = Column(Text, nullable=True)
54
    """
55
        An optional pointer to the actual model that should be used when proxying requests.
56
57
    """

58
    name = Column(Text)
59
    """
60
        The human-readable display name of the model.
61
62
    """

63
    params = Column(JSONField)
64
    """
65
        Holds a JSON encoded blob of parameters, see `ModelParams`.
66
67
    """

68
    meta = Column(JSONField)
Timothy J. Baek's avatar
Timothy J. Baek committed
69
70
71
72
    """
        Holds a JSON encoded blob of metadata, see `ModelMeta`.
    """

73
74
    updated_at = Column(BigInteger)
    created_at = Column(BigInteger)
75
76
77
78


class ModelModel(BaseModel):
    id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
79
    user_id: str
80
    base_model_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
81

82
    name: str
83
    params: ModelParams
Timothy J. Baek's avatar
Timothy J. Baek committed
84
    meta: ModelMeta
85

Timothy J. Baek's avatar
Timothy J. Baek committed
86
87
88
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch

89
90
    model_config = ConfigDict(from_attributes=True)

91
92
93
94
95
96

####################
# Forms
####################


Timothy J. Baek's avatar
Timothy J. Baek committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
class ModelResponse(BaseModel):
    id: str
    name: str
    meta: ModelMeta
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch


class ModelForm(BaseModel):
    id: str
    base_model_id: Optional[str] = None
    name: str
    meta: ModelMeta
    params: ModelParams


113
class ModelsTable:
Timothy J. Baek's avatar
Timothy J. Baek committed
114
    def insert_new_model(
115
        self, form_data: ModelForm, user_id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
116
117
118
119
120
121
122
123
124
    ) -> Optional[ModelModel]:
        model = ModelModel(
            **{
                **form_data.model_dump(),
                "user_id": user_id,
                "created_at": int(time.time()),
                "updated_at": int(time.time()),
            }
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
125
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
126
127
128
129
130
131
132
133
134
135
            with get_db() as db:
                result = Model(**model.model_dump())
                db.add(result)
                db.commit()
                db.refresh(result)

                if result:
                    return ModelModel.model_validate(result)
                else:
                    return None
Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
        except Exception as e:
            print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
138
139
            return None

140
    def get_all_models(self) -> List[ModelModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
141
142
        with get_db() as db:
            return [ModelModel.model_validate(model) for model in db.query(Model).all()]
143

144
    def get_model_by_id(self, id: str) -> Optional[ModelModel]:
145
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
146
147
148
            with get_db() as db:
                model = db.get(Model, id)
                return ModelModel.model_validate(model)
Timothy J. Baek's avatar
Timothy J. Baek committed
149
150
        except:
            return None
151

152
    def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
153
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
154
155
            with get_db() as db:
                # update only the fields that are present in the model
Timothy J. Baek's avatar
Timothy J. Baek committed
156
157
158
159
160
                result = (
                    db.query(Model)
                    .filter_by(id=id)
                    .update(model.model_dump(exclude={"id"}, exclude_none=True))
                )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
161
                db.commit()
Timothy J. Baek's avatar
Timothy J. Baek committed
162
163

                model = db.get(Model, id)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
164
165
                db.refresh(model)
                return ModelModel.model_validate(model)
Timothy J. Baek's avatar
Timothy J. Baek committed
166
167
168
        except Exception as e:
            print(e)

Timothy J. Baek's avatar
Timothy J. Baek committed
169
170
            return None

171
    def delete_model_by_id(self, id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
172
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
173
174
            with get_db() as db:
                db.query(Model).filter_by(id=id).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
175
176
                db.commit()

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
177
                return True
Timothy J. Baek's avatar
Timothy J. Baek committed
178
        except:
179
180
181
            return False


182
Models = ModelsTable()