components_manager.py 44.9 KB
Newer Older
YiYi Xu's avatar
YiYi Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import time
from collections import OrderedDict
from itertools import combinations
from typing import Any, Dict, List, Optional, Union

import torch

from ..hooks import ModelHook
from ..utils import (
    is_accelerate_available,
    logging,
)
Yao Matrix's avatar
Yao Matrix committed
28
from ..utils.torch_utils import get_device
YiYi Xu's avatar
YiYi Xu committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164


if is_accelerate_available():
    from accelerate.hooks import add_hook_to_module, remove_hook_from_module
    from accelerate.state import PartialState
    from accelerate.utils import send_to_device
    from accelerate.utils.memory import clear_device_cache
    from accelerate.utils.modeling import convert_file_size_to_int

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class CustomOffloadHook(ModelHook):
    """
    A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are
    on the given device. Optionally offloads other models to the CPU before the forward pass is called.

    Args:
        execution_device(`str`, `int` or `torch.device`, *optional*):
            The device on which the model should be executed. Will default to the MPS device if it's available, then
            GPU 0 if there is a GPU, and finally to the CPU.
    """

    no_grad = False

    def __init__(
        self,
        execution_device: Optional[Union[str, int, torch.device]] = None,
        other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
        offload_strategy: Optional["AutoOffloadStrategy"] = None,
    ):
        self.execution_device = execution_device if execution_device is not None else PartialState().default_device
        self.other_hooks = other_hooks
        self.offload_strategy = offload_strategy
        self.model_id = None

    def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
        self.offload_strategy = offload_strategy

    def add_other_hook(self, hook: "UserCustomOffloadHook"):
        """
        Add a hook to the list of hooks to consider for offloading.
        """
        if self.other_hooks is None:
            self.other_hooks = []
        self.other_hooks.append(hook)

    def init_hook(self, module):
        return module.to("cpu")

    def pre_forward(self, module, *args, **kwargs):
        if module.device != self.execution_device:
            if self.other_hooks is not None:
                hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
                # offload all other hooks
                start_time = time.perf_counter()
                if self.offload_strategy is not None:
                    hooks_to_offload = self.offload_strategy(
                        hooks=hooks_to_offload,
                        model_id=self.model_id,
                        model=module,
                        execution_device=self.execution_device,
                    )
                end_time = time.perf_counter()
                logger.info(
                    f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds"
                )

                for hook in hooks_to_offload:
                    logger.info(
                        f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu"
                    )
                    hook.offload()

                if hooks_to_offload:
                    clear_device_cache()
            module.to(self.execution_device)
        return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)


class UserCustomOffloadHook:
    """
    A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
    the hook or remove it entirely.
    """

    def __init__(self, model_id, model, hook):
        self.model_id = model_id
        self.model = model
        self.hook = hook

    def offload(self):
        self.hook.init_hook(self.model)

    def attach(self):
        add_hook_to_module(self.model, self.hook)
        self.hook.model_id = self.model_id

    def remove(self):
        remove_hook_from_module(self.model)
        self.hook.model_id = None

    def add_other_hook(self, hook: "UserCustomOffloadHook"):
        self.hook.add_other_hook(hook)


def custom_offload_with_hook(
    model_id: str,
    model: torch.nn.Module,
    execution_device: Union[str, int, torch.device] = None,
    offload_strategy: Optional["AutoOffloadStrategy"] = None,
):
    hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy)
    user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook)
    user_hook.attach()
    return user_hook


# this is the class that user can customize to implement their own offload strategy
class AutoOffloadStrategy:
    """
    Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
    the available memory on the device.
    """

    # YiYi TODO: instead of memory_reserve_margin, we should let user set the maximum_total_models_size to keep on device
    # the actual memory usage would be higher. But it's simpler this way, and can be tested
    def __init__(self, memory_reserve_margin="3GB"):
        self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin)

    def __call__(self, hooks, model_id, model, execution_device):
        if len(hooks) == 0:
            return []

        current_module_size = model.get_memory_footprint()

Yao Matrix's avatar
Yao Matrix committed
165
166
167
        device_type = execution_device.type
        device_module = getattr(torch, device_type, torch.cuda)
        mem_on_device = device_module.mem_get_info(execution_device.index)[0]
