model.py 9.47 KB
Newer Older
1
import inspect
2
3
import torch

4
from abc import ABC, abstractmethod
drbh's avatar
drbh committed
5
6
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
from collections import defaultdict
7
from transformers import PreTrainedTokenizerBase, PretrainedConfig
8

9
from text_generation_server.models.types import Batch, Generation
Nicolas Patry's avatar
Nicolas Patry committed
10
from text_generation_server.utils.speculate import get_speculate
11
from text_generation_server.pb.generate_pb2 import InfoResponse
drbh's avatar
drbh committed
12
13
14
15
16
17
from text_generation_server.adapters.weights import LayerAdapterWeights
from text_generation_server.utils.adapter import (
    load_and_merge_adapters,
    AdapterParameters,
    AdapterSource,
)
18
from text_generation_server.utils.log import log_master
drbh's avatar
drbh committed
19
20
21
22
23
from loguru import logger


BASE_MODEL_ADAPTER_ID = "__base_model__"

24

25
26
B = TypeVar("B", bound=Batch)

27

28
class Model(ABC):
29
30
    def __init__(
        self,
drbh's avatar
drbh committed
31
        model_id: str,
32
        model: torch.nn.Module,
33
        tokenizer: PreTrainedTokenizerBase,
34
35
        requires_padding: bool,
        dtype: torch.dtype,
36
        device: torch.device,
37
38
        rank: int = 0,
        world_size: int = 1,
39
        sliding_window: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
40
        speculate: Optional[int] = None,
drbh's avatar
drbh committed
41
        adapter_id: str = BASE_MODEL_ADAPTER_ID,
42
    ):
drbh's avatar
drbh committed
43
        self.model_id = model_id
44
        self.model = model.eval()
45
        self.tokenizer = tokenizer
46
47
48
49
50
51

        # all_special_ids is not set correctly if the rust tokenizer is unpacked
        # TODO report this to transformers.
        other_special_ids = {
            id for id, token in tokenizer.added_tokens_decoder.items() if token.special
        }
52
        self.all_special_ids = set(tokenizer.all_special_ids)
53
        self.all_special_ids.update(other_special_ids)
54
55
        self.requires_padding = requires_padding
        self.dtype = dtype
56
        self.device = device
57
58
        self.rank = rank
        self.world_size = world_size
59
        self.sliding_window = sliding_window if sliding_window != -1 else None
60

drbh's avatar
drbh committed
61
62
63
        self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
            LayerAdapterWeights
        )
Nicolas Patry's avatar
Nicolas Patry committed
64
        self.target_to_layer = None
drbh's avatar
drbh committed
65
66
67
        self.loaded_adapters = set()
        self.static_adapter_id = adapter_id

Nicolas Patry's avatar
Nicolas Patry committed
68
69
70
71
        if speculate is None:
            speculate = get_speculate()
        self.speculate = speculate

72
73
74
75
76
        self.has_position_ids = (
            inspect.signature(model.forward).parameters.get("position_ids", None)
            is not None
        )

77
        self.check_initialized()
78

79
80
    @property
    def info(self) -> InfoResponse:
81
82
83
        if self.requires_padding and self.sliding_window is not None:
            raise NotImplementedError("sliding_window is not implemented with padding")

84
85
86
87
        return InfoResponse(
            requires_padding=self.requires_padding,
            dtype=str(self.dtype),
            device_type=self.device.type,
88
            window_size=self.sliding_window,
OlivierDehaene's avatar
OlivierDehaene committed
89
            speculate=self.speculate,
90
91
        )

92
    @property
93
    @abstractmethod
94
    def batch_type(self) -> Type[B]:
95
        raise NotImplementedError
96

97
    @abstractmethod
