model.py 1.86 KB
Newer Older
1
2
import torch

3
from abc import ABC, abstractmethod
4
from typing import List, Tuple, Optional, TypeVar, Type
5
from transformers import PreTrainedTokenizerBase
6

7
from text_generation_server.models.types import Batch, GeneratedText
8

9
10
B = TypeVar("B", bound=Batch)

11

12
class Model(ABC):
13
    def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
14
        self.tokenizer = tokenizer
15
        self.all_special_ids = set(tokenizer.all_special_ids)
16
17
        self.device = device

18
    @property
19
    @abstractmethod
20
    def batch_type(self) -> Type[B]:
21
        raise NotImplementedError
22

23
24
25
    @abstractmethod
    def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
        raise NotImplementedError
26

27
28
29
30
31
32
    def decode_token(
        self,
        all_input_ids: List[int],
        offset: Optional[int] = None,
        token_offset: Optional[int] = None,
    ) -> Tuple[str, Optional[int], Optional[int]]:
33
        """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        if all_input_ids[-1] in self.all_special_ids:
            return (
                self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
                None,
                None,
            )

        if token_offset is None:
            token_offset = len(all_input_ids) - 3

        # Decode token_offset token minus last one and token_offset tokens
        results = self.tokenizer.batch_decode(
            [all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
            skip_special_tokens=False,
48
        )
49
50
51
52
53
54
55
56
57
58
59
60
61

        # default offset is only the last token
        if offset is None:
            offset = len(results[0])

        # get text
        text = results[1][offset:]

        # if text is utf-8
        if text and text[-1] != "�":
            return text, None, None
        else:
            return "", offset, token_offset