YiYi Xu's avatar
YiYi Xu committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        mem_on_device = mem_on_device - self.memory_reserve_margin
        if current_module_size < mem_on_device:
            return []

        min_memory_offload = current_module_size - mem_on_device
        logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory")

        # exlucde models that's not currently loaded on the device
        module_sizes = dict(
            sorted(
                {hook.model_id: hook.model.get_memory_footprint() for hook in hooks}.items(),
                key=lambda x: x[1],
                reverse=True,
            )
        )

        # YiYi/Dhruv TODO: sort smallest to largest, and offload in that order we would tend to keep the larger models on GPU more often
        def search_best_candidate(module_sizes, min_memory_offload):
            """
            search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a
            minimum memory offload size. the combination of models should add up to the smallest modulesize that is
            larger than `min_memory_offload`
            """
            model_ids = list(module_sizes.keys())
            best_candidate = None
            best_size = float("inf")
            for r in range(1, len(model_ids) + 1):
                for candidate_model_ids in combinations(model_ids, r):
                    candidate_size = sum(
                        module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids
                    )
                    if candidate_size < min_memory_offload:
                        continue
                    else:
                        if best_candidate is None or candidate_size < best_size:
                            best_candidate = candidate_model_ids
                            best_size = candidate_size

            return best_candidate

        best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload)

        if best_offload_model_ids is None:
            # if no combination is found, meaning that we cannot meet the memory requirement, offload all models
            logger.warning("no combination of models to offload to cpu is found, offloading all models")
            hooks_to_offload = hooks
        else:
            hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids]

        return hooks_to_offload


# utils for display component info in a readable format
# TODO: move to a different file
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
    """Summarizes a dictionary by finding common prefixes that share the same value.

    For a dictionary with dot-separated keys like: {
        'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
        'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
        'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
    }

    Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
        'down_blocks': [0.6], 'up_blocks': [0.3]
    }
    """
    # First group by values - convert lists to tuples to make them hashable
    value_to_keys = {}
    for key, value in d.items():
        value_tuple = tuple(value) if isinstance(value, list) else value
        if value_tuple not in value_to_keys:
            value_to_keys[value_tuple] = []
        value_to_keys[value_tuple].append(key)

    def find_common_prefix(keys: List[str]) -> str:
        """Find the shortest common prefix among a list of dot-separated keys."""
        if not keys:
            return ""
        if len(keys) == 1:
            return keys[0]

        # Split all keys into parts
        key_parts = [k.split(".") for k in keys]

        # Find how many initial parts are common
        common_length = 0
        for parts in zip(*key_parts):
            if len(set(parts)) == 1:  # All parts at this position are the same
                common_length += 1
            else:
                break

        if common_length == 0:
            return ""

        # Return the common prefix
        return ".".join(key_parts[0][:common_length])

    # Create summary by finding common prefixes for each value group
    summary = {}
    for value_tuple, keys in value_to_keys.items():
        prefix = find_common_prefix(keys)
        if prefix:  # Only add if we found a common prefix
            # Convert tuple back to list if it was originally a list
            value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
            summary[prefix] = value
        else:
            summary[""] = value  # Use empty string if no common prefix

    return summary


