model.py 2.92 KB
Newer Older
1
2
import torch

3
from abc import ABC, abstractmethod
4
from typing import List, Tuple, Optional, TypeVar, Type
5
from transformers import PreTrainedTokenizerBase
6

7
from text_generation_server.models.types import Batch, GeneratedText
8
from text_generation_server.pb.generate_pb2 import InfoResponse
9

10
11
B = TypeVar("B", bound=Batch)

12

13
class Model(ABC):
14
15
    def __init__(
        self,
16
        model: torch.nn.Module,
17
        tokenizer: PreTrainedTokenizerBase,
18
19
        requires_padding: bool,
        dtype: torch.dtype,
20
        device: torch.device,
21
22
        rank: int = 0,
        world_size: int = 1,
23
    ):
24
        self.model = model.eval()
25
        self.tokenizer = tokenizer
26
        self.all_special_ids = set(tokenizer.all_special_ids)
27
28
        self.requires_padding = requires_padding
        self.dtype = dtype
29
        self.device = device
30
31
        self.rank = rank
        self.world_size = world_size
32
        self.check_initialized()
33

34
35
36
37
38
39
40
41
    @property
    def info(self) -> InfoResponse:
        return InfoResponse(
            requires_padding=self.requires_padding,
            dtype=str(self.dtype),
            device_type=self.device.type,
        )

42
    @property
43
    @abstractmethod
44
    def batch_type(self) -> Type[B]:
45
        raise NotImplementedError
46

47
48
49
    @abstractmethod
    def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
        raise NotImplementedError
50

51
52
53
    def decode_token(
        self,
        all_input_ids: List[int],
54
55
56
        prefix_offset: int = 0,
        read_offset: int = 0,
    ) -> Tuple[str, int, int]:
57
        """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
58

59
60
61
62
63
64
65
66
        # The prefix text is necessary only to defeat cleanup algorithms in the decode
        # which decide to add a space or not depending on the surrounding ids.
        prefix_text = self.tokenizer.decode(
            all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
        )
        new_text = self.tokenizer.decode(
            all_input_ids[prefix_offset:], skip_special_tokens=False
        )
67

68
69
70
71
72
73
74
        if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
            # utf-8 char at the end means it's a potential unfinished byte sequence
            # from byte fallback tokenization.
            # If it's in the middle, it's probably a real invalid id generated
            # by the model
            new_text = new_text[len(prefix_text) :]
            return new_text, read_offset, len(all_input_ids)
75
        else:
76
            return "", prefix_offset, read_offset
77
78
79
80
81
82
83
84
85
86

    def check_initialized(self):
        uninitialized_parameters = []
        for n, p in self.model.named_parameters():
            if p.data.device == torch.device("meta"):
                uninitialized_parameters.append(n)
        if uninitialized_parameters:
            raise RuntimeError(
                f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
            )