register.py 1.94 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Dict, Type
import torch
import torch.nn as nn

from xfuser.logger import init_logger
from xfuser.model_executor.models.transformers.base_transformer import (
    xFuserTransformerBaseWrapper,
)

logger = init_logger(__name__)


class xFuserTransformerWrappersRegister:
    _XFUSER_TRANSFORMER_MAPPING: Dict[
        Type[nn.Module], Type[xFuserTransformerBaseWrapper]
    ] = {}

    @classmethod
    def register(cls, origin_transformer_class: Type[nn.Module]):
        def decorator(xfuser_transformer_class: Type[nn.Module]):
            if not issubclass(
                xfuser_transformer_class, xFuserTransformerBaseWrapper
            ):
                raise ValueError(
                    f"{xfuser_transformer_class.__class__.__name__} is not "
                    f"a subclass of xFuserTransformerBaseWrapper"
                )
            cls._XFUSER_TRANSFORMER_MAPPING[origin_transformer_class] = (
                xfuser_transformer_class
            )
            return xfuser_transformer_class

        return decorator

    @classmethod
    def get_wrapper(cls, transformer: nn.Module) -> xFuserTransformerBaseWrapper:
        candidate = None
        candidate_origin = None
        for (
            origin_transformer_class,
            wrapper_class,
        ) in cls._XFUSER_TRANSFORMER_MAPPING.items():
            if isinstance(transformer, origin_transformer_class):
                if (
                    candidate is None
                    or origin_transformer_class == transformer.__class__
                    or issubclass(origin_transformer_class, candidate_origin)
                ):
                    candidate_origin = origin_transformer_class
                    candidate = wrapper_class

        if candidate is None:
            raise ValueError(
                f"Transformer class {transformer.__class__.__name__} "
                f"is not supported by xFuser"
            )
        else:
            return candidate