server.py 10.6 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
import os
3
import torch
4
import time
5
import signal
Olivier Dehaene's avatar
Olivier Dehaene committed
6

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
from grpc import aio
8
from loguru import logger
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
10
11

from grpc_reflection.v1alpha import reflection
from pathlib import Path
12
from typing import List, Optional
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13

14
15
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
16
17
from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.utils.adapter import AdapterInfo
18
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
19
20
21
22
23
24
25

try:
    from text_generation_server.models.pali_gemma import PaliGemmaBatch
    from text_generation_server.models.vlm_causal_lm import (
        VlmCausalLMBatch,
    )
    from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
Nicolas Patry's avatar
Nicolas Patry committed
26
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
27

Nicolas Patry's avatar
Nicolas Patry committed
28
29
30
31
32
33
    VLM_BATCH_TYPES = {
        PaliGemmaBatch,
        VlmCausalLMBatch,
        IdeficsCausalLMBatch,
        MllamaCausalLMBatch,
    }
34
35
36
37
except (ImportError, NotImplementedError):
    # These imports can fail on CPU/Non flash.
    VLM_BATCH_TYPES = set()

38
39
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
drbh's avatar
drbh committed
40
from text_generation_server.models.globals import set_adapter_to_index
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41

OlivierDehaene's avatar
OlivierDehaene committed
42

43
44
45
46
47
48
49
class SignalHandler:
    KEEP_PROCESSING = True

    def __init__(self):
        signal.signal(signal.SIGINT, self.exit_gracefully)
        signal.signal(signal.SIGTERM, self.exit_gracefully)

50
51
52
    def set_keep_processing(self, value: bool):
        self.KEEP_PROCESSING = value

53
54
    def exit_gracefully(self, signum, frame):
        print(f"Exiting gracefully: Signal {signum}")
55
        self.set_keep_processing(False)
56
57


Olivier Dehaene's avatar
Olivier Dehaene committed
58
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
59
60
61
62
63
64
    def __init__(
        self,
        model: Model,
        cache: Cache,
        server_urls: List[str],
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
65
66
        self.cache = cache
        self.model = model
Nicolas Patry's avatar
Nicolas Patry committed
67
68
        # Quantize is resolved during model loading
        self.quantize = model.quantize
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
69
        self.server_urls = server_urls
70
71
72
73
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        if model.device.type == "cuda":
            # Force inference mode for the lifetime of TextGenerationService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
74

75
76
77
    async def Info(self, request, context):
        return self.model.info

78
79
80
81
82
    async def Health(self, request, context):
        if self.model.device.type == "cuda":
            torch.zeros((2, 2)).cuda()
        return generate_pb2.HealthResponse()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
83
84
85
86
    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
87
88
89
90
        if request.HasField("id"):
            self.cache.delete(request.id)
        else:
            self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
91
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
92

93
94
95
96
    async def FilterBatch(self, request, context):
        batch = self.cache.pop(request.batch_id)
        if batch is None:
            raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
97
        filtered_batch = batch.filter(request.request_ids)
98
99
100
101
        self.cache.set(filtered_batch)

        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())

102
    async def Warmup(self, request, context):
103
104
        set_max_prefill_tokens(request.max_prefill_tokens)

105
        if self.quantize in {"exl2", "gptq"}:
106
107
108
109
            try:
                # When using GPTQ, Exllama kernels need some global kernels
                # For which we have the finale shapes only after the model has loaded
                # This will allocate those buffers.
Nicolas Patry's avatar
Nicolas Patry committed
110
                from text_generation_server.layers.gptq import (
111
112
113
114
115
116
117
118
119
                    create_exllama_buffers,
                    set_device,
                )

                set_device(self.model.device)
                create_exllama_buffers(request.max_prefill_tokens)
            except ImportError:
                pass

120
121
122
        if (
            self.model.batch_type in VLM_BATCH_TYPES
        ):  # Hack, i would rather use kwargs in the `from_pb` call
