"docs/vscode:/vscode.git/clone" did not exist on "029a9da00b4015b6988a28e353fc389572ac6254"
server.py 8.5 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
import os
3
import torch
4
import time
5
import signal
Olivier Dehaene's avatar
Olivier Dehaene committed
6

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
from grpc import aio
8
from loguru import logger
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
10
11

from grpc_reflection.v1alpha import reflection
from pathlib import Path
12
from typing import List, Optional
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13

14
15
16
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
17
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
18
19
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
20
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
21

OlivierDehaene's avatar
OlivierDehaene committed
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class SignalHandler:
    KEEP_PROCESSING = True

    def __init__(self):
        signal.signal(signal.SIGINT, self.exit_gracefully)
        signal.signal(signal.SIGTERM, self.exit_gracefully)

    def exit_gracefully(self, signum, frame):
        print(f"Exiting gracefully: Signal {signum}")
        self.KEEP_PROCESSING = False


signal_handler = SignalHandler()


Olivier Dehaene's avatar
Olivier Dehaene committed
38
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
39
40
41
42
43
44
45
    def __init__(
        self,
        model: Model,
        cache: Cache,
        quantize: Optional[str],
        server_urls: List[str],
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
46
47
        self.cache = cache
        self.model = model
48
        self.quantize = quantize
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
49
        self.server_urls = server_urls
50
51
52
53
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        if model.device.type == "cuda":
            # Force inference mode for the lifetime of TextGenerationService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
54

55
56
57
    async def Info(self, request, context):
        return self.model.info

58
59
60
61
62
    async def Health(self, request, context):
        if self.model.device.type == "cuda":
            torch.zeros((2, 2)).cuda()
        return generate_pb2.HealthResponse()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
63
64
65
66
    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
67
68
69
70
        if request.HasField("id"):
            self.cache.delete(request.id)
        else:
            self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
71
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
72

73
74
75
76
    async def FilterBatch(self, request, context):
        batch = self.cache.pop(request.batch_id)
        if batch is None:
            raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
77
        filtered_batch = batch.filter(request.request_ids)
78
79
80
81
        self.cache.set(filtered_batch)

        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())

82
    async def Warmup(self, request, context):
83
84
85
86
87
        if self.quantize == "gptq":
            try:
                # When using GPTQ, Exllama kernels need some global kernels
                # For which we have the finale shapes only after the model has loaded
                # This will allocate those buffers.
Nicolas Patry's avatar
Nicolas Patry committed
88
                from text_generation_server.layers.gptq import (
89
90
91
92
93
94
95
96
97
                    create_exllama_buffers,
                    set_device,
                )

                set_device(self.model.device)
                create_exllama_buffers(request.max_prefill_tokens)
            except ImportError:
                pass

98
99
100
101
102
        if self.model.batch_type in {
            IdeficsCausalLMBatch,
            VlmCausalLMBatch,
        }:  # Hack, i would rather use kwargs in the `from_pb` call
            batch = self.model.batch_type.from_pb_processor(
OlivierDehaene's avatar
OlivierDehaene committed
103
104
105
                request.batch,
                self.model.tokenizer,
                self.model.processor,
106
                self.model.model.config,
OlivierDehaene's avatar
OlivierDehaene committed
107
108
                self.model.dtype,
                self.model.device,
109
110
111
112
113
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
114
        max_supported_total_tokens = self.model.warmup(batch)
115

116
117
118
        return generate_pb2.WarmupResponse(
            max_supported_total_tokens=max_supported_total_tokens
        )
119

120
    async def Prefill(self, request, context):
121
        start = time.time_ns()
122
123
124
125
126
        if self.model.batch_type in {
            IdeficsCausalLMBatch,
            VlmCausalLMBatch,
        }:  # Hack, i would rather use kwargs in the `from_pb` call
            batch = self.model.batch_type.from_pb_processor(
OlivierDehaene's avatar
OlivierDehaene committed
127
128
129
                request.batch,
                self.model.tokenizer,
                self.model.processor,
130
                self.model.model.config,
OlivierDehaene's avatar
OlivierDehaene committed
131
132
                self.model.dtype,
                self.model.device,
133
134
135
136
137
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
Olivier Dehaene's avatar
Olivier Dehaene committed
138

139
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
140
141
        self.cache.set(next_batch)

142
143
        return generate_pb2.PrefillResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
144
            batch=next_batch.to_pb() if next_batch else None,
145
146
147
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
148
149
        )

150
    async def Decode(self, request, context):
151
        start = time.time_ns()
Olivier Dehaene's avatar
Olivier Dehaene committed
152
153
154
155
156
157
158
159
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
160
            batches.append(batch)
161
162
163

        if len(batches) == 0:
            raise ValueError("All batches are empty")
Olivier Dehaene's avatar
Olivier Dehaene committed
164
165

        if len(batches) > 1:
166
            start_concat = time.time_ns()
167
            batch = self.model.batch_type.concatenate(batches)
168
            concat_ns = time.time_ns() - start_concat
Olivier Dehaene's avatar
Olivier Dehaene committed
169
170
        else:
            batch = batches[0]
171
            concat_ns = None
Olivier Dehaene's avatar
Olivier Dehaene committed
172

173
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
174
175
        self.cache.set(next_batch)

176
177
        return generate_pb2.DecodeResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
178
            batch=next_batch.to_pb() if next_batch else None,
179
180
181
182
            concat_ns=concat_ns,
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Olivier Dehaene committed
183
184
        )

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
185

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
186
def serve(
187
188
189
190
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
191
    speculate: Optional[int],
192
193
194
    dtype: Optional[str],
    trust_remote_code: bool,
    uds_path: Path,
195
196
):
    async def serve_inner(
197
198
199
200
        model_id: str,
        revision: Optional[str],
        sharded: bool = False,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
201
        speculate: Optional[int] = None,
202
203
        dtype: Optional[str] = None,
        trust_remote_code: bool = False,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
204
    ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
205
        unix_socket_template = "unix://{}-{}"
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
206
207
        if sharded:
            server_urls = [
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
208
                unix_socket_template.format(uds_path, rank)
209
                for rank in range(int(os.environ["WORLD_SIZE"]))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
210
            ]
211
            local_url = server_urls[int(os.environ["RANK"])]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
212
        else:
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
213
            local_url = unix_socket_template.format(uds_path, 0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
214
215
            server_urls = [local_url]

216
        try:
217
            model = get_model(
OlivierDehaene's avatar
OlivierDehaene committed
218
219
220
221
222
223
224
                model_id,
                revision,
                sharded,
                quantize,
                speculate,
                dtype,
                trust_remote_code,
225
            )
226
227
228
        except Exception:
            logger.exception("Error when initializing model")
            raise
229

230
231
232
233
234
235
        server = aio.server(
            interceptors=[
                ExceptionInterceptor(),
                UDSOpenTelemetryAioServerInterceptor(),
            ]
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
236
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
237
            TextGenerationService(model, Cache(), quantize, server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
238
239
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
240
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
241
242
243
244
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
245

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
246
        await server.start()
247

248
        logger.info("Server started at {}".format(local_url))
249

250
251
        while signal_handler.KEEP_PROCESSING:
            await asyncio.sleep(0.5)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
252

253
    asyncio.run(
OlivierDehaene's avatar
OlivierDehaene committed
254
255
256
        serve_inner(
            model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
        )
257
    )