db.py 1.92 KB
Newer Older
1
import os
2
import logging
3
import json
4
from contextlib import contextmanager
5
6
from typing import Optional, Any
from typing_extensions import Self
Timothy J. Baek's avatar
Timothy J. Baek committed
7

8
9
from sqlalchemy import create_engine, types, Dialect
from sqlalchemy.ext.declarative import declarative_base
10
from sqlalchemy.orm import sessionmaker, scoped_session
11
from sqlalchemy.sql.type_api import _T
12

13
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
14

15
16
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18

19
20
21
22
23
24
25
26
27
28
29
30
31
32
class JSONField(types.TypeDecorator):
    impl = types.Text
    cache_ok = True

    def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
        return json.dumps(value)

    def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any:
        if value is not None:
            return json.loads(value)

    def copy(self, **kw: Any) -> Self:
        return JSONField(self.impl.length)

33
34
35
36
37
38
39
    def db_value(self, value):
        return json.dumps(value)

    def python_value(self, value):
        if value is not None:
            return json.loads(value)

Timothy J. Baek's avatar
Timothy J. Baek committed
40

41
42
43
44
# Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
    # Rename the file
    os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
45
    log.info("Database migrated from Ollama-WebUI successfully.")
46
47
48
else:
    pass

49
50
51
52
53
54
55
SQLALCHEMY_DATABASE_URL = DATABASE_URL
if "sqlite" in SQLALCHEMY_DATABASE_URL:
    engine = create_engine(
        SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
    )
else:
    engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
56
57


58
59
60
SessionLocal = sessionmaker(
    autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
61
Base = declarative_base()
62
Session = scoped_session(SessionLocal)
Timothy J. Baek's avatar
Timothy J. Baek committed
63
64


Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
65
66
67
from contextlib import contextmanager


Timothy J. Baek's avatar
Timothy J. Baek committed
68
# Dependency
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
69
def get_session():
Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
72
73
74
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
75
76
77


get_db = contextmanager(get_session)