server.py 8.23 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
import os
3
import torch
4
import time
Olivier Dehaene's avatar
Olivier Dehaene committed
5

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
6
from grpc import aio
7
from loguru import logger
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
8
9
10

from grpc_reflection.v1alpha import reflection
from pathlib import Path
11
from typing import List, Optional
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12

13
14
15
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
16
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
17
18
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
19
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
20

OlivierDehaene's avatar
OlivierDehaene committed
21

Olivier Dehaene's avatar
Olivier Dehaene committed
22
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
23
24
25
26
27
28
29
    def __init__(
        self,
        model: Model,
        cache: Cache,
        quantize: Optional[str],
        server_urls: List[str],
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
30
31
        self.cache = cache
        self.model = model
32
        self.quantize = quantize
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
33
        self.server_urls = server_urls
34
35
36
37
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        if model.device.type == "cuda":
            # Force inference mode for the lifetime of TextGenerationService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
38

39
40
41
    async def Info(self, request, context):
        return self.model.info

42
43
44
45
46
    async def Health(self, request, context):
        if self.model.device.type == "cuda":
            torch.zeros((2, 2)).cuda()
        return generate_pb2.HealthResponse()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
47
48
49
50
    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
51
52
53
54
        if request.HasField("id"):
            self.cache.delete(request.id)
        else:
            self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
55
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
56

57
58
59
60
    async def FilterBatch(self, request, context):
        batch = self.cache.pop(request.batch_id)
        if batch is None:
            raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
61
        filtered_batch = batch.filter(request.request_ids)
62
63
64
65
        self.cache.set(filtered_batch)

        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())

66
    async def Warmup(self, request, context):
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        if self.quantize == "gptq":
            try:
                # When using GPTQ, Exllama kernels need some global kernels
                # For which we have the finale shapes only after the model has loaded
                # This will allocate those buffers.
                from text_generation_server.utils.layers import (
                    create_exllama_buffers,
                    set_device,
                )

                set_device(self.model.device)
                create_exllama_buffers(request.max_prefill_tokens)
            except ImportError:
                pass

82
83
84
85
86
        if self.model.batch_type in {
            IdeficsCausalLMBatch,
            VlmCausalLMBatch,
        }:  # Hack, i would rather use kwargs in the `from_pb` call
            batch = self.model.batch_type.from_pb_processor(
OlivierDehaene's avatar
OlivierDehaene committed
87
88
89
                request.batch,
                self.model.tokenizer,
                self.model.processor,
90
                self.model.model.config,
OlivierDehaene's avatar
OlivierDehaene committed
91
92
                self.model.dtype,
                self.model.device,
93
94
95
96
97
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
98
        max_supported_total_tokens = self.model.warmup(batch)
99

100
101
102
        return generate_pb2.WarmupResponse(
            max_supported_total_tokens=max_supported_total_tokens
        )
103

104
    async def Prefill(self, request, context):
105
        start = time.time_ns()
106
107
108
109
110
        if self.model.batch_type in {
            IdeficsCausalLMBatch,
            VlmCausalLMBatch,
        }:  # Hack, i would rather use kwargs in the `from_pb` call
            batch = self.model.batch_type.from_pb_processor(
OlivierDehaene's avatar
OlivierDehaene committed
111
112
113
                request.batch,
                self.model.tokenizer,
                self.model.processor,
114
                self.model.model.config,
OlivierDehaene's avatar
OlivierDehaene committed
115
116
                self.model.dtype,
                self.model.device,
117
118
119
120
121
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
Olivier Dehaene's avatar
Olivier Dehaene committed
122

123
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
124
125
        self.cache.set(next_batch)

126
127
        return generate_pb2.PrefillResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
128
            batch=next_batch.to_pb() if next_batch else None,
129
130
131
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
132
133
        )

134
    async def Decode(self, request, context):
135
        start = time.time_ns()
Olivier Dehaene's avatar
Olivier Dehaene committed
136
137
138
139
140
141
142
143
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
144
            batches.append(batch)
145
146
147

        if len(batches) == 0:
            raise ValueError("All batches are empty")
Olivier Dehaene's avatar
Olivier Dehaene committed
148
149

        if len(batches) > 1:
150
            start_concat = time.time_ns()
151
            batch = self.model.batch_type.concatenate(batches)
152
            concat_ns = time.time_ns() - start_concat
Olivier Dehaene's avatar
Olivier Dehaene committed
153
154
        else:
            batch = batches[0]
155
            concat_ns = None
Olivier Dehaene's avatar
Olivier Dehaene committed
156

157
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
158
159
        self.cache.set(next_batch)

160
161
        return generate_pb2.DecodeResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
162
            batch=next_batch.to_pb() if next_batch else None,
163
164
165
166
            concat_ns=concat_ns,
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Olivier Dehaene committed
167
168
        )

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
169

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
170
def serve(
171
172
173
174
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
175
    speculate: Optional[int],
176
177
178
    dtype: Optional[str],
    trust_remote_code: bool,
    uds_path: Path,
179
180
):
    async def serve_inner(
181
182
183
184
        model_id: str,
        revision: Optional[str],
        sharded: bool = False,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
185
        speculate: Optional[int] = None,
186
187
        dtype: Optional[str] = None,
        trust_remote_code: bool = False,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
188
    ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
189
        unix_socket_template = "unix://{}-{}"
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
190
191
        if sharded:
            server_urls = [
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
192
                unix_socket_template.format(uds_path, rank)
193
                for rank in range(int(os.environ["WORLD_SIZE"]))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
194
            ]
195
            local_url = server_urls[int(os.environ["RANK"])]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
196
        else:
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
197
            local_url = unix_socket_template.format(uds_path, 0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
198
199
            server_urls = [local_url]

200
        try:
201
            model = get_model(
OlivierDehaene's avatar
OlivierDehaene committed
202
203
204
205
206
207
208
                model_id,
                revision,
                sharded,
                quantize,
                speculate,
                dtype,
                trust_remote_code,
209
            )
210
211
212
        except Exception:
            logger.exception("Error when initializing model")
            raise
213

214
215
216
217
218
219
        server = aio.server(
            interceptors=[
                ExceptionInterceptor(),
                UDSOpenTelemetryAioServerInterceptor(),
            ]
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
220
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
221
            TextGenerationService(model, Cache(), quantize, server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
222
223
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
224
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
225
226
227
228
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
229

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
230
        await server.start()
231

232
        logger.info("Server started at {}".format(local_url))
233

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
234
235
236
        try:
            await server.wait_for_termination()
        except KeyboardInterrupt:
237
            logger.info("Signal received. Shutting down")
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
238
            await server.stop(0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
239

240
    asyncio.run(
OlivierDehaene's avatar
OlivierDehaene committed
241
242
243
        serve_inner(
            model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
        )
244
    )