models.py 4.4 KB
Newer Older
1
import logging
Michael Poluektov's avatar
Michael Poluektov committed
2
from typing import Optional, List
3

4
from pydantic import BaseModel, ConfigDict
Michael Poluektov's avatar
Michael Poluektov committed
5
from sqlalchemy import Column, BigInteger, Text
6

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
7
from apps.webui.internal.db import Base, JSONField, get_db
8

9
10
from config import SRC_LOG_LEVELS

Timothy J. Baek's avatar
Timothy J. Baek committed
11
12
import time

13
14
15
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

16
17
18
19
20
21
22
23

####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
24
    model_config = ConfigDict(extra="allow")
25
26
27
28
29
    pass


# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
30
    profile_image_url: Optional[str] = "/static/favicon.png"
Timothy J. Baek's avatar
Timothy J. Baek committed
31

32
    description: Optional[str] = None
33
    """
34
        User-facing description of the model.
35
36
    """

Timothy J. Baek's avatar
Timothy J. Baek committed
37
    capabilities: Optional[dict] = None
38

39
40
41
42
    model_config = ConfigDict(extra="allow")

    pass

43

44
45
46
class Model(Base):
    __tablename__ = "model"

47
    id = Column(Text, primary_key=True)
48
49
50
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """
51
    user_id = Column(Text)
52

53
    base_model_id = Column(Text, nullable=True)
54
    """
55
        An optional pointer to the actual model that should be used when proxying requests.
56
57
    """

58
    name = Column(Text)
59
    """
60
        The human-readable display name of the model.
61
62
    """

63
    params = Column(JSONField)
64
    """
65
        Holds a JSON encoded blob of parameters, see `ModelParams`.
66
67
    """

68
    meta = Column(JSONField)
Timothy J. Baek's avatar
Timothy J. Baek committed
69
70
71
72
    """
        Holds a JSON encoded blob of metadata, see `ModelMeta`.
    """

73
74
    updated_at = Column(BigInteger)
    created_at = Column(BigInteger)
75
76
77
78


class ModelModel(BaseModel):
    id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
79
    user_id: str
80
    base_model_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
81

82
    name: str
83
    params: ModelParams
Timothy J. Baek's avatar
Timothy J. Baek committed
84
    meta: ModelMeta
85

Timothy J. Baek's avatar
Timothy J. Baek committed
86
87
88
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch

89
90
    model_config = ConfigDict(from_attributes=True)

91
92
93
94
95
96

####################
# Forms
####################


Timothy J. Baek's avatar
Timothy J. Baek committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
class ModelResponse(BaseModel):
    id: str
    name: str
    meta: ModelMeta
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch


class ModelForm(BaseModel):
    id: str
    base_model_id: Optional[str] = None
    name: str
    meta: ModelMeta
    params: ModelParams


113
class ModelsTable:
Timothy J. Baek's avatar
Timothy J. Baek committed
114
    def insert_new_model(
115
        self, form_data: ModelForm, user_id: str
Timothy J. Baek's avatar
Timothy J. Baek committed
116
117
118
119
120
121
122
123
124
    ) -> Optional[ModelModel]:
        model = ModelModel(
            **{
                **form_data.model_dump(),
                "user_id": user_id,
                "created_at": int(time.time()),
                "updated_at": int(time.time()),
            }
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
125
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
126
127
128
129
130
131
132
133
134
135
            with get_db() as db:
                result = Model(**model.model_dump())
                db.add(result)
                db.commit()
                db.refresh(result)

                if result:
                    return ModelModel.model_validate(result)
                else:
                    return None
Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
        except Exception as e:
            print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
138
139
            return None

Michael Poluektov's avatar
Michael Poluektov committed
140
    def get_all_models(self) -> list[ModelModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
141
142
        with get_db() as db:
            return [ModelModel.model_validate(model) for model in db.query(Model).all()]
143

144
    def get_model_by_id(self, id: str) -> Optional[ModelModel]:
145
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
146
147
148
            with get_db() as db:
                model = db.get(Model, id)
                return ModelModel.model_validate(model)
149
        except Exception:
Timothy J. Baek's avatar
Timothy J. Baek committed
150
            return None
151

152
    def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
153
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
154
155
            with get_db() as db:
                # update only the fields that are present in the model
Timothy J. Baek's avatar
Timothy J. Baek committed
156
157
158
159
160
                result = (
                    db.query(Model)
                    .filter_by(id=id)
                    .update(model.model_dump(exclude={"id"}, exclude_none=True))
                )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
161
                db.commit()
Timothy J. Baek's avatar
Timothy J. Baek committed
162
163

                model = db.get(Model, id)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
164
165
                db.refresh(model)
                return ModelModel.model_validate(model)
Timothy J. Baek's avatar
Timothy J. Baek committed
166
167
168
        except Exception as e:
            print(e)

Timothy J. Baek's avatar
Timothy J. Baek committed
169
170
            return None

171
    def delete_model_by_id(self, id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
172
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
173
174
            with get_db() as db:
                db.query(Model).filter_by(id=id).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
175
176
                db.commit()

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
177
                return True
178
        except Exception:
179
180
181
            return False


182
Models = ModelsTable()