model.py 467 Bytes
Newer Older
1
2
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
3
4
5

from text_generation.models.types import Batch, GeneratedText

6
B = TypeVar("B", bound=Batch)
7
8


9
10
11
12
13
class Model(ABC):
    @property
    @abstractmethod
    def batch_type(self) -> Type[B]:
        raise NotImplementedError
14

15
    @abstractmethod
16
    def generate_token(
17
18
19
            self, batch: B
    ) -> Tuple[List[GeneratedText], Optional[B]]:
        raise NotImplementedError