123
            batch = self.model.batch_type.from_pb_processor(
OlivierDehaene's avatar
OlivierDehaene committed
124
125
126
                request.batch,
                self.model.tokenizer,
                self.model.processor,
127
                self.model.model.config,
OlivierDehaene's avatar
OlivierDehaene committed
128
129
                self.model.dtype,
                self.model.device,
130
131
132
133
134
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
135
        max_supported_total_tokens = self.model.warmup(batch)
136

137
138
139
        return generate_pb2.WarmupResponse(
            max_supported_total_tokens=max_supported_total_tokens
        )
140

141
    async def Prefill(self, request, context):
142
        start = time.time_ns()
143
144
145
        if (
            self.model.batch_type in VLM_BATCH_TYPES
        ):  # Hack, i would rather use kwargs in the `from_pb` call
146
            batch = self.model.batch_type.from_pb_processor(
OlivierDehaene's avatar
OlivierDehaene committed
147
148
149
                request.batch,
                self.model.tokenizer,
                self.model.processor,
150
                self.model.model.config,
OlivierDehaene's avatar
OlivierDehaene committed
151
152
                self.model.dtype,
                self.model.device,
153
154
155
156
157
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
Olivier Dehaene's avatar
Olivier Dehaene committed
158

159
160
161
162
163
164
165
166
167
168
169
170
        concat_ns = None
        if self.model.support_chunking:
            if request.HasField("cached_batch"):
                cached_batch = self.cache.pop(request.cached_batch.id)
                if cached_batch is None:
                    raise ValueError(
                        f"Batch ID {request.cached_batch.id} not found in cache."
                    )
                start_concat = time.time_ns()
                batch = self.model.batch_type.concatenate([cached_batch, batch])
                concat_ns = time.time_ns() - start_concat

171
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
172
173
        self.cache.set(next_batch)

174
175
        return generate_pb2.PrefillResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
176
            batch=next_batch.to_pb() if next_batch else None,
177
178
179
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
180
            concat_ns=concat_ns,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
181
182
        )

183
    async def Decode(self, request, context):
184
        start = time.time_ns()
Olivier Dehaene's avatar
Olivier Dehaene committed
185
186
187
188
189
190
191
192
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
193
            batches.append(batch)
194
195
196

        if len(batches) == 0:
            raise ValueError("All batches are empty")
Olivier Dehaene's avatar
Olivier Dehaene committed
197
198

        if len(batches) > 1:
199
            start_concat = time.time_ns()
200
            batch = self.model.batch_type.concatenate(batches)
201
            concat_ns = time.time_ns() - start_concat
Olivier Dehaene's avatar
Olivier Dehaene committed
202
203
        else:
            batch = batches[0]
204
            concat_ns = None
Olivier Dehaene's avatar
Olivier Dehaene committed
205

206
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
207
208
        self.cache.set(next_batch)

209
210
        return generate_pb2.DecodeResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
211
            batch=next_batch.to_pb() if next_batch else None,
212
213
214
215
            concat_ns=concat_ns,
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Olivier Dehaene committed
216
217
        )

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
218

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
219
def serve(
220
    model_id: str,
221
    lora_adapters: Optional[List[AdapterInfo]],
222
223
224
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
225
    speculate: Optional[int],
226
    dtype: Optional[str],
227
    kv_cache_dtype: Optional[str],
228
229
    trust_remote_code: bool,
    uds_path: Path,
230
    max_input_tokens: int,
231
232
):
    async def serve_inner(
233
        model_id: str,
234
        lora_adapters: Optional[List[AdapterInfo]],
235
236
237
        revision: Optional[str],
        sharded: bool = False,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
238
        speculate: Optional[int] = None,
239
        dtype: Optional[str] = None,
240
        kv_cache_dtype: Optional[str] = None,
241
        trust_remote_code: bool = False,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
242
    ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
243
        unix_socket_template = "unix://{}-{}"
drbh's avatar
drbh committed
244
        adapter_to_index = {}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
245
246
        if sharded:
            server_urls = [
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
247
                unix_socket_template.format(uds_path, rank)
248
                for rank in range(int(os.environ["WORLD_SIZE"]))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
249
            ]
250
            local_url = server_urls[int(os.environ["RANK"])]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
251
        else:
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
252
            local_url = unix_socket_template.format(uds_path, 0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
253
254
            server_urls = [local_url]

255
        try:
256
            model = get_model_with_lora_adapters(
OlivierDehaene's avatar
OlivierDehaene committed
257
                model_id,
258
                lora_adapters,
OlivierDehaene's avatar
OlivierDehaene committed
259
260
261
262
263
                revision,
                sharded,
                quantize,
                speculate,
                dtype,
264
                kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
265
                trust_remote_code,
266
                max_input_tokens,
267
                adapter_to_index,
268
            )
drbh's avatar
drbh committed
269

270
271
272
        except Exception:
            logger.exception("Error when initializing model")
            raise
273

274
275
        signal_handler = SignalHandler()

drbh's avatar
drbh committed
276
        set_adapter_to_index(adapter_to_index)
277
278
        server = aio.server(
            interceptors=[
279
                ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),
280
                UDSOpenTelemetryAioServerInterceptor(),
281
282
283
284
285
            ],
            options=[
                # Set the maximum possible message length: i32::MAX
                ("grpc.max_receive_message_length", (1 << 31) - 1)
            ],
286
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
287
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
Nicolas Patry's avatar
Nicolas Patry committed
288
            TextGenerationService(model, Cache(), server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
289
290
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
291
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
292
293
294
295
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
296

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
297
        await server.start()
298

299
        logger.info("Server started at {}".format(local_url))
300
301
        while signal_handler.KEEP_PROCESSING:
            await asyncio.sleep(0.5)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
302

303
    asyncio.run(
OlivierDehaene's avatar
OlivierDehaene committed
304
        serve_inner(
drbh's avatar
drbh committed
305
            model_id,
306
            lora_adapters,
drbh's avatar
drbh committed
307
308
309
310
311
            revision,
            sharded,
            quantize,
            speculate,
            dtype,
312
            kv_cache_dtype,
drbh's avatar
drbh committed
313
            trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
314
        )
315
    )