register.py 2.63 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Dict, Type, Union
from diffusers.pipelines.pipeline_utils import DiffusionPipeline

from xfuser.logger import init_logger
from .base_pipeline import xFuserPipelineBaseWrapper

logger = init_logger(__name__)

class xFuserPipelineWrapperRegister:
    _XFUSER_PIPE_MAPPING: Dict[
        Type[DiffusionPipeline], 
        Type[xFuserPipelineBaseWrapper]
    ] = {}

    @classmethod
    def register(cls, origin_pipe_class: Type[DiffusionPipeline]):
        def decorator(xfuser_pipe_class: Type[xFuserPipelineBaseWrapper]):
            if not issubclass(xfuser_pipe_class, xFuserPipelineBaseWrapper):
                raise ValueError(f"{xfuser_pipe_class} is not a subclass of"
                                 f" xFuserPipelineBaseWrapper")
            cls._XFUSER_PIPE_MAPPING[origin_pipe_class] = \
                xfuser_pipe_class
            return xfuser_pipe_class
        return decorator

    @classmethod
    def get_class(
        cls,
        pipe: Union[DiffusionPipeline, Type[DiffusionPipeline]]
    ) -> Type[xFuserPipelineBaseWrapper]:
        if isinstance(pipe, type):
            candidate = None
            candidate_origin = None
            for (origin_model_class, 
                 xfuser_model_class) in cls._XFUSER_PIPE_MAPPING.items():
                if issubclass(pipe, origin_model_class):
                    if ((candidate is None and candidate_origin is None) or 
                        issubclass(origin_model_class, candidate_origin)):
                        candidate_origin = origin_model_class
                        candidate = xfuser_model_class
            if candidate is None:
                raise ValueError(f"Diffusion Pipeline class {pipe} "
                                 f"is not supported by xFuser")
            else:
                return candidate
        elif isinstance(pipe, DiffusionPipeline):
            candidate = None
            candidate_origin = None
            for (origin_model_class, 
                 xfuser_model_class) in cls._XFUSER_PIPE_MAPPING.items():
                if isinstance(pipe, origin_model_class):
                    if ((candidate is None and candidate_origin is None) or 
                        issubclass(origin_model_class, candidate_origin)):
                        candidate_origin = origin_model_class
                        candidate = xfuser_model_class

            if candidate is None:
                raise ValueError(f"Diffusion Pipeline class {pipe.__class__} "
                                 f"is not supported by xFuser")
            else:
                return candidate
        else:
            raise ValueError(f"Unsupported type {type(pipe)} for pipe")