"src/diffusers/models/controlnets/controlnet_xs.py" did not exist on "cf6e0407e051467b480830d3ed97d2873b5019d3"
server.py 11.1 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
import os
3
import torch
4
import time
5
import signal
Olivier Dehaene's avatar
Olivier Dehaene committed
6

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
from grpc import aio
8
from loguru import logger
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
10
11

from grpc_reflection.v1alpha import reflection
from pathlib import Path
12
from typing import List, Optional
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13

14
15
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
16
17
from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.utils.adapter import AdapterInfo
18
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
19
20
21
22
23
24
25

try:
    from text_generation_server.models.pali_gemma import PaliGemmaBatch
    from text_generation_server.models.vlm_causal_lm import (
        VlmCausalLMBatch,
    )
    from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
Nicolas Patry's avatar
Nicolas Patry committed
26
    from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
27

Nicolas Patry's avatar
Nicolas Patry committed
28
29
30
31
32
33
    VLM_BATCH_TYPES = {
        PaliGemmaBatch,
        VlmCausalLMBatch,
        IdeficsCausalLMBatch,
        MllamaCausalLMBatch,
    }
34
35
36
37
except (ImportError, NotImplementedError):
    # These imports can fail on CPU/Non flash.
    VLM_BATCH_TYPES = set()

38
39
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
drbh's avatar
drbh committed
40
from text_generation_server.models.globals import set_adapter_to_index
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41

OlivierDehaene's avatar
OlivierDehaene committed
42

43
44
45
46
47
48
49
class SignalHandler:
    KEEP_PROCESSING = True

    def __init__(self):
        signal.signal(signal.SIGINT, self.exit_gracefully)
        signal.signal(signal.SIGTERM, self.exit_gracefully)

50
51
52
    def set_keep_processing(self, value: bool):
        self.KEEP_PROCESSING = value

53
54
    def exit_gracefully(self, signum, frame):
        print(f"Exiting gracefully: Signal {signum}")
55
        self.set_keep_processing(False)
56
57


Olivier Dehaene's avatar
Olivier Dehaene committed
58
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
59
60
61
62
63
64
    def __init__(
        self,
        model: Model,
        cache: Cache,
        server_urls: List[str],
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
65
66
        self.cache = cache
        self.model = model
Nicolas Patry's avatar
Nicolas Patry committed
67
68
        # Quantize is resolved during model loading
        self.quantize = model.quantize
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
69
        self.server_urls = server_urls
70
71
72
73
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        if model.device.type == "cuda":
            # Force inference mode for the lifetime of TextGenerationService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
74

75
76
77
    async def Info(self, request, context):
        return self.model.info

78
79
80
81
82
    async def Health(self, request, context):
        if self.model.device.type == "cuda":
            torch.zeros((2, 2)).cuda()
        return generate_pb2.HealthResponse()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
83
84
85
86
    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
87
88
89
90
        if request.HasField("id"):
            self.cache.delete(request.id)
        else:
            self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
91
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
92

93
94
95
96
    async def FilterBatch(self, request, context):
        batch = self.cache.pop(request.batch_id)
        if batch is None:
            raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
97
        filtered_batch = batch.filter(request.request_ids)
98
99
100
101
        self.cache.set(filtered_batch)

        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())

102
    async def Warmup(self, request, context):
103
104
        set_max_prefill_tokens(request.max_prefill_tokens)

105
        if self.quantize in {"exl2", "gptq"}:
106
107
108
109
            try:
                # When using GPTQ, Exllama kernels need some global kernels
                # For which we have the finale shapes only after the model has loaded
                # This will allocate those buffers.
Nicolas Patry's avatar
Nicolas Patry committed
110
                from text_generation_server.layers.gptq import (
111
112
113
114
115
116
117
118
119
                    create_exllama_buffers,
                    set_device,
                )

                set_device(self.model.device)
                create_exllama_buffers(request.max_prefill_tokens)
            except ImportError:
                pass

120
121
122
        if (
            self.model.batch_type in VLM_BATCH_TYPES
        ):  # Hack, i would rather use kwargs in the `from_pb` call
