models.py 4.48 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import json
from typing import Optional

import peewee as pw
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel

from apps.web.internal.db import DB


####################
# Models DB Schema
####################


# ModelParams is a model for the data stored in the params field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelParams(BaseModel):
    """
    A Pydantic model that represents the parameters of a model.

    Attributes:
        description (str): A description of the model.
        vision_capable (bool): A flag indicating if the model is capable of vision and thus image inputs.
    """

    description: str
    vision_capable: bool


class Model(pw.Model):
    id = pw.TextField()
    """
        The model's id as used in the API. If set to an existing model, it will override the model.
    """

    source = pw.TextField()
    """
    The source of the model, e.g., ollama, openai, or litellm.
    """

    base_model = pw.TextField(null=True)
    """
    An optional pointer to the actual model that should be used when proxying requests.
    Currently unused - but will be used to support Modelfile like behaviour in the future
    """

    name = pw.TextField()
    """
    The human-readable display name of the model.
    """

    params = pw.TextField()
    """
    Holds a JSON encoded blob of parameters, see `ModelParams`.
    """

    class Meta:
        database = DB

        indexes = (
            # Create a unique index on the id, source columns
            (("id", "source"), True),
        )


class ModelModel(BaseModel):
    id: str
    source: str
    base_model: Optional[str] = None
    name: str
    params: str

    def to_form(self) -> "ModelForm":
        return ModelForm(**{**self.model_dump(), "params": json.loads(self.params)})


####################
# Forms
####################


class ModelForm(BaseModel):
    id: str
    source: str
    base_model: Optional[str] = None
    name: str
    params: dict

    def to_db_model(self) -> ModelModel:
        return ModelModel(**{**self.model_dump(), "params": json.dumps(self.params)})


class ModelsTable:

    def __init__(
        self,
        db: pw.SqliteDatabase | pw.PostgresqlDatabase,
    ):
        self.db = db
        self.db.create_tables([Model])

    def get_all_models(self) -> list[ModelModel]:
        return [ModelModel(**model_to_dict(model)) for model in Model.select()]

    def get_all_models_by_source(self, source: str) -> list[ModelModel]:
        return [
            ModelModel(**model_to_dict(model))
            for model in Model.select().where(Model.source == source)
        ]

    def update_all_models(self, models: list[ModelForm]) -> bool:
        try:
            with self.db.atomic():
                # Fetch current models from the database
                current_models = self.get_all_models()
                current_model_dict = {
                    (model.id, model.source): model for model in current_models
                }

                # Create a set of model IDs and sources from the current models and the new models
                current_model_keys = set(current_model_dict.keys())
                new_model_keys = set((model.id, model.source) for model in models)

                # Determine which models need to be created, updated, or deleted
                models_to_create = [
                    model
                    for model in models
                    if (model.id, model.source) not in current_model_keys
                ]
                models_to_update = [
                    model
                    for model in models
                    if (model.id, model.source) in current_model_keys
                ]
                models_to_delete = current_model_keys - new_model_keys

                # Perform the necessary database operations
                for model in models_to_create:
                    Model.create(**model.to_db_model().model_dump())

                for model in models_to_update:
                    Model.update(**model.to_db_model().model_dump()).where(
                        (Model.id == model.id) & (Model.source == model.source)
                    ).execute()

                for model_id, model_source in models_to_delete:
                    Model.delete().where(
                        (Model.id == model_id) & (Model.source == model_source)
                    ).execute()

            return True
        except Exception as e:
            return False


Models = ModelsTable(DB)