base.py 2.03 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from abc import ABC, abstractmethod
from typing import Optional, Tuple

import torch
from torch import nn

from fastvideo.v1.configs.models.encoders import (BaseEncoderOutput,
                                                  ImageEncoderConfig,
                                                  TextEncoderConfig)
from fastvideo.v1.platforms import _Backend


class TextEncoder(nn.Module, ABC):
    _supported_attention_backends: Tuple[
        _Backend, ...] = TextEncoderConfig()._supported_attention_backends

    def __init__(self, config: TextEncoderConfig) -> None:
        super().__init__()
        self.config = config
        if not self.supported_attention_backends:
            raise ValueError(
                f"Subclass {self.__class__.__name__} must define _supported_attention_backends"
            )

    @abstractmethod
    def forward(self,
                input_ids: Optional[torch.Tensor],
                position_ids: Optional[torch.Tensor] = None,
                attention_mask: Optional[torch.Tensor] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                output_hidden_states: Optional[bool] = None,
                **kwargs) -> BaseEncoderOutput:
        pass

    @property
    def supported_attention_backends(self) -> Tuple[_Backend, ...]:
        return self._supported_attention_backends


class ImageEncoder(nn.Module, ABC):
    _supported_attention_backends: Tuple[
        _Backend, ...] = ImageEncoderConfig()._supported_attention_backends

    def __init__(self, config: ImageEncoderConfig) -> None:
        super().__init__()
        self.config = config
        if not self.supported_attention_backends:
            raise ValueError(
                f"Subclass {self.__class__.__name__} must define _supported_attention_backends"
            )

    @abstractmethod
    def forward(self, pixel_values: torch.Tensor,
                **kwargs) -> BaseEncoderOutput:
        pass

    @property
    def supported_attention_backends(self) -> Tuple[_Backend, ...]:
        return self._supported_attention_backends