class ComponentsManager:
    """
    A central registry and management system for model components across multiple pipelines.

    [`ComponentsManager`] provides a unified way to register, track, and reuse model components (like UNet, VAE, text
    encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory
    management, and component organization.

    <Tip warning={true}>

        This is an experimental feature and is likely to change in the future.

    </Tip>

    Example:
        ```python
        from diffusers import ComponentsManager

        # Create a components manager
        cm = ComponentsManager()

        # Add components
        cm.add("unet", unet_model, collection="sdxl")
        cm.add("vae", vae_model, collection="sdxl")

        # Enable auto offloading
Yao Matrix's avatar
Yao Matrix committed
307
        cm.enable_auto_cpu_offload()
YiYi Xu's avatar
YiYi Xu committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

        # Retrieve components
        unet = cm.get_one(name="unet", collection="sdxl")
        ```
    """

    _available_info_fields = [
        "model_id",
        "added_time",
        "collection",
        "class_name",
        "size_gb",
        "adapters",
        "has_hook",
        "execution_device",
        "ip_adapter",
    ]

    def __init__(self):
        self.components = OrderedDict()
        # YiYi TODO: can remove once confirm we don't need this in mellon
        self.added_time = OrderedDict()  # Store when components were added
        self.collections = OrderedDict()  # collection_name -> set of component_names
        self.model_hooks = None
        self._auto_offload_enabled = False

    def _lookup_ids(
        self,
        name: Optional[str] = None,
        collection: Optional[str] = None,
        load_id: Optional[str] = None,
        components: Optional[OrderedDict] = None,
    ):
        """
        Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of
        component_ids
        """
        if components is None:
            components = self.components

        if name:
            ids_by_name = set()
            for component_id, component in components.items():
                comp_name = self._id_to_name(component_id)
                if comp_name == name:
                    ids_by_name.add(component_id)
        else:
            ids_by_name = set(components.keys())
        if collection:
            ids_by_collection = set()
            for component_id, component in components.items():
                if component_id in self.collections[collection]:
                    ids_by_collection.add(component_id)
        else:
            ids_by_collection = set(components.keys())
        if load_id:
            ids_by_load_id = set()
            for name, component in components.items():
                if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
                    ids_by_load_id.add(name)
        else:
            ids_by_load_id = set(components.keys())

        ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id)
        return ids

    @staticmethod
    def _id_to_name(component_id: str):
        return "_".join(component_id.split("_")[:-1])

    def add(self, name: str, component: Any, collection: Optional[str] = None):
        """
        Add a component to the ComponentsManager.

        Args:
            name (str): The name of the component
            component (Any): The component to add
            collection (Optional[str]): The collection to add the component to

        Returns:
            str: The unique component ID, which is generated as "{name}_{id(component)}" where
                 id(component) is Python's built-in unique identifier for the object
        """
        component_id = f"{name}_{id(component)}"
392
        is_new_component = True
YiYi Xu's avatar
YiYi Xu committed
393
394
395
396
397
398
399
400

        # check for duplicated components
        for comp_id, comp in self.components.items():
            if comp == component:
                comp_name = self._id_to_name(comp_id)
                if comp_name == name:
                    logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'")
                    component_id = comp_id
401
                    is_new_component = False
YiYi Xu's avatar
YiYi Xu committed
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
                    break
                else:
                    logger.warning(
                        f"ComponentsManager: adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
                        f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
                    )

        # check for duplicated load_id and warn (we do not delete for you)
        if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
            components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id)
            components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id]

            if components_with_same_load_id:
                existing = ", ".join(components_with_same_load_id)
                logger.warning(
                    f"ComponentsManager: adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
                    f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
                )

        # add component to components manager
        self.components[component_id] = component
        self.added_time[component_id] = time.time()

        if collection:
            if collection not in self.collections:
                self.collections[collection] = set()
            if component_id not in self.collections[collection]:
                comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
                for comp_id in comp_ids_in_collection:
                    logger.warning(
                        f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
                    )
434
435
436
                    # remove existing component from this collection (if it is not in any other collection, will be removed from ComponentsManager)
                    self.remove_from_collection(comp_id, collection)

YiYi Xu's avatar
YiYi Xu committed
437
438
439
440
441
442
443
                self.collections[collection].add(component_id)
                logger.info(
                    f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
                )
        else:
            logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")

444
        if self._auto_offload_enabled and is_new_component:
YiYi Xu's avatar
YiYi Xu committed
445
446
447
448
            self.enable_auto_cpu_offload(self._auto_offload_device)

        return component_id

449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    def remove_from_collection(self, component_id: str, collection: str):
        """
        Remove a component from a collection.
        """
        if collection not in self.collections:
            logger.warning(f"Collection '{collection}' not found in ComponentsManager")
            return
        if component_id not in self.collections[collection]:
            logger.warning(f"Component '{component_id}' not found in collection '{collection}'")
            return
        # remove from the collection
        self.collections[collection].remove(component_id)
        # check if this component is in any other collection
        comp_colls = [coll for coll, comps in self.collections.items() if component_id in comps]
        if not comp_colls:  # only if no other collection contains this component, remove it
            logger.warning(f"ComponentsManager: removing component '{component_id}' from ComponentsManager")
            self.remove(component_id)