98
99
100
    def generate_token(
        self, batch: B
    ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
101
        raise NotImplementedError
102

103
    def warmup(self, batch: B) -> Optional[int]:
104
        self.generate_token(batch)
105
        return None
106

107
108
109
    def decode_token(
        self,
        all_input_ids: List[int],
110
111
        prefix_offset: int = 0,
        read_offset: int = 0,
112
        skip_special_tokens: bool = False,
113
    ) -> Tuple[str, int, int]:
114
        """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
115

116
117
118
        # The prefix text is necessary only to defeat cleanup algorithms in the decode
        # which decide to add a space or not depending on the surrounding ids.
        prefix_text = self.tokenizer.decode(
OlivierDehaene's avatar
OlivierDehaene committed
119
120
            all_input_ids[prefix_offset:read_offset],
            skip_special_tokens=skip_special_tokens,
121
122
        )
        new_text = self.tokenizer.decode(
123
            all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
124
        )
125

126
127
128
129
130
131
132
        if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
            # utf-8 char at the end means it's a potential unfinished byte sequence
            # from byte fallback tokenization.
            # If it's in the middle, it's probably a real invalid id generated
            # by the model
            new_text = new_text[len(prefix_text) :]
            return new_text, read_offset, len(all_input_ids)
133
        else:
134
            return "", prefix_offset, read_offset
135
136
137
138
139
140
141
142
143
144

    def check_initialized(self):
        uninitialized_parameters = []
        for n, p in self.model.named_parameters():
            if p.data.device == torch.device("meta"):
                uninitialized_parameters.append(n)
        if uninitialized_parameters:
            raise RuntimeError(
                f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
            )
drbh's avatar
drbh committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

    @property
    def supports_adapter_loading(self) -> bool:
        return False

    def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
        return {}

    @property
    def adapter_layers(self) -> List[str]:
        return []

    @property
    def default_traced_adapter_layers(self) -> List[str]:
        return []

    def get_num_layers_for_type(self, layer_type: str) -> int:
        return 0

    def is_row_parallel(self, layer_type: str) -> bool:
        return False

    @property
    def max_speculative_tokens(self) -> int:
        return max(
            [
                weights.max_speculative_tokens
                for weights in self.layer_to_adapter_weights.values()
            ],
            default=0,
        )

    def load_adapter(
        self,
        adapter_parameters: AdapterParameters,
        adapter_source: AdapterSource,
        adapter_index: int,
        api_token: str,
        dynamic: bool = True,
    ):
        """Loads adapter weights from disk / host memory on the GPU.

        adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
        into model. Otherwise, the adapter weights are applied during the forward
        pass and stored separately from the base model parameters.
        """
Nicolas Patry's avatar
Nicolas Patry committed
191
192
        if self.target_to_layer is None:
            self.target_to_layer = self.adapter_target_to_layer()
drbh's avatar
drbh committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        if adapter_index in self.loaded_adapters:
            # Adapter already loaded
            return

        if not self.supports_adapter_loading:
            raise ValueError("This model does not support adapter loading.")

        if dynamic and not self.dynamic_adapter_loading_enabled:
            raise ValueError(
                f"This model was initialized with the adapter {self.static_adapter_id} "
                f"and therefore does not support dynamic adapter loading. "
                f"Please initialize a new model instance from the base model in "
                f"order to use the dynamic adapter loading feature."
            )

208
209
210
        log_master(
            logger.info,
            f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
drbh's avatar
drbh committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        )
        weight_names = tuple([v[0] for v in self.target_to_layer.values()])
        (
            module_map,
            adapter_config,
            adapter_weight_names,
            adapter_tokenizer,
        ) = load_and_merge_adapters(
            self.model_id,
            adapter_parameters,
            adapter_source,
            adapter_index,
            weight_names,
            api_token,
            False,
        )

        unused_weight_names = adapter_weight_names.copy()
        for layer_name in self.adapter_layers:
            adapter_weights = adapter_config.load_batched_adapter_weights(
                self,
                module_map,
                layer_name,
                unused_weight_names,
                dynamic,
            )

            if adapter_weights is None:
                continue

            layer_weights = self.layer_to_adapter_weights[layer_name]
            layer_weights.add_adapter(adapter_index, adapter_weights)

        if len(unused_weight_names) > 0:
245
246
247
            log_master(
                logger.warning,
                f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
drbh's avatar
drbh committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
            )

        if adapter_tokenizer is not None:
            self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

        self.loaded_adapters.add(adapter_index)

    def offload_adapter(
        self,
        adapter_parameters: AdapterParameters,
        adapter_source: AdapterSource,
        adapter_index: int,
    ):
        """Offloads the adapter weights from GPU to CPU or disk."""
        if adapter_index not in self.loaded_adapters:
            # Adapter already offloaded
            return

        if not self.supports_adapter_loading:
            raise ValueError("This model does not support adapter loading.")

        if not self.dynamic_adapter_loading_enabled:
            raise ValueError(
                f"This model was initialized with the adapter {self.static_adapter_id} "
                f"and therefore does not support dynamic adapter loading. "
                f"Please initialize a new model instance from the base model in "
                f"order to use the dynamic adapter loading feature."
            )

        for layer_name in self.adapter_layers:
            if layer_name in self.layer_to_adapter_weights:
                self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)

        self.loaded_adapters.remove(adapter_index)