runs.py 1.73 KB
Newer Older
chenxl's avatar
chenxl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from time import time
from uuid import uuid4

from ktransformers.server.models.assistants.runs import Run
from ktransformers.server.schemas.assistants.runs import RunCreate,RunObject
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.sql_utils import SQLUtil


class RunsDatabaseManager:
    def __init__(self) -> None:
        self.sql_util = SQLUtil()

    def create_run_object(self, thread_id: ObjectID, run: RunCreate) -> RunObject:
        run_obj = RunObject(
            **run.model_dump(mode='json', exclude={"stream"}),
            id=str(uuid4()),
            object='run',
            created_at=int(time()),
            thread_id=thread_id,
            status=RunObject.Status.queued,
        )
        run_obj.set_compute_save(0)
        return run_obj

    def db_create_run(self, thread_id: str, run: RunCreate):
        db_run = Run(
            **run.model_dump(mode="json", exclude={"stream"}),
            id=str(uuid4()),
            created_at=int(time()),
            status="queued",
            thread_id=thread_id,
        )
        with self.sql_util.get_db() as db:
            self.sql_util.db_add_commit_refresh(db, db_run)
            run_obj = RunObject.model_validate(db_run.__dict__)
            run_obj.set_compute_save(0)
        return run_obj

    def db_sync_run(self, run: RunObject) -> None:
        db_run = Run(
            **run.model_dump(mode='json'),
        )
        with self.sql_util.get_db() as db:
            self.sql_util.db_merge_commit(db, db_run)

    def db_get_run(self, run_id: ObjectID) -> RunObject:
        with self.sql_util.get_db() as db:
            db_run = db.query(Run).filter(Run.id == run_id).first()
            return RunObject.model_validate(db_run.__dict__)