model.py 6.25 KB
Newer Older
1
import inspect
2
3
import torch

4
from abc import ABC, abstractmethod
5
from typing import List, Tuple, Optional, TypeVar, Type, Dict
drbh's avatar
drbh committed
6
from collections import defaultdict
7
from transformers import PreTrainedTokenizerBase
8
from loguru import logger
9

10
11
12
13
14
15
from text_generation_server.models.globals import (
    ATTENTION,
    PREFIX_CACHING,
    BLOCK_SIZE,
    PREFILL_CHUNKING,
)
16
from text_generation_server.models.types import Batch, Generation
17
18
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import set_support_chunking
Nicolas Patry's avatar
Nicolas Patry committed
19
from text_generation_server.utils.speculate import get_speculate
20
from text_generation_server.pb.generate_pb2 import InfoResponse
drbh's avatar
drbh committed
21
22
23
24
from text_generation_server.adapters.weights import LayerAdapterWeights

BASE_MODEL_ADAPTER_ID = "__base_model__"

25

26
27
B = TypeVar("B", bound=Batch)

28

29
class Model(ABC):
30
31
    def __init__(
        self,
drbh's avatar
drbh committed
32
        model_id: str,
33
        model: torch.nn.Module,
34
        tokenizer: PreTrainedTokenizerBase,
35
36
        requires_padding: bool,
        dtype: torch.dtype,
37
        device: torch.device,
38
39
        rank: int = 0,
        world_size: int = 1,
40
        sliding_window: Optional[int] = None,
Nicolas Patry's avatar
Nicolas Patry committed
41
        speculate: Optional[int] = None,
drbh's avatar
drbh committed
42
        adapter_id: str = BASE_MODEL_ADAPTER_ID,
43
        support_chunking: bool = False,
44
    ):
drbh's avatar
drbh committed
45
        self.model_id = model_id
46
        self.model = model.eval()
47
        self.tokenizer = tokenizer
48
49
50
51
52
53

        # all_special_ids is not set correctly if the rust tokenizer is unpacked
        # TODO report this to transformers.
        other_special_ids = {
            id for id, token in tokenizer.added_tokens_decoder.items() if token.special
        }
54
        self.all_special_ids = set(tokenizer.all_special_ids)
55
        self.all_special_ids.update(other_special_ids)
56
57
        self.requires_padding = requires_padding
        self.dtype = dtype
58
        self.device = device
59
60
        self.rank = rank
        self.world_size = world_size
61
        self.sliding_window = sliding_window if sliding_window != -1 else None
62

drbh's avatar
drbh committed
63
64
65
66
67
68
        self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
            LayerAdapterWeights
        )
        self.loaded_adapters = set()
        self.static_adapter_id = adapter_id

Nicolas Patry's avatar
Nicolas Patry committed
69
70
71
72
        if speculate is None:
            speculate = get_speculate()
        self.speculate = speculate

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        support_chunking = support_chunking and PREFILL_CHUNKING

        if speculate != 0 and support_chunking:
            log_master(
                logger.warning,
                "Prefill chunking does not support speculation yet. "
                "Prefill chunking will be turned off",
            )
            support_chunking = False
        if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
            log_master(
                logger.warning,
                "Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
            )
            support_chunking = False

89
        log_master(logger.info, f"Using prefill chunking = {support_chunking}")
90
91
92
93

        self.support_chunking = support_chunking
        set_support_chunking(support_chunking)

94
95
96
97
98
        self.has_position_ids = (
            inspect.signature(model.forward).parameters.get("position_ids", None)
            is not None
        )

99
        self.check_initialized()
100

101
102
    @property
    def info(self) -> InfoResponse:
103
104
105
        if self.requires_padding and self.sliding_window is not None:
            raise NotImplementedError("sliding_window is not implemented with padding")

106
107
108
109
        return InfoResponse(
            requires_padding=self.requires_padding,
            dtype=str(self.dtype),
            device_type=self.device.type,
110
            window_size=self.sliding_window,
OlivierDehaene's avatar
OlivierDehaene committed
111
            speculate=self.speculate,
112
113
114
115
            support_chunking=self.support_chunking,
            use_prefix_caching=PREFIX_CACHING,
            attention_impl=ATTENTION,
            block_size=BLOCK_SIZE,
116
117
        )

118
    @property
119
    @abstractmethod
120
    def batch_type(self) -> Type[B]:
121
        raise NotImplementedError
122

123
    @abstractmethod
124
125
126
    def generate_token(
        self, batch: B
    ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
127
        raise NotImplementedError
128

129
130
131
    def warmup(
        self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
    ) -> Tuple[Optional[int], int, int]:
132
        self.generate_token(batch)
133
134
135
136
137
138
139
        total = sum(len(i) for i in batch.input_ids)
        if max_total_tokens is None:
            max_total_tokens = total

        if max_input_tokens is None:
            max_input_tokens = max_total_tokens - 1
        return None, max_input_tokens, max_total_tokens
140

141
142
143
    def decode_token(
        self,
        all_input_ids: List[int],
144
145
        prefix_offset: int = 0,
        read_offset: int = 0,
146
        skip_special_tokens: bool = False,
147
    ) -> Tuple[str, int, int]:
148
        """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
149

150
151
152
        # The prefix text is necessary only to defeat cleanup algorithms in the decode
        # which decide to add a space or not depending on the surrounding ids.
        prefix_text = self.tokenizer.decode(
OlivierDehaene's avatar
OlivierDehaene committed
153
154
            all_input_ids[prefix_offset:read_offset],
            skip_special_tokens=skip_special_tokens,
155
156
        )
        new_text = self.tokenizer.decode(
157
            all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
158
        )
159

160
161
162
163
164
165
166
        if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
            # utf-8 char at the end means it's a potential unfinished byte sequence
            # from byte fallback tokenization.
            # If it's in the middle, it's probably a real invalid id generated
            # by the model
            new_text = new_text[len(prefix_text) :]
            return new_text, read_offset, len(all_input_ids)
167
        else:
168
            return "", prefix_offset, read_offset
169
170
171
172
173
174
175
176
177
178

    def check_initialized(self):
        uninitialized_parameters = []
        for n, p in self.model.named_parameters():
            if p.data.device == torch.device("meta"):
                uninitialized_parameters.append(n)
        if uninitialized_parameters:
            raise RuntimeError(
                f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
            )