YiYi Xu's avatar
YiYi Xu committed
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    def remove(self, component_id: str = None):
        """
        Remove a component from the ComponentsManager.

        Args:
            component_id (str): The ID of the component to remove
        """
        if component_id not in self.components:
            logger.warning(f"Component '{component_id}' not found in ComponentsManager")
            return

        component = self.components.pop(component_id)
        self.added_time.pop(component_id)

        for collection in self.collections:
            if component_id in self.collections[collection]:
                self.collections[collection].remove(component_id)

        if self._auto_offload_enabled:
            self.enable_auto_cpu_offload(self._auto_offload_device)
        else:
            if isinstance(component, torch.nn.Module):
                component.to("cpu")
            del component
            import gc

            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
Yao Matrix's avatar
Yao Matrix committed
496
497
            if torch.xpu.is_available():
                torch.xpu.empty_cache()
YiYi Xu's avatar
YiYi Xu committed
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685

    # YiYi TODO: rename to search_components for now, may remove this method
    def search_components(
        self,
        names: Optional[str] = None,
        collection: Optional[str] = None,
        load_id: Optional[str] = None,
        return_dict_with_names: bool = True,
    ):
        """
        Search components by name with simple pattern matching. Optionally filter by collection or load_id.

        Args:
            names: Component name(s) or pattern(s)
                Patterns:
                - "unet" : match any component with base name "unet" (e.g., unet_123abc)
                - "!unet" : everything except components with base name "unet"
                - "unet*" : anything with base name starting with "unet"
                - "!unet*" : anything with base name NOT starting with "unet"
                - "*unet*" : anything with base name containing "unet"
                - "!*unet*" : anything with base name NOT containing "unet"
                - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet"
                - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet"
                - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
            collection: Optional collection to filter by
            load_id: Optional load_id to filter by
            return_dict_with_names:
                                    If True, returns a dictionary with component names as keys, throw an error if
                                    multiple components with the same name are found If False, returns a dictionary
                                    with component IDs as keys

        Returns:
            Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping
            component IDs to components if return_dict_with_names=False
        """

        # select components based on collection and load_id filters
        selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
        components = {k: self.components[k] for k in selected_ids}

        def get_return_dict(components, return_dict_with_names):
            """
            Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary
            mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component
            names are found when return_dict_with_names=True
            """
            if return_dict_with_names:
                dict_to_return = {}
                for comp_id, comp in components.items():
                    comp_name = self._id_to_name(comp_id)
                    if comp_name in dict_to_return:
                        raise ValueError(
                            f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
                        )
                    dict_to_return[comp_name] = comp
                return dict_to_return
            else:
                return components

        # if no names are provided, return the filtered components as it is
        if names is None:
            return get_return_dict(components, return_dict_with_names)

        # if names is not a string, raise an error
        elif not isinstance(names, str):
            raise ValueError(f"Invalid type for `names: {type(names)}, only support string")

        # Create mapping from component_id to base_name for components to be used for pattern matching
        base_names = {comp_id: self._id_to_name(comp_id) for comp_id in components.keys()}

        # Helper function to check if a component matches a pattern based on its base name
        def matches_pattern(component_id, pattern, exact_match=False):
            """
            Helper function to check if a component matches a pattern based on its base name.

            Args:
                component_id: The component ID to check
                pattern: The pattern to match against
                exact_match: If True, only exact matches to base_name are considered
            """
            base_name = base_names[component_id]

            # Exact match with base name
            if exact_match:
                return pattern == base_name

            # Prefix match (ends with *)
            elif pattern.endswith("*"):
                prefix = pattern[:-1]
                return base_name.startswith(prefix)

            # Contains match (starts with *)
            elif pattern.startswith("*"):
                search = pattern[1:-1] if pattern.endswith("*") else pattern[1:]
                return search in base_name

            # Exact match (no wildcards)
            else:
                return pattern == base_name

        # Check if this is a "not" pattern
        is_not_pattern = names.startswith("!")
        if is_not_pattern:
            names = names[1:]  # Remove the ! prefix

        # Handle OR patterns (containing |)
        if "|" in names:
            terms = names.split("|")
            matches = {}

            for comp_id, comp in components.items():
                # For OR patterns with exact names (no wildcards), we do exact matching on base names
                exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms)

                # Check if any of the terms match this component
                should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)

                # Flip the decision if this is a NOT pattern
                if is_not_pattern:
                    should_include = not should_include

                if should_include:
                    matches[comp_id] = comp

            log_msg = "NOT " if is_not_pattern else ""
            match_type = "exactly matching" if exact_match else "matching any of patterns"
            logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")

        # Try exact match with a base name
        elif any(names == base_name for base_name in base_names.values()):
            # Find all components with this base name
            matches = {
                comp_id: comp
                for comp_id, comp in components.items()
                if (base_names[comp_id] == names) != is_not_pattern
            }

            if is_not_pattern:
                logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
            else:
                logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")

        # Prefix match (ends with *)
        elif names.endswith("*"):
            prefix = names[:-1]
            matches = {
                comp_id: comp
                for comp_id, comp in components.items()
                if base_names[comp_id].startswith(prefix) != is_not_pattern
            }
            if is_not_pattern:
                logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
            else:
                logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")

        # Contains match (starts with *)
        elif names.startswith("*"):
            search = names[1:-1] if names.endswith("*") else names[1:]
            matches = {
                comp_id: comp
                for comp_id, comp in components.items()
                if (search in base_names[comp_id]) != is_not_pattern
            }
            if is_not_pattern:
                logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
            else:
                logger.info(f"Getting components containing '{search}': {list(matches.keys())}")

        # Substring match (no wildcards, but not an exact component name)
        elif any(names in base_name for base_name in base_names.values()):
            matches = {
                comp_id: comp
                for comp_id, comp in components.items()
                if (names in base_names[comp_id]) != is_not_pattern
            }
            if is_not_pattern:
                logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
            else:
                logger.info(f"Getting components containing '{names}': {list(matches.keys())}")

        else:
            raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")

        if not matches:
            raise ValueError(f"No components found matching pattern '{names}'")

        return get_return_dict(matches, return_dict_with_names)

