register.py 1.77 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import Dict, Type
import torch
import torch.nn as nn

from xfuser.logger import init_logger
from xfuser.model_executor.layers.base_layer import xFuserLayerBaseWrapper

logger = init_logger(__name__)


class xFuserLayerWrappersRegister:
    _XFUSER_LAYER_MAPPING: Dict[
        Type[nn.Module], Type[xFuserLayerBaseWrapper]
    ] = {}

    @classmethod
    def register(cls, origin_layer_class: Type[nn.Module]):
        def decorator(xfuser_layer_wrapper: Type[xFuserLayerBaseWrapper]):
            if not issubclass(xfuser_layer_wrapper, xFuserLayerBaseWrapper):
                raise ValueError(
                    f"{xfuser_layer_wrapper.__class__.__name__} is not a "
                    f"subclass of xFuserLayerBaseWrapper"
                )
            cls._XFUSER_LAYER_MAPPING[origin_layer_class] = xfuser_layer_wrapper
            return xfuser_layer_wrapper

        return decorator

    @classmethod
    def get_wrapper(cls, layer: nn.Module) -> xFuserLayerBaseWrapper:
        candidate = None
        candidate_origin = None
        for (
            origin_layer_class,
            xfuser_layer_wrapper,
        ) in cls._XFUSER_LAYER_MAPPING.items():
            if isinstance(layer, origin_layer_class):
                if (
                    (candidate is None and candidate_origin is None)
                    or origin_layer_class == layer.__class__
                    or issubclass(origin_layer_class, candidate_origin)
                ):
                    candidate_origin = origin_layer_class
                    candidate = xfuser_layer_wrapper

        if candidate is None:
            raise ValueError(
                f"Layer class {layer.__class__.__name__} "
                f"is not supported by xFuser"
            )
        else:
            return candidate