wrappers.py 1.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from contextvars import ContextVar

from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError, MySQLDatabase, _ConnectionState
from playhouse.db_url import register_database
from playhouse.pool import PooledPostgresqlDatabase, PooledMySQLDatabase
from playhouse.shortcuts import ReconnectMixin
from psycopg2 import OperationalError
from psycopg2.errors import InterfaceError


db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(_ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


class CustomReconnectMixin(ReconnectMixin):
    reconnect_errors = (
        # default ReconnectMixin exceptions (MySQL specific)
        *ReconnectMixin.reconnect_errors,
        # psycopg2
        (OperationalError, 'termin'),
        (InterfaceError, 'closed'),
        # peewee
        (PeeWeeInterfaceError, 'closed'),
    )


class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
    pass


class ReconnectingPooledPostgresqlDatabase(CustomReconnectMixin, PooledPostgresqlDatabase):
    pass


class ReconnectingMySQLDatabase(CustomReconnectMixin, MySQLDatabase):
    pass


class ReconnectingPooledMySQLDatabase(CustomReconnectMixin, PooledMySQLDatabase):
    pass


def register_peewee_databases():
    register_database(MySQLDatabase, 'mysql')
    register_database(PooledMySQLDatabase, 'mysql+pool')
    register_database(ReconnectingPostgresqlDatabase, 'postgres', 'postgresql')
    register_database(ReconnectingPooledPostgresqlDatabase, 'postgres+pool', 'postgresql+pool')