server.py 10.6 KB
Newer Older
jixx's avatar
init  
jixx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import asyncio
import os
import torch
import time
import signal

from grpc import aio
from loguru import logger

from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import List, Optional

from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
jixx's avatar
jixx committed
16
17
18
from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
jixx's avatar
init  
jixx committed
19
20
21
22
23
24
25

try:
    from text_generation_server.models.pali_gemma import PaliGemmaBatch
    from text_generation_server.models.vlm_causal_lm import (
        VlmCausalLMBatch,
    )
    from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
jixx's avatar
jixx committed
26
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
jixx's avatar
init  
jixx committed
27

jixx's avatar
jixx committed
28
29
30
31
32
33
    VLM_BATCH_TYPES = {
        PaliGemmaBatch,
        VlmCausalLMBatch,
        IdeficsCausalLMBatch,
        MllamaCausalLMBatch,
    }
jixx's avatar
init  
jixx committed
34
35
36
37
38
39
except (ImportError, NotImplementedError):
    # These imports can fail on CPU/Non flash.
    VLM_BATCH_TYPES = set()

from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
jixx's avatar
jixx committed
40
from text_generation_server.models.globals import set_adapter_to_index
jixx's avatar
init  
jixx committed
41
42
43
44
45
46
47
48
49


class SignalHandler:
    KEEP_PROCESSING = True

    def __init__(self):
        signal.signal(signal.SIGINT, self.exit_gracefully)
        signal.signal(signal.SIGTERM, self.exit_gracefully)

jixx's avatar
jixx committed
50
51
52
    def set_keep_processing(self, value: bool):
        self.KEEP_PROCESSING = value

jixx's avatar
init  
jixx committed
53
54
    def exit_gracefully(self, signum, frame):
        print(f"Exiting gracefully: Signal {signum}")
jixx's avatar
jixx committed
55
        self.set_keep_processing(False)
jixx's avatar
init  
jixx committed
56
57
58
59
60
61
62
63
64
65
66


class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
    def __init__(
        self,
        model: Model,
        cache: Cache,
        server_urls: List[str],
    ):
        self.cache = cache
        self.model = model
jixx's avatar
jixx committed
67
68
        # Quantize is resolved during model loading
        self.quantize = model.quantize
jixx's avatar
init  
jixx committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        self.server_urls = server_urls
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        if model.device.type == "cuda":
            # Force inference mode for the lifetime of TextGenerationService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)

    async def Info(self, request, context):
        return self.model.info

    async def Health(self, request, context):
        if self.model.device.type == "cuda":
            torch.zeros((2, 2)).cuda()
        return generate_pb2.HealthResponse()

    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
        if request.HasField("id"):
            self.cache.delete(request.id)
        else:
            self.cache.clear()
        return generate_pb2.ClearCacheResponse()

    async def FilterBatch(self, request, context):
        batch = self.cache.pop(request.batch_id)
        if batch is None:
            raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
        filtered_batch = batch.filter(request.request_ids)
        self.cache.set(filtered_batch)

        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())

    async def Warmup(self, request, context):
jixx's avatar
jixx committed
103
104
        set_max_prefill_tokens(request.max_prefill_tokens)

