wrappers.py 1.58 KB
Newer Older
1
from contextvars import ContextVar
2
3
from peewee import *
from playhouse.db_url import connect
4
from playhouse.pool import PooledPostgresqlDatabase
5
6
7
8
9
from playhouse.shortcuts import ReconnectMixin

db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())

10
class PeeweeConnectionState(object):
11
12
13
14
15
16
17
18
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
19
20
        value = self._state.get()[name]
        return value
21

22
class ReconnectingPostgresqlDatabase(ReconnectMixin, PostgresqlDatabase):
23
24
    pass

25
26
class ReconnectingPooledPostgresqlDatabase(ReconnectMixin, PooledPostgresqlDatabase):
    pass
27

28
class ReconnectingSqliteDatabase(ReconnectMixin, SqliteDatabase):
29
30
31
    pass


32
33
34
35
36
37
38
39
40
41
42
43
44
def register_connection(db_url):
    # Connect using the playhouse.db_url module, which supports multiple 
    # database types, then wrap the connection in a ReconnectMixin to handle dropped connections
    db = connect(db_url)
    if isinstance(db, PostgresqlDatabase):
        db = ReconnectingPostgresqlDatabase(db.database, **db.connect_params)
    elif isinstance(db, PooledPostgresqlDatabase):
        db = ReconnectingPooledPostgresqlDatabase(db.database, **db.connect_params)
    elif isinstance(db, SqliteDatabase):
        db = ReconnectingSqliteDatabase(db.database, **db.connect_params)
    else:
        raise ValueError('Unsupported database connection')
    return db