scheduler.py 6.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import copy
import inspect
from typing import Any, List, Optional, Union

import torch


class BaseAsyncScheduler:
    def __init__(self, scheduler: Any):
        self.scheduler = scheduler

    def __getattr__(self, name: str):
        if hasattr(self.scheduler, name):
            return getattr(self.scheduler, name)
        raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")

    def __setattr__(self, name: str, value):
        if name == "scheduler":
            super().__setattr__(name, value)
        else:
            if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
                setattr(self.scheduler, name, value)
            else:
                super().__setattr__(name, value)

    def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
        local = copy.deepcopy(self.scheduler)
        local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
        cloned = self.__class__(local)
        return cloned

    def __repr__(self):
        return f"BaseAsyncScheduler({repr(self.scheduler)})"

    def __str__(self):
        return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"


def async_retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    sigmas: Optional[List[float]] = None,
    **kwargs,
):
    r"""
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
    Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Backwards compatible: by default the function behaves exactly as before and returns
        (timesteps_tensor, num_inference_steps)

    If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
    scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
    or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
        (timesteps_tensor, num_inference_steps, scheduler_in_use)

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`List[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Optional kwargs:
        return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
            where `scheduler_in_use` is a scheduler instance that already has timesteps set.
            This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.

    Returns:
        `(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
        `(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
    """
    # pop our optional control kwarg (keeps compatibility)
    return_scheduler = bool(kwargs.pop("return_scheduler", False))

    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")

    # choose scheduler to call set_timesteps on
    scheduler_in_use = scheduler
    if return_scheduler:
        # Do not mutate the provided scheduler: prefer to clone if possible
        if hasattr(scheduler, "clone_for_request"):
            try:
                # clone_for_request may accept num_inference_steps or other kwargs; be permissive
                scheduler_in_use = scheduler.clone_for_request(
                    num_inference_steps=num_inference_steps or 0, device=device
                )
            except Exception:
                scheduler_in_use = copy.deepcopy(scheduler)
        else:
            # fallback deepcopy (scheduler tends to be smallish - acceptable)
            scheduler_in_use = copy.deepcopy(scheduler)

    # helper to test if set_timesteps supports a particular kwarg
    def _accepts(param_name: str) -> bool:
        try:
            return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
        except (ValueError, TypeError):
            # if signature introspection fails, be permissive and attempt the call later
            return False

    # now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
    if timesteps is not None:
        accepts_timesteps = _accepts("timesteps")
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps_out = scheduler_in_use.timesteps
        num_inference_steps = len(timesteps_out)
    elif sigmas is not None:
        accept_sigmas = _accepts("sigmas")
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps_out = scheduler_in_use.timesteps
        num_inference_steps = len(timesteps_out)
    else:
        # default path
        scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps_out = scheduler_in_use.timesteps

    if return_scheduler:
        return timesteps_out, num_inference_steps, scheduler_in_use
    return timesteps_out, num_inference_steps