register.py 1.84 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Dict, Type
import torch
import torch.nn as nn

from xfuser.logger import init_logger
from xfuser.model_executor.schedulers.base_scheduler import xFuserSchedulerBaseWrapper

logger = init_logger(__name__)

class xFuserSchedulerWrappersRegister:
    _XFUSER_SCHEDULER_MAPPING: Dict[
        Type[nn.Module], 
        Type[xFuserSchedulerBaseWrapper]
    ] = {}

    @classmethod
    def register(cls, origin_scheduler_class: Type[nn.Module]):
        def decorator(xfuser_scheduler_class: Type[nn.Module]):
            if not issubclass(xfuser_scheduler_class, 
                              xFuserSchedulerBaseWrapper):
                raise ValueError(
                    f"{xfuser_scheduler_class.__class__.__name__} is not "
                    f"a subclass of xFuserSchedulerBaseWrapper"
                )
            cls._XFUSER_SCHEDULER_MAPPING[origin_scheduler_class] = \
                xfuser_scheduler_class
            return xfuser_scheduler_class
        return decorator

    @classmethod
    def get_wrapper(
        cls, 
        scheduler: nn.Module
    ) -> xFuserSchedulerBaseWrapper:
        candidate = None
        candidate_origin = None
        for (origin_scheduler_class,
             wrapper_class) in cls._XFUSER_SCHEDULER_MAPPING.items():
            if isinstance(scheduler, origin_scheduler_class):
                if ((candidate is None and candidate_origin is None) or 
                    origin_scheduler_class == scheduler.__class__ or
                    issubclass(origin_scheduler_class, candidate_origin)):
                    candidate_origin = origin_scheduler_class
                    candidate = wrapper_class

        if candidate is None:
            raise ValueError(f"Scheduler class {scheduler.__class__.__name__} "
                         f"is not supported by xFuser")
        else:
            return candidate