jixx's avatar
init  
jixx committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        if self.quantize in {"exl2", "gptq"}:
            try:
                # When using GPTQ, Exllama kernels need some global kernels
                # For which we have the finale shapes only after the model has loaded
                # This will allocate those buffers.
                from text_generation_server.layers.gptq import (
                    create_exllama_buffers,
                    set_device,
                )

                set_device(self.model.device)
                create_exllama_buffers(request.max_prefill_tokens)
            except ImportError:
                pass

        if (
            self.model.batch_type in VLM_BATCH_TYPES
        ):  # Hack, i would rather use kwargs in the `from_pb` call
            batch = self.model.batch_type.from_pb_processor(
                request.batch,
                self.model.tokenizer,
                self.model.processor,
                self.model.model.config,
                self.model.dtype,
                self.model.device,
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
        max_supported_total_tokens = self.model.warmup(batch)

        return generate_pb2.WarmupResponse(
            max_supported_total_tokens=max_supported_total_tokens
        )

    async def Prefill(self, request, context):
        start = time.time_ns()
        if (
            self.model.batch_type in VLM_BATCH_TYPES
        ):  # Hack, i would rather use kwargs in the `from_pb` call
            batch = self.model.batch_type.from_pb_processor(
                request.batch,
                self.model.tokenizer,
                self.model.processor,
                self.model.model.config,
                self.model.dtype,
                self.model.device,
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )

jixx's avatar
jixx committed
159
160
161
162
163
164
165
166
167
168
169
170
        concat_ns = None
        if self.model.support_chunking:
            if request.HasField("cached_batch"):
                cached_batch = self.cache.pop(request.cached_batch.id)
                if cached_batch is None:
                    raise ValueError(
                        f"Batch ID {request.cached_batch.id} not found in cache."
                    )
                start_concat = time.time_ns()
                batch = self.model.batch_type.concatenate([cached_batch, batch])
                concat_ns = time.time_ns() - start_concat

jixx's avatar
init  
jixx committed
171
172
173
174
175
176
177
178
179
        generations, next_batch, timings = self.model.generate_token(batch)
        self.cache.set(next_batch)

        return generate_pb2.PrefillResponse(
            generations=[generation.to_pb() for generation in generations],
            batch=next_batch.to_pb() if next_batch else None,
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
jixx's avatar
jixx committed
180
            concat_ns=concat_ns,
jixx's avatar
init  
jixx committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        )

    async def Decode(self, request, context):
        start = time.time_ns()
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
            batches.append(batch)

        if len(batches) == 0:
            raise ValueError("All batches are empty")

        if len(batches) > 1:
            start_concat = time.time_ns()
            batch = self.model.batch_type.concatenate(batches)
            concat_ns = time.time_ns() - start_concat
        else:
            batch = batches[0]
            concat_ns = None

        generations, next_batch, timings = self.model.generate_token(batch)
        self.cache.set(next_batch)

        return generate_pb2.DecodeResponse(
            generations=[generation.to_pb() for generation in generations],
            batch=next_batch.to_pb() if next_batch else None,
            concat_ns=concat_ns,
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
        )


def serve(
    model_id: str,
jixx's avatar
jixx committed
221
    lora_adapters: Optional[List[AdapterInfo]],
jixx's avatar
init  
jixx committed
222
223
224
225
226
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    speculate: Optional[int],
    dtype: Optional[str],
jixx's avatar
jixx committed
227
    kv_cache_dtype: Optional[str],
jixx's avatar
init  
jixx committed
228
229
230
231
232
233
    trust_remote_code: bool,
    uds_path: Path,
    max_input_tokens: int,
):
    async def serve_inner(
        model_id: str,
jixx's avatar
jixx committed
234
        lora_adapters: Optional[List[AdapterInfo]],
jixx's avatar
init  
jixx committed
235
236
237
238
239
        revision: Optional[str],
        sharded: bool = False,
        quantize: Optional[str] = None,
        speculate: Optional[int] = None,
        dtype: Optional[str] = None,
jixx's avatar
jixx committed
240
        kv_cache_dtype: Optional[str] = None,
jixx's avatar
init  
jixx committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        trust_remote_code: bool = False,
    ):
        unix_socket_template = "unix://{}-{}"
        adapter_to_index = {}
        if sharded:
            server_urls = [
                unix_socket_template.format(uds_path, rank)
                for rank in range(int(os.environ["WORLD_SIZE"]))
            ]
            local_url = server_urls[int(os.environ["RANK"])]
        else:
            local_url = unix_socket_template.format(uds_path, 0)
            server_urls = [local_url]

        try:
jixx's avatar
jixx committed
256
            model = get_model_with_lora_adapters(
jixx's avatar
init  
jixx committed
257
                model_id,
jixx's avatar
jixx committed
258
                lora_adapters,
jixx's avatar
init  
jixx committed
259
260
261
262
263
                revision,
                sharded,
                quantize,
                speculate,
                dtype,
jixx's avatar
jixx committed
264
                kv_cache_dtype,
jixx's avatar
init  
jixx committed
265
266
                trust_remote_code,
                max_input_tokens,
jixx's avatar
jixx committed
267
                adapter_to_index,
jixx's avatar
init  
jixx committed
268
269
270
271
272
273
            )

        except Exception:
            logger.exception("Error when initializing model")
            raise

jixx's avatar
jixx committed
274
275
        signal_handler = SignalHandler()

jixx's avatar
init  
jixx committed
276
277
278
        set_adapter_to_index(adapter_to_index)
        server = aio.server(
            interceptors=[
jixx's avatar
jixx committed
279
                ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),
jixx's avatar
init  
jixx committed
280
281
282
283
284
285
286
287
                UDSOpenTelemetryAioServerInterceptor(),
            ],
            options=[
                # Set the maximum possible message length: i32::MAX
                ("grpc.max_receive_message_length", (1 << 31) - 1)
            ],
        )
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
jixx's avatar
jixx committed
288
            TextGenerationService(model, Cache(), server_urls), server
jixx's avatar
init  
jixx committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        )
        SERVICE_NAMES = (
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)

        await server.start()

        logger.info("Server started at {}".format(local_url))
        while signal_handler.KEEP_PROCESSING:
            await asyncio.sleep(0.5)

    asyncio.run(
        serve_inner(
            model_id,
jixx's avatar
jixx committed
306
            lora_adapters,
jixx's avatar
init  
jixx committed
307
308
309
310
311
            revision,
            sharded,
            quantize,
            speculate,
            dtype,
jixx's avatar
jixx committed
312
            kv_cache_dtype,
jixx's avatar
init  
jixx committed
313
314
315
            trust_remote_code,
        )
    )