base_layer.py 1.28 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from abc import abstractmethod, ABCMeta
from typing import List

import torch
import torch.nn as nn

from xfuser.config.config import InputConfig, ParallelConfig, RuntimeConfig
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper


class xFuserLayerBaseWrapper(nn.Module, xFuserBaseWrapper, metaclass=ABCMeta):

    def __init__(self, module: nn.Module):
        super().__init__()
        super(nn.Module, self).__init__(module=module)
        self.activation_cache = None

    def __getattr__(self, name: str):
        if "_parameters" in self.__dict__:
            _parameters = self.__dict__["_parameters"]
            if name in _parameters:
                return _parameters[name]
        if "_buffers" in self.__dict__:
            _buffers = self.__dict__["_buffers"]
            if name in _buffers:
                return _buffers[name]
        if "_modules" in self.__dict__:
            modules = self.__dict__["_modules"]
            if name in modules:
                return modules[name]
        try:
            return getattr(self.module, name)
        except RecursionError:
            raise AttributeError(
                f"module {type(self.module).__name__} has no " f"attribute {name}"
            )

    @abstractmethod
    def forward(self, *args, **kwargs):
        pass