"src/diffusers/pipelines/stable_diffusion/pipeline_output.py" did not exist on "88735249da94266a433368d2b899e87dc33446c9"
model.py 629 Bytes
Newer Older
1
2
import torch

3
from abc import ABC, abstractmethod
4
from typing import List, Tuple, Optional, TypeVar, Type
5
from tokenizers import Tokenizer
6
7
8

from text_generation.models.types import Batch, GeneratedText

9
10
B = TypeVar("B", bound=Batch)

11

12
class Model(ABC):
13
    def __init__(self, tokenizer: Tokenizer, device: torch.device):
14
15
16
        self.tokenizer = tokenizer
        self.device = device

17
    @property
18
    @abstractmethod
19
    def batch_type(self) -> Type[B]:
20
        raise NotImplementedError
21

22
23
24
    @abstractmethod
    def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
        raise NotImplementedError