123
            batch = self.model.batch_type.from_pb_processor(
OlivierDehaene's avatar
OlivierDehaene committed
124
125
126
                request.batch,
                self.model.tokenizer,
                self.model.processor,
127
                self.model.model.config,
OlivierDehaene's avatar
OlivierDehaene committed
128
129
                self.model.dtype,
                self.model.device,
130
131
132
133
134
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
135
136
137
138
139
140
141
142
143
144
145

        # Override default values with None for clearer semantics.
        max_input_tokens = (
            request.max_input_tokens if request.HasField("max_input_tokens") else None
        )
        max_total_tokens = (
            request.max_total_tokens if request.HasField("max_total_tokens") else None
        )
        max_supported_total_tokens, max_input_tokens, max_total_tokens = (
            self.model.warmup(batch, max_input_tokens, max_total_tokens)
        )
146

147
        return generate_pb2.WarmupResponse(
148
149
150
            max_supported_total_tokens=max_supported_total_tokens,
            max_input_tokens=max_input_tokens,
            max_total_tokens=max_total_tokens,
151
        )
152

153
    async def Prefill(self, request, context):
154
        start = time.time_ns()
155
156
157
        if (
            self.model.batch_type in VLM_BATCH_TYPES
        ):  # Hack, i would rather use kwargs in the `from_pb` call
158
            batch = self.model.batch_type.from_pb_processor(
OlivierDehaene's avatar
OlivierDehaene committed
159
160
161
                request.batch,
                self.model.tokenizer,
                self.model.processor,
162
                self.model.model.config,
OlivierDehaene's avatar
OlivierDehaene committed
163
164
                self.model.dtype,
                self.model.device,
165
166
167
168
169
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
Olivier Dehaene's avatar
Olivier Dehaene committed
170

171
172
173
174
175
176
177
178
179
180
181
182
        concat_ns = None
        if self.model.support_chunking:
            if request.HasField("cached_batch"):
                cached_batch = self.cache.pop(request.cached_batch.id)
                if cached_batch is None:
                    raise ValueError(
                        f"Batch ID {request.cached_batch.id} not found in cache."
                    )
                start_concat = time.time_ns()
                batch = self.model.batch_type.concatenate([cached_batch, batch])
                concat_ns = time.time_ns() - start_concat

183
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
184
185
        self.cache.set(next_batch)

186
187
        return generate_pb2.PrefillResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
188
            batch=next_batch.to_pb() if next_batch else None,
189
190
191
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
192
            concat_ns=concat_ns,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
193
194
        )

195
    async def Decode(self, request, context):
196
        start = time.time_ns()
Olivier Dehaene's avatar
Olivier Dehaene committed
197
198
199
200
201
202
203
204
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
205
            batches.append(batch)
206
207
208

        if len(batches) == 0:
            raise ValueError("All batches are empty")
Olivier Dehaene's avatar
Olivier Dehaene committed
209
210

        if len(batches) > 1:
211
            start_concat = time.time_ns()
212
            batch = self.model.batch_type.concatenate(batches)
213
            concat_ns = time.time_ns() - start_concat
Olivier Dehaene's avatar
Olivier Dehaene committed
214
215
        else:
            batch = batches[0]
216
            concat_ns = None
Olivier Dehaene's avatar
Olivier Dehaene committed
217

218
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
219
220
        self.cache.set(next_batch)

221
222
        return generate_pb2.DecodeResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
223
            batch=next_batch.to_pb() if next_batch else None,
224
225
226
227
            concat_ns=concat_ns,
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Olivier Dehaene committed
228
229
        )

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
230

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
231
def serve(
232
    model_id: str,
233
    lora_adapters: Optional[List[AdapterInfo]],
234
235
236
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
237
    speculate: Optional[int],
238
    dtype: Optional[str],
239
    kv_cache_dtype: Optional[str],
240
241
    trust_remote_code: bool,
    uds_path: Path,
242
    max_input_tokens: int,
243
244
):
    async def serve_inner(
245
        model_id: str,
246
        lora_adapters: Optional[List[AdapterInfo]],
247
248
249
        revision: Optional[str],
        sharded: bool = False,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
250
        speculate: Optional[int] = None,
251
        dtype: Optional[str] = None,
252
        kv_cache_dtype: Optional[str] = None,
253
        trust_remote_code: bool = False,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
254
    ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
255
        unix_socket_template = "unix://{}-{}"
drbh's avatar
drbh committed
256
        adapter_to_index = {}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
257
258
        if sharded:
            server_urls = [
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
259
                unix_socket_template.format(uds_path, rank)
260
                for rank in range(int(os.environ["WORLD_SIZE"]))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
261
            ]
262
            local_url = server_urls[int(os.environ["RANK"])]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
263
        else:
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
264
            local_url = unix_socket_template.format(uds_path, 0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
265
266
            server_urls = [local_url]

267
        try:
268
            model = get_model_with_lora_adapters(
OlivierDehaene's avatar
OlivierDehaene committed
269
                model_id,
270
                lora_adapters,
OlivierDehaene's avatar
OlivierDehaene committed
271
272
273
274
275
                revision,
                sharded,
                quantize,
                speculate,
                dtype,
276
                kv_cache_dtype,
OlivierDehaene's avatar
OlivierDehaene committed
277
                trust_remote_code,
278
                max_input_tokens,
279
                adapter_to_index,
280
            )
drbh's avatar
drbh committed
281

282
283
284
        except Exception:
            logger.exception("Error when initializing model")
            raise
285

286
287
        signal_handler = SignalHandler()

drbh's avatar
drbh committed
288
        set_adapter_to_index(adapter_to_index)
289
290
        server = aio.server(
            interceptors=[
291
                ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),
292
                UDSOpenTelemetryAioServerInterceptor(),
293
294
295
296
297
            ],
            options=[
                # Set the maximum possible message length: i32::MAX
                ("grpc.max_receive_message_length", (1 << 31) - 1)
            ],
298
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
299
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
Nicolas Patry's avatar
Nicolas Patry committed
300
            TextGenerationService(model, Cache(), server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
301
302
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
303
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
304
305
306
307
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
308

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
309
        await server.start()
310

311
        logger.info("Server started at {}".format(local_url))
312
313
        while signal_handler.KEEP_PROCESSING:
            await asyncio.sleep(0.5)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
314

315
    asyncio.run(
OlivierDehaene's avatar
OlivierDehaene committed
316
        serve_inner(
drbh's avatar
drbh committed
317
            model_id,
318
            lora_adapters,
drbh's avatar
drbh committed
319
320
321
322
323
            revision,
            sharded,
            quantize,
            speculate,
            dtype,
324
            kv_cache_dtype,
drbh's avatar
drbh committed
325
            trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
326
        )
327
    )