_fused_base.py 884 Bytes
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from abc import abstractmethod
from logging import getLogger

import torch.nn as nn

from .triton_utils.mixin import TritonModuleMixin


logger = getLogger(__name__)


class FusedBaseModule(nn.Module, TritonModuleMixin):
    @classmethod
    @abstractmethod
    def inject_to_model(cls, *args, **kwargs):
        raise NotImplementedError()


class FusedBaseAttentionModule(FusedBaseModule):
    @classmethod
    @abstractmethod
    def inject_to_model(
        cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, trainable=False, **kwargs
    ):
        raise NotImplementedError()

    @classmethod
    def warmup(cls, model, transpose=False, seqlen=2048):
        pass


class FusedBaseMLPModule(FusedBaseModule):
    @classmethod
    @abstractmethod
    def inject_to_model(cls, model, use_triton=False, **kwargs):
        raise NotImplementedError()