model.py 4.17 KB
Newer Older
1
import inspect
2
3
import torch

4
from abc import ABC, abstractmethod
5
from typing import List, Tuple, Optional, TypeVar, Type
6
from transformers import PreTrainedTokenizerBase, PretrainedConfig
7

8
from text_generation_server.models.types import Batch, Generation
Nicolas Patry's avatar
Nicolas Patry committed
9
from text_generation_server.utils.speculate import get_speculate
10
from text_generation_server.pb.generate_pb2 import InfoResponse
11

12
13
B = TypeVar("B", bound=Batch)

14

15
class Model(ABC):
16
17
    def __init__(
        self,
18
        model: torch.nn.Module,
19
        tokenizer: PreTrainedTokenizerBase,
20
21
        requires_padding: bool,
        dtype: torch.dtype,
22
        device: torch.device,
23
24
        rank: int = 0,
        world_size: int = 1,
25
        sliding_window: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
26
        speculate: Optional[int] = None,
27
    ):
28
        self.model = model.eval()
29
        self.tokenizer = tokenizer
30
31
32
33
34
35

        # all_special_ids is not set correctly if the rust tokenizer is unpacked
        # TODO report this to transformers.
        other_special_ids = {
            id for id, token in tokenizer.added_tokens_decoder.items() if token.special
        }
36
        self.all_special_ids = set(tokenizer.all_special_ids)
37
        self.all_special_ids.update(other_special_ids)
38
39
        self.requires_padding = requires_padding
        self.dtype = dtype
40
        self.device = device
41
42
        self.rank = rank
        self.world_size = world_size
43
        self.sliding_window = sliding_window if sliding_window != -1 else None
44

Nicolas Patry's avatar
Nicolas Patry committed
45
46
47
48
        if speculate is None:
            speculate = get_speculate()
        self.speculate = speculate

49
50
51
52
53
        self.has_position_ids = (
            inspect.signature(model.forward).parameters.get("position_ids", None)
            is not None
        )

54
        self.check_initialized()
55

56
57
    @property
    def info(self) -> InfoResponse:
58
59
60
        if self.requires_padding and self.sliding_window is not None:
            raise NotImplementedError("sliding_window is not implemented with padding")

61
62
63
64
        return InfoResponse(
            requires_padding=self.requires_padding,
            dtype=str(self.dtype),
            device_type=self.device.type,
65
            window_size=self.sliding_window,
OlivierDehaene's avatar
OlivierDehaene committed
66
            speculate=self.speculate,
67
68
        )

69
    @property
70
    @abstractmethod
71
    def batch_type(self) -> Type[B]:
72
        raise NotImplementedError
73

74
    @abstractmethod
75
76
77
    def generate_token(
        self, batch: B
    ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
78
        raise NotImplementedError
79

80
    def warmup(self, batch: B) -> Optional[int]:
81
        self.generate_token(batch)
82
        return None
83

84
85
86
    def decode_token(
        self,
        all_input_ids: List[int],
87
88
        prefix_offset: int = 0,
        read_offset: int = 0,
89
        skip_special_tokens: bool = False,
90
    ) -> Tuple[str, int, int]:
91
        """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
92

93
94
95
        # The prefix text is necessary only to defeat cleanup algorithms in the decode
        # which decide to add a space or not depending on the surrounding ids.
        prefix_text = self.tokenizer.decode(
OlivierDehaene's avatar
OlivierDehaene committed
96
97
            all_input_ids[prefix_offset:read_offset],
            skip_special_tokens=skip_special_tokens,
98
99
        )
        new_text = self.tokenizer.decode(
100
            all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
101
        )
102

103
104
105
106
107
108
109
        if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
            # utf-8 char at the end means it's a potential unfinished byte sequence
            # from byte fallback tokenization.
            # If it's in the middle, it's probably a real invalid id generated
            # by the model
            new_text = new_text[len(prefix_text) :]
            return new_text, read_offset, len(all_input_ids)
110
        else:
111
            return "", prefix_offset, read_offset
112
113
114
115
116
117
118
119
120
121

    def check_initialized(self):
        uninitialized_parameters = []
        for n, p in self.model.named_parameters():
            if p.data.device == torch.device("meta"):
                uninitialized_parameters.append(n)
        if uninitialized_parameters:
            raise RuntimeError(
                f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
            )