model.py 4.63 KB
Newer Older
1
import inspect
2
3
import torch

4
from abc import ABC, abstractmethod
drbh's avatar
drbh committed
5
6
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
from collections import defaultdict
7
from transformers import PreTrainedTokenizerBase
8

9
from text_generation_server.models.types import Batch, Generation
Nicolas Patry's avatar
Nicolas Patry committed
10
from text_generation_server.utils.speculate import get_speculate
11
from text_generation_server.pb.generate_pb2 import InfoResponse
drbh's avatar
drbh committed
12
13
14
15
from text_generation_server.adapters.weights import LayerAdapterWeights

BASE_MODEL_ADAPTER_ID = "__base_model__"

16

17
18
B = TypeVar("B", bound=Batch)

19

20
class Model(ABC):
21
22
    def __init__(
        self,
drbh's avatar
drbh committed
23
        model_id: str,
24
        model: torch.nn.Module,
25
        tokenizer: PreTrainedTokenizerBase,
26
27
        requires_padding: bool,
        dtype: torch.dtype,
28
        device: torch.device,
29
30
        rank: int = 0,
        world_size: int = 1,
31
        sliding_window: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
32
        speculate: Optional[int] = None,
drbh's avatar
drbh committed
33
        adapter_id: str = BASE_MODEL_ADAPTER_ID,
34
    ):
drbh's avatar
drbh committed
35
        self.model_id = model_id
36
        self.model = model.eval()
37
        self.tokenizer = tokenizer
38
39
40
41
42
43

        # all_special_ids is not set correctly if the rust tokenizer is unpacked
        # TODO report this to transformers.
        other_special_ids = {
            id for id, token in tokenizer.added_tokens_decoder.items() if token.special
        }
44
        self.all_special_ids = set(tokenizer.all_special_ids)
45
        self.all_special_ids.update(other_special_ids)
46
47
        self.requires_padding = requires_padding
        self.dtype = dtype
48
        self.device = device
49
50
        self.rank = rank
        self.world_size = world_size
51
        self.sliding_window = sliding_window if sliding_window != -1 else None
52

drbh's avatar
drbh committed
53
54
55
56
57
58
        self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
            LayerAdapterWeights
        )
        self.loaded_adapters = set()
        self.static_adapter_id = adapter_id

Nicolas Patry's avatar
Nicolas Patry committed
59
60
61
62
        if speculate is None:
            speculate = get_speculate()
        self.speculate = speculate

63
64
65
66
67
        self.has_position_ids = (
            inspect.signature(model.forward).parameters.get("position_ids", None)
            is not None
        )

68
        self.check_initialized()
69

70
71
    @property
    def info(self) -> InfoResponse:
72
73
74
        if self.requires_padding and self.sliding_window is not None:
            raise NotImplementedError("sliding_window is not implemented with padding")

75
76
77
78
        return InfoResponse(
            requires_padding=self.requires_padding,
            dtype=str(self.dtype),
            device_type=self.device.type,
79
            window_size=self.sliding_window,
OlivierDehaene's avatar
OlivierDehaene committed
80
            speculate=self.speculate,
81
82
        )

83
    @property
84
    @abstractmethod
85
    def batch_type(self) -> Type[B]:
86
        raise NotImplementedError
87

88
    @abstractmethod
89
90
91
    def generate_token(
        self, batch: B
    ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
92
        raise NotImplementedError
93

94
    def warmup(self, batch: B) -> Optional[int]:
95
        self.generate_token(batch)
96
        return None
97

98
99
100
    def decode_token(
        self,
        all_input_ids: List[int],
101
102
        prefix_offset: int = 0,
        read_offset: int = 0,
103
        skip_special_tokens: bool = False,
104
    ) -> Tuple[str, int, int]:
105
        """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
106

107
108
109
        # The prefix text is necessary only to defeat cleanup algorithms in the decode
        # which decide to add a space or not depending on the surrounding ids.
        prefix_text = self.tokenizer.decode(
OlivierDehaene's avatar
OlivierDehaene committed
110
111
            all_input_ids[prefix_offset:read_offset],
            skip_special_tokens=skip_special_tokens,
112
113
        )
        new_text = self.tokenizer.decode(
114
            all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
115
        )
116

117
118
119
120
121
122
123
        if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
            # utf-8 char at the end means it's a potential unfinished byte sequence
            # from byte fallback tokenization.
            # If it's in the middle, it's probably a real invalid id generated
            # by the model
            new_text = new_text[len(prefix_text) :]
            return new_text, read_offset, len(all_input_ids)
124
        else:
125
            return "", prefix_offset, read_offset
126
127
128
129
130
131
132
133
134
135

    def check_initialized(self):
        uninitialized_parameters = []
        for n, p in self.model.named_parameters():
            if p.data.device == torch.device("meta"):
                uninitialized_parameters.append(n)
        if uninitialized_parameters:
            raise RuntimeError(
                f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
            )