requestscopedpipeline.py 11.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import copy
import threading
from typing import Any, Iterable, List, Optional

import torch

from diffusers.utils import logging

from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps


logger = logging.get_logger(__name__)


def safe_tokenize(tokenizer, *args, lock, **kwargs):
    with lock:
        return tokenizer(*args, **kwargs)


class RequestScopedPipeline:
    DEFAULT_MUTABLE_ATTRS = [
        "_all_hooks",
        "_offload_device",
        "_progress_bar_config",
        "_progress_bar",
        "_rng_state",
        "_last_seed",
        "latents",
    ]

    def __init__(
        self,
        pipeline: Any,
        mutable_attrs: Optional[Iterable[str]] = None,
        auto_detect_mutables: bool = True,
        tensor_numel_threshold: int = 1_000_000,
        tokenizer_lock: Optional[threading.Lock] = None,
        wrap_scheduler: bool = True,
    ):
        self._base = pipeline
        self.unet = getattr(pipeline, "unet", None)
        self.vae = getattr(pipeline, "vae", None)
        self.text_encoder = getattr(pipeline, "text_encoder", None)
        self.components = getattr(pipeline, "components", None)

        if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
            if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
                pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)

        self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
        self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()

        self._auto_detect_mutables = bool(auto_detect_mutables)
        self._tensor_numel_threshold = int(tensor_numel_threshold)

        self._auto_detected_attrs: List[str] = []

    def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
        base_sched = getattr(self._base, "scheduler", None)
        if base_sched is None:
            return None

        if not isinstance(base_sched, BaseAsyncScheduler):
            wrapped_scheduler = BaseAsyncScheduler(base_sched)
        else:
            wrapped_scheduler = base_sched

        try:
            return wrapped_scheduler.clone_for_request(
                num_inference_steps=num_inference_steps, device=device, **clone_kwargs
            )
        except Exception as e:
            logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
            try:
                return copy.deepcopy(wrapped_scheduler)
            except Exception as e:
                logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
                return wrapped_scheduler

    def _autodetect_mutables(self, max_attrs: int = 40):
        if not self._auto_detect_mutables:
            return []

        if self._auto_detected_attrs:
            return self._auto_detected_attrs

        candidates: List[str] = []
        seen = set()
        for name in dir(self._base):
            if name.startswith("__"):
                continue
            if name in self._mutable_attrs:
                continue
            if name in ("to", "save_pretrained", "from_pretrained"):
                continue
            try:
                val = getattr(self._base, name)
            except Exception:
                continue

            import types

            # skip callables and modules
            if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
                continue

            # containers -> candidate
            if isinstance(val, (dict, list, set, tuple, bytearray)):
                candidates.append(name)
                seen.add(name)
            else:
                # try Tensor detection
                try:
                    if isinstance(val, torch.Tensor):
                        if val.numel() <= self._tensor_numel_threshold:
                            candidates.append(name)
                            seen.add(name)
                        else:
                            logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
                except Exception:
                    continue

            if len(candidates) >= max_attrs:
                break

        self._auto_detected_attrs = candidates
        logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
        return self._auto_detected_attrs

    def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
        try:
            cls = type(base_obj)
            descriptor = getattr(cls, attr_name, None)
            if isinstance(descriptor, property):
                return descriptor.fset is None
            if hasattr(descriptor, "__set__") is False and descriptor is not None:
                return False
        except Exception:
            pass
        return False

    def _clone_mutable_attrs(self, base, local):
        attrs_to_clone = list(self._mutable_attrs)
        attrs_to_clone.extend(self._autodetect_mutables())

        EXCLUDE_ATTRS = {
            "components",
        }

        for attr in attrs_to_clone:
            if attr in EXCLUDE_ATTRS:
                logger.debug(f"Skipping excluded attr '{attr}'")
                continue
            if not hasattr(base, attr):
                continue
            if self._is_readonly_property(base, attr):
                logger.debug(f"Skipping read-only property '{attr}'")
                continue

            try:
                val = getattr(base, attr)
            except Exception as e:
                logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
                continue

            try:
                if isinstance(val, dict):
                    setattr(local, attr, dict(val))
                elif isinstance(val, (list, tuple, set)):
                    setattr(local, attr, list(val))
                elif isinstance(val, bytearray):
                    setattr(local, attr, bytearray(val))
                else:
                    # small tensors or atomic values
                    if isinstance(val, torch.Tensor):
                        if val.numel() <= self._tensor_numel_threshold:
                            setattr(local, attr, val.clone())
                        else:
                            # don't clone big tensors, keep reference
                            setattr(local, attr, val)
                    else:
                        try:
                            setattr(local, attr, copy.copy(val))
                        except Exception:
                            setattr(local, attr, val)
            except (AttributeError, TypeError) as e:
                logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
                continue
            except Exception as e:
                logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
                continue

    def _is_tokenizer_component(self, component) -> bool:
        if component is None:
            return False

        tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
        has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)

        class_name = component.__class__.__name__.lower()
        has_tokenizer_in_name = "tokenizer" in class_name

        tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
        has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)

        return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)

    def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
        local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)

        try:
            local_pipe = copy.copy(self._base)
        except Exception as e:
            logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
            local_pipe = copy.deepcopy(self._base)

        if local_scheduler is not None:
            try:
                timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
                    local_scheduler.scheduler,
                    num_inference_steps=num_inference_steps,
                    device=device,
                    return_scheduler=True,
                    **{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
                )

                final_scheduler = BaseAsyncScheduler(configured_scheduler)
                setattr(local_pipe, "scheduler", final_scheduler)
            except Exception:
                logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")

        self._clone_mutable_attrs(self._base, local_pipe)

        # 4) wrap tokenizers on the local pipe with the lock wrapper
        tokenizer_wrappers = {}  # name -> original_tokenizer
        try:
            # a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
            for name in dir(local_pipe):
                if "tokenizer" in name and not name.startswith("_"):
                    tok = getattr(local_pipe, name, None)
                    if tok is not None and self._is_tokenizer_component(tok):
                        tokenizer_wrappers[name] = tok
                        setattr(
                            local_pipe,
                            name,
                            lambda *args, tok=tok, **kwargs: safe_tokenize(
                                tok, *args, lock=self._tokenizer_lock, **kwargs
                            ),
                        )

            # b) wrap tokenizers in components dict
            if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
                for key, val in local_pipe.components.items():
                    if val is None:
                        continue

                    if self._is_tokenizer_component(val):
                        tokenizer_wrappers[f"components[{key}]"] = val
                        local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
                            tokenizer, *args, lock=self._tokenizer_lock, **kwargs
                        )

        except Exception as e:
            logger.debug(f"Tokenizer wrapping step encountered an error: {e}")

        result = None
        cm = getattr(local_pipe, "model_cpu_offload_context", None)
        try:
            if callable(cm):
                try:
                    with cm():
                        result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
                except TypeError:
                    # cm might be a context manager instance rather than callable
                    try:
                        with cm:
                            result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
                    except Exception as e:
                        logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
                        result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
            else:
                # no offload context available — call directly
                result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)

            return result

        finally:
            try:
                for name, tok in tokenizer_wrappers.items():
                    if name.startswith("components["):
                        key = name[len("components[") : -1]
                        local_pipe.components[key] = tok
                    else:
                        setattr(local_pipe, name, tok)
            except Exception as e:
                logger.debug(f"Error restoring wrapped tokenizers: {e}")