model.py 1.47 KB
Newer Older
1
2
import torch

3
from abc import ABC, abstractmethod
4
from typing import List, Tuple, Optional, TypeVar, Type
5
from transformers import PreTrainedTokenizerBase
6

7
from text_generation_server.models.types import Batch, GeneratedText
8

9
10
B = TypeVar("B", bound=Batch)

11

12
class Model(ABC):
13
    def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
14
        self.tokenizer = tokenizer
15
        self.all_special_ids = set(tokenizer.all_special_ids)
16
17
        self.device = device

18
19
20
21
22
23
24
25
26
        # see `decode_token` method
        self.tokenizer.add_special_tokens(
            {"additional_special_tokens": ["<decode-token>"]}
        )
        self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids(
            "<decode-token>"
        )
        self.special_decode_token_length = len("<decode-token>")

27
    @property
28
    @abstractmethod
29
    def batch_type(self) -> Type[B]:
30
        raise NotImplementedError
31

32
33
34
    @abstractmethod
    def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
        raise NotImplementedError
35
36
37
38
39
40
41
42
43

    def decode_token(self, token_id: int) -> str:
        """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
        # append token to special decode token and decode both
        result = self.tokenizer.decode(
            [self.special_decode_token_id, token_id], skip_special_tokens=False
        )
        # slice to remove special decode token
        return result[self.special_decode_token_length :]