base_scheduler.py 1.38 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from abc import abstractmethod, ABCMeta
from functools import wraps
from typing import List

from diffusers.schedulers import SchedulerMixin
from xfuser.core.distributed import (
    get_pipeline_parallel_world_size,
    get_sequence_parallel_world_size,
)
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper


class xFuserSchedulerBaseWrapper(xFuserBaseWrapper, metaclass=ABCMeta):
    def __init__(
        self,
        module: SchedulerMixin,
    ):
        super().__init__(
            module=module,
        )

    def __setattr__(self, name, value):
        if name == "module":
            super().__setattr__(name, value)
        elif (
            hasattr(self, "module")
            and self.module is not None
            and hasattr(self.module, name)
        ):
            setattr(self.module, name, value)
        else:
            super().__setattr__(name, value)

    @abstractmethod
    def step(self, *args, **kwargs):
        pass

    @staticmethod
    def check_to_use_naive_step(func):
        @wraps(func)
        def check_naive_step_fn(self, *args, **kwargs):
            if (
                get_pipeline_parallel_world_size() == 1
                and get_sequence_parallel_world_size() == 1
            ):
                return self.module.step(*args, **kwargs)
            else:
                return func(self, *args, **kwargs)

        return check_naive_step_fn