"vscode:/vscode.git/clone" did not exist on "c43356267b6e74a80e7b76ac3d680a0c2aca3a80"
server.py 8.63 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
import os
3
import torch
4
import time
5
import signal
Olivier Dehaene's avatar
Olivier Dehaene committed
6

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
from grpc import aio
8
from loguru import logger
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
10
11

from grpc_reflection.v1alpha import reflection
from pathlib import Path
12
from typing import List, Optional
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13

14
15
16
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
drbh's avatar
drbh committed
17
18
19
20
from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.vlm_causal_lm import (
    VlmCausalLMBatch,
)
21
22
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
23
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
24

OlivierDehaene's avatar
OlivierDehaene committed
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class SignalHandler:
    KEEP_PROCESSING = True

    def __init__(self):
        signal.signal(signal.SIGINT, self.exit_gracefully)
        signal.signal(signal.SIGTERM, self.exit_gracefully)

    def exit_gracefully(self, signum, frame):
        print(f"Exiting gracefully: Signal {signum}")
        self.KEEP_PROCESSING = False


signal_handler = SignalHandler()


Olivier Dehaene's avatar
Olivier Dehaene committed
41
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
42
43
44
45
46
47
48
    def __init__(
        self,
        model: Model,
        cache: Cache,
        quantize: Optional[str],
        server_urls: List[str],
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
49
50
        self.cache = cache
        self.model = model
51
        self.quantize = quantize
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
52
        self.server_urls = server_urls
53
54
55
56
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        if model.device.type == "cuda":
            # Force inference mode for the lifetime of TextGenerationService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
57

58
59
60
    async def Info(self, request, context):
        return self.model.info

61
62
63
64
65
    async def Health(self, request, context):
        if self.model.device.type == "cuda":
            torch.zeros((2, 2)).cuda()
        return generate_pb2.HealthResponse()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
66
67
68
69
    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
70
71
72
73
        if request.HasField("id"):
            self.cache.delete(request.id)
        else:
            self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
74
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
75

76
77
78
79
    async def FilterBatch(self, request, context):
        batch = self.cache.pop(request.batch_id)
        if batch is None:
            raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
80
        filtered_batch = batch.filter(request.request_ids)
81
82
83
84
        self.cache.set(filtered_batch)

        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())

85
    async def Warmup(self, request, context):
86
87
88
89
90
        if self.quantize == "gptq":
            try:
                # When using GPTQ, Exllama kernels need some global kernels
                # For which we have the finale shapes only after the model has loaded
                # This will allocate those buffers.
Nicolas Patry's avatar
Nicolas Patry committed
91
                from text_generation_server.layers.gptq import (
92
93
94
95
96
97
98
99
100
                    create_exllama_buffers,
                    set_device,
                )

                set_device(self.model.device)
                create_exllama_buffers(request.max_prefill_tokens)
            except ImportError:
                pass

101
102
103
        if self.model.batch_type in {
            IdeficsCausalLMBatch,
            VlmCausalLMBatch,
drbh's avatar
drbh committed
104
            PaliGemmaBatch,
105
106
        }:  # Hack, i would rather use kwargs in the `from_pb` call
            batch = self.model.batch_type.from_pb_processor(
OlivierDehaene's avatar
OlivierDehaene committed
107
108
109
                request.batch,
                self.model.tokenizer,
                self.model.processor,
110
                self.model.model.config,
OlivierDehaene's avatar
OlivierDehaene committed
111
112
                self.model.dtype,
                self.model.device,
113
114
115
116
117
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
118
        max_supported_total_tokens = self.model.warmup(batch)
119

120
121
122
        return generate_pb2.WarmupResponse(
            max_supported_total_tokens=max_supported_total_tokens
        )
123

124
    async def Prefill(self, request, context):
125
        start = time.time_ns()
126
127
128
        if self.model.batch_type in {
            IdeficsCausalLMBatch,
            VlmCausalLMBatch,
drbh's avatar
drbh committed
129
            PaliGemmaBatch,
130
131
        }:  # Hack, i would rather use kwargs in the `from_pb` call
            batch = self.model.batch_type.from_pb_processor(
OlivierDehaene's avatar
OlivierDehaene committed
132
133
134
                request.batch,
                self.model.tokenizer,
                self.model.processor,
135
                self.model.model.config,
OlivierDehaene's avatar
OlivierDehaene committed
136
137
                self.model.dtype,
                self.model.device,
138
139
140
141
142
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
Olivier Dehaene's avatar
Olivier Dehaene committed
143

144
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
145
146
        self.cache.set(next_batch)

147
148
        return generate_pb2.PrefillResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
149
            batch=next_batch.to_pb() if next_batch else None,
150
151
152
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
153
154
        )

155
    async def Decode(self, request, context):
156
        start = time.time_ns()
Olivier Dehaene's avatar
Olivier Dehaene committed
157
158
159
160
161
162
163
164
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
165
            batches.append(batch)
166
167
168

        if len(batches) == 0:
            raise ValueError("All batches are empty")
Olivier Dehaene's avatar
Olivier Dehaene committed
169
170

        if len(batches) > 1:
171
            start_concat = time.time_ns()
172
            batch = self.model.batch_type.concatenate(batches)
173
            concat_ns = time.time_ns() - start_concat
Olivier Dehaene's avatar
Olivier Dehaene committed
174
175
        else:
            batch = batches[0]
176
            concat_ns = None
Olivier Dehaene's avatar
Olivier Dehaene committed
177

178
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
179
180
        self.cache.set(next_batch)

181
182
        return generate_pb2.DecodeResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
183
            batch=next_batch.to_pb() if next_batch else None,
184
185
186
187
            concat_ns=concat_ns,
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Olivier Dehaene committed
188
189
        )

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
190

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
191
def serve(
192
193
194
195
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
196
    speculate: Optional[int],
197
198
199
    dtype: Optional[str],
    trust_remote_code: bool,
    uds_path: Path,
200
201
):
    async def serve_inner(
202
203
204
205
        model_id: str,
        revision: Optional[str],
        sharded: bool = False,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
206
        speculate: Optional[int] = None,
207
208
        dtype: Optional[str] = None,
        trust_remote_code: bool = False,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
209
    ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
210
        unix_socket_template = "unix://{}-{}"
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
211
212
        if sharded:
            server_urls = [
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
213
                unix_socket_template.format(uds_path, rank)
214
                for rank in range(int(os.environ["WORLD_SIZE"]))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
215
            ]
216
            local_url = server_urls[int(os.environ["RANK"])]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
217
        else:
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
218
            local_url = unix_socket_template.format(uds_path, 0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
219
220
            server_urls = [local_url]

221
        try:
222
            model = get_model(
OlivierDehaene's avatar
OlivierDehaene committed
223
224
225
226
227
228
229
                model_id,
                revision,
                sharded,
                quantize,
                speculate,
                dtype,
                trust_remote_code,
230
            )
231
232
233
        except Exception:
            logger.exception("Error when initializing model")
            raise
234

235
236
237
238
239
240
        server = aio.server(
            interceptors=[
                ExceptionInterceptor(),
                UDSOpenTelemetryAioServerInterceptor(),
            ]
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
241
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
242
            TextGenerationService(model, Cache(), quantize, server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
243
244
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
245
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
246
247
248
249
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
250

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
251
        await server.start()
252

253
        logger.info("Server started at {}".format(local_url))
254

255
256
        while signal_handler.KEEP_PROCESSING:
            await asyncio.sleep(0.5)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
257

258
    asyncio.run(
OlivierDehaene's avatar
OlivierDehaene committed
259
260
261
        serve_inner(
            model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
        )
262
    )