fork_hook.py 6.03 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import functools
import os
import weakref
from dataclasses import dataclass
from typing import Callable


def _cleanup(hooks, key, wr):
    hooks.pop(key)


class WeakCallbacks:
    """
    A class that manages weak references to callback functions.
    """

    # A dictionary of weak (or strong) references to functions.
    _hooks: dict[int, Callable[[], Callable[..., None] | None]]

    def __init__(self):
        """
        Initialize the registry.
        """
        self._hooks: dict[int, Callable[[], Callable[..., None] | None]] = {}

    def add_hook(self, callable: Callable[..., None], make_persistent: bool = False) -> None:
        """
        Add a callback to the registry.

        Args:
            callable: The function to run before the fork of a worker process.
            make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
        """
        if make_persistent:
            # Not a weakref, but always return the callable.
            self._hooks[id(callable)] = lambda: callable
        elif getattr(callable, "__self__", None):
            # Add a method reference to the hooks
            key = id(callable.__self__)
            self._hooks[key] = weakref.WeakMethod(
                callable, functools.partial(_cleanup, self._hooks, key)
            )
        else:
            # Add a function reference to the hooks
            key = id(callable)
            self._hooks[key] = weakref.ref(callable, functools.partial(_cleanup, self._hooks, key))

    def run(self, *args, **kwargs) -> None:
        """
        Run all the callbacks in the registry, passing the given arguments.
        """
        for hook in self._hooks.values():
            ref = hook()
            if ref is not None:
                ref(*args, **kwargs)


_after_in_child_fork_hooks = WeakCallbacks()
_after_in_parent_fork_hooks = WeakCallbacks()
_before_fork_hooks = WeakCallbacks()


def before_fork_hook(callable: Callable[[], None], make_persistent: bool = False):
    """
    Run function before the fork of a worker process.
    The function must be persistent (i.e. not a lambda) or an instance method.

    Args:
        callable: The function to run before the fork of a worker process.
        make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
    """
    _before_fork_hooks.add_hook(callable, make_persistent)


def after_in_parent_fork_hook(callable: Callable[[], None], make_persistent: bool = False):
    """
    Run function after the fork of a worker process.
    The function must be persistent (i.e. not a lambda) or an instance method.

    Args:
        callable: The function to run after the fork of a worker process.
        make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
    """
    _after_in_parent_fork_hooks.add_hook(callable, make_persistent)


def after_in_child_fork_hook(callable: Callable[[], None], make_persistent: bool = False):
    """
    Run function after the fork of a worker process.
    The function must be persistent (i.e. not a lambda) or an instance method.

    Args:
        callable: The function to run after the fork of a worker process.
        make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
    """
    _after_in_child_fork_hooks.add_hook(callable, make_persistent)


class ForkMixin:
    """
    A mixin that runs a method after the fork of a worker process.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__post_init__()

    def __post_init__(self):
        if getattr(self.__before_fork__, "__func__", None) is not ForkMixin.__before_fork__:
            before_fork_hook(self.__before_fork__)
        if (
            getattr(self.__after_in_child_fork__, "__func__", None)
            is not ForkMixin.__after_in_child_fork__
        ):
            after_in_child_fork_hook(self.__after_in_child_fork__)
        if (
            getattr(self.__after_in_parent_fork__, "__func__", None)
            is not ForkMixin.__after_in_parent_fork__
        ):
            after_in_parent_fork_hook(self.__after_in_parent_fork__)

    def __after_in_child_fork__(self):
        """
        A method that runs after the fork in the child process.
        """
        pass

    def __after_in_parent_fork__(self):
        """
        A method that runs after the fork in the parent process.
        """
        pass

    def __before_fork__(self):
        """
        A method that runs before the fork of a worker process.
        """
        pass


@dataclass
class DataclassForkMixin:
    """
    A mixin that runs a method after the fork of a worker process.
    """

    def __post_init__(self):
        if (
            getattr(self.__before_fork__, "__func__", None)
            is not DataclassForkMixin.__before_fork__
        ):
            before_fork_hook(self.__before_fork__)
        if (
            getattr(self.__after_in_child_fork__, "__func__", None)
            is not DataclassForkMixin.__after_in_child_fork__
        ):
            after_in_child_fork_hook(self.__after_in_child_fork__)
        if (
            getattr(self.__after_in_parent_fork__, "__func__", None)
            is not DataclassForkMixin.__after_in_parent_fork__
        ):
            after_in_parent_fork_hook(self.__after_in_parent_fork__)

    def __after_in_child_fork__(self):
        """
        A method that runs after the fork in the child process.
        """
        pass

    def __after_in_parent_fork__(self):
        """
        A method that runs after the fork in the parent process.
        """
        pass

    def __before_fork__(self):
        """
        A method that runs before the fork of a worker process.
        """
        pass


os.register_at_fork(
    before=_before_fork_hooks.run,
    after_in_child=_after_in_child_fork_hooks.run,
    after_in_parent=_after_in_parent_fork_hooks.run,
)