Yao Matrix's avatar
Yao Matrix committed
686
    def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, memory_reserve_margin="3GB"):
YiYi Xu's avatar
YiYi Xu committed
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
        """
        Enable automatic CPU offloading for all components.

        The algorithm works as follows:
        1. All models start on CPU by default
        2. When a model's forward pass is called, it's moved to the execution device
        3. If there's insufficient memory, other models on the device are moved back to CPU
        4. The system tries to offload the smallest combination of models that frees enough memory
        5. Models stay on the execution device until another model needs memory and forces them off

        Args:
            device (Union[str, int, torch.device]): The execution device where models are moved for forward passes
            memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of
                                        memory to keep free on the device to avoid running out of memory during model
                                        execution (e.g., for intermediate activations, gradients, etc.)
        """
        if not is_accelerate_available():
            raise ImportError("Make sure to install accelerate to use auto_cpu_offload")

        for name, component in self.components.items():
            if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
                remove_hook_from_module(component, recurse=True)

        self.disable_auto_cpu_offload()
        offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
Yao Matrix's avatar
Yao Matrix committed
712
713
        if device is None:
            device = get_device()
YiYi Xu's avatar
YiYi Xu committed
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
        device = torch.device(device)
        if device.index is None:
            device = torch.device(f"{device.type}:{0}")
        all_hooks = []
        for name, component in self.components.items():
            if isinstance(component, torch.nn.Module):
                hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy)
                all_hooks.append(hook)

        for hook in all_hooks:
            other_hooks = [h for h in all_hooks if h is not hook]
            for other_hook in other_hooks:
                if other_hook.hook.execution_device == hook.hook.execution_device:
                    hook.add_other_hook(other_hook)

        self.model_hooks = all_hooks
        self._auto_offload_enabled = True
        self._auto_offload_device = device

    def disable_auto_cpu_offload(self):
        """
        Disable automatic CPU offloading for all components.
        """
        if self.model_hooks is None:
            self._auto_offload_enabled = False
            return

        for hook in self.model_hooks:
            hook.offload()
            hook.remove()
        if self.model_hooks:
            clear_device_cache()
        self.model_hooks = None
        self._auto_offload_enabled = False

    # YiYi TODO: (1) add quantization info
    def get_model_info(
        self,
        component_id: str,
        fields: Optional[Union[str, List[str]]] = None,
    ) -> Optional[Dict[str, Any]]:
        """Get comprehensive information about a component.

        Args:
            component_id (str): Name of the component to get info for
            fields (Optional[Union[str, List[str]]]):
                   Field(s) to return. Can be a string for single field or list of fields. If None, uses the
                   available_info_fields setting.

        Returns:
            Dictionary containing requested component metadata. If fields is specified, returns only those fields.
            Otherwise, returns all fields.
        """
        if component_id not in self.components:
            raise ValueError(f"Component '{component_id}' not found in ComponentsManager")

        component = self.components[component_id]

        # Validate fields if specified
        if fields is not None:
            if isinstance(fields, str):
                fields = [fields]
            for field in fields:
                if field not in self._available_info_fields:
                    raise ValueError(f"Field '{field}' not found in available_info_fields")

        # Build complete info dict first
        info = {
            "model_id": component_id,
            "added_time": self.added_time[component_id],
            "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps])
            or None,
        }

        # Additional info for torch.nn.Module components
        if isinstance(component, torch.nn.Module):
            # Check for hook information
            has_hook = hasattr(component, "_hf_hook")
            execution_device = None
            if has_hook and hasattr(component._hf_hook, "execution_device"):
                execution_device = component._hf_hook.execution_device

            info.update(
                {
                    "class_name": component.__class__.__name__,
                    "size_gb": component.get_memory_footprint() / (1024**3),
                    "adapters": None,  # Default to None
                    "has_hook": has_hook,
                    "execution_device": execution_device,
                }
            )

            # Get adapters if applicable
            if hasattr(component, "peft_config"):
                info["adapters"] = list(component.peft_config.keys())

            # Check for IP-Adapter scales
            if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"):
                processors = copy.deepcopy(component.attn_processors)
                # First check if any processor is an IP-Adapter
                processor_types = [v.__class__.__name__ for v in processors.values()]
                if any("IPAdapter" in ptype for ptype in processor_types):
                    # Then get scales only from IP-Adapter processors
                    scales = {
                        k: v.scale
                        for k, v in processors.items()
                        if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
                    }
                    if scales:
                        info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)

        # If fields specified, filter info
        if fields is not None:
            return {k: v for k, v in info.items() if k in fields}
        else:
            return info

    # YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table
    def __repr__(self):
        # Handle empty components case
        if not self.components:
            return "Components:\n" + "=" * 50 + "\nNo components registered.\n" + "=" * 50

        # Extract load_id if available
        def get_load_id(component):
            if hasattr(component, "_diffusers_load_id"):
                return component._diffusers_load_id
            return "N/A"

        # Format device info compactly
        def format_device(component, info):
            if not info["has_hook"]:
                return str(getattr(component, "device", "N/A"))
            else:
                device = str(getattr(component, "device", "N/A"))
                exec_device = str(info["execution_device"] or "N/A")
                return f"{device}({exec_device})"

        # Get max length of load_ids for models
        load_ids = [
            get_load_id(component)
            for component in self.components.values()
            if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
        ]
        max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15

        # Get all collections for each component
        component_collections = {}
        for name in self.components.keys():
            component_collections[name] = []
            for coll, comps in self.collections.items():
                if name in comps:
                    component_collections[name].append(coll)
            if not component_collections[name]:
                component_collections[name] = ["N/A"]

        # Find the maximum collection name length
        all_collections = [coll for colls in component_collections.values() for coll in colls]
        max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10

        col_widths = {
            "id": max(15, max(len(name) for name in self.components.keys())),
            "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
            "device": 20,
            "dtype": 15,
            "size": 10,
            "load_id": max_load_id_len,
            "collection": max_collection_len,
        }

        # Create the header lines
        sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
        dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"

        output = "Components:\n" + sep_line

        # Separate components into models and others
        models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
        others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)}

        # Models section
        if models:
            output += "Models:\n" + dash_line
            # Column headers
            output += f"{'Name_ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
            output += f"{'Device: act(exec)':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | "
            output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n"
            output += dash_line

            # Model entries
            for name, component in models.items():
                info = self.get_model_info(name)
                device_str = format_device(component, info)
                dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
                load_id = get_load_id(component)

                # Print first collection on the main line
                first_collection = component_collections[name][0] if component_collections[name] else "N/A"

                output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
                output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
                output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n"

                # Print additional collections on separate lines if they exist
                for i in range(1, len(component_collections[name])):
                    collection = component_collections[name][i]
                    output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | "
                    output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | "
                    output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n"

            output += dash_line

        # Other components section
        if others:
            if models:  # Add extra newline if we had models section
                output += "\n"
            output += "Other Components:\n" + dash_line
            # Column headers for other components
            output += f"{'ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | Collection\n"
            output += dash_line

            # Other component entries
            for name, component in others.items():
                info = self.get_model_info(name)

                # Print first collection on the main line
                first_collection = component_collections[name][0] if component_collections[name] else "N/A"

                output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n"

                # Print additional collections on separate lines if they exist
                for i in range(1, len(component_collections[name])):
                    collection = component_collections[name][i]
                    output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | {collection}\n"

            output += dash_line

        # Add additional component info
        output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
        for name in self.components:
            info = self.get_model_info(name)
            if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
                output += f"\n{name}:\n"
                if info.get("adapters") is not None:
                    output += f"  Adapters: {info['adapters']}\n"
                if info.get("ip_adapter"):
                    output += "  IP-Adapter: Enabled\n"

        return output

    def get_one(
        self,
        component_id: Optional[str] = None,
        name: Optional[str] = None,
        collection: Optional[str] = None,
        load_id: Optional[str] = None,
    ) -> Any:
        """
        Get a single component by either:
        - searching name (pattern matching), collection, or load_id.
        - passing in a component_id
        Raises an error if multiple components match or none are found.

        Args:
            component_id (Optional[str]): Optional component ID to get
            name (Optional[str]): Component name or pattern
            collection (Optional[str]): Optional collection to filter by
            load_id (Optional[str]): Optional load_id to filter by

        Returns:
            A single component

        Raises:
            ValueError: If no components match or multiple components match
        """

        if component_id is not None and (name is not None or collection is not None or load_id is not None):
            raise ValueError("If searching by component_id, do not pass name, collection, or load_id")

        # search by component_id
        if component_id is not None:
            if component_id not in self.components:
                raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
            return self.components[component_id]
        # search with name/collection/load_id
        results = self.search_components(name, collection, load_id)

        if not results:
            raise ValueError(f"No components found matching '{name}'")

        if len(results) > 1:
            raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")

        return next(iter(results.values()))

    def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None):
        """
        Get component IDs by a list of names, optionally filtered by collection.

        Args:
            names (Union[str, List[str]]): List of component names
            collection (Optional[str]): Optional collection to filter by

        Returns:
            List[str]: List of component IDs
        """
        ids = set()
        if not isinstance(names, list):
            names = [names]
        for name in names:
            ids.update(self._lookup_ids(name=name, collection=collection))
        return list(ids)

    def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True):
        """
        Get components by a list of IDs.

        Args:
            ids (List[str]):
                List of component IDs
            return_dict_with_names (Optional[bool]):
                Whether to return a dictionary with component names as keys:

        Returns:
            Dict[str, Any]: Dictionary of components.
                - If return_dict_with_names=True, keys are component names.
                - If return_dict_with_names=False, keys are component IDs.

        Raises:
            ValueError: If duplicate component names are found in the search results when return_dict_with_names=True
        """
        components = {id: self.components[id] for id in ids}

        if return_dict_with_names:
            dict_to_return = {}
            for comp_id, comp in components.items():
                comp_name = self._id_to_name(comp_id)
                if comp_name in dict_to_return:
                    raise ValueError(
                        f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
                    )
                dict_to_return[comp_name] = comp
            return dict_to_return
        else:
            return components

    def get_components_by_names(self, names: List[str], collection: Optional[str] = None):
        """
        Get components by a list of names, optionally filtered by collection.

        Args:
            names (List[str]): List of component names
            collection (Optional[str]): Optional collection to filter by

        Returns:
            Dict[str, Any]: Dictionary of components with component names as keys

        Raises:
            ValueError: If duplicate component names are found in the search results
        """
        ids = self.get_ids(names, collection)
        return self.get_components_by_ids(ids)