server.py 8 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
import os
3
import torch
4
import time
Olivier Dehaene's avatar
Olivier Dehaene committed
5

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
6
from grpc import aio
7
from loguru import logger
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
8
9
10

from grpc_reflection.v1alpha import reflection
from pathlib import Path
11
from typing import List, Optional
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12

13
14
15
16
17
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
18
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
19

OlivierDehaene's avatar
OlivierDehaene committed
20

Olivier Dehaene's avatar
Olivier Dehaene committed
21
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
22
23
24
25
26
27
28
    def __init__(
        self,
        model: Model,
        cache: Cache,
        quantize: Optional[str],
        server_urls: List[str],
    ):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
29
30
        self.cache = cache
        self.model = model
31
        self.quantize = quantize
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
32
        self.server_urls = server_urls
33
34
35
36
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        if model.device.type == "cuda":
            # Force inference mode for the lifetime of TextGenerationService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
37

38
39
40
    async def Info(self, request, context):
        return self.model.info

41
42
43
44
45
    async def Health(self, request, context):
        if self.model.device.type == "cuda":
            torch.zeros((2, 2)).cuda()
        return generate_pb2.HealthResponse()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
46
47
48
49
    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
50
51
52
53
        if request.HasField("id"):
            self.cache.delete(request.id)
        else:
            self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
54
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
55

56
57
58
59
    async def FilterBatch(self, request, context):
        batch = self.cache.pop(request.batch_id)
        if batch is None:
            raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
60
        filtered_batch = batch.filter(request.request_ids)
61
62
63
64
        self.cache.set(filtered_batch)

        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())

65
    async def Warmup(self, request, context):
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        if self.quantize == "gptq":
            try:
                # When using GPTQ, Exllama kernels need some global kernels
                # For which we have the finale shapes only after the model has loaded
                # This will allocate those buffers.
                from text_generation_server.utils.layers import (
                    create_exllama_buffers,
                    set_device,
                )

                set_device(self.model.device)
                create_exllama_buffers(request.max_prefill_tokens)
            except ImportError:
                pass

OlivierDehaene's avatar
OlivierDehaene committed
81
82
83
        if (
            self.model.batch_type == IdeficsCausalLMBatch
        ):  # Hack, i would rather use kwargs in the `from_pb` call
84
            batch = self.model.batch_type.from_pb(
OlivierDehaene's avatar
OlivierDehaene committed
85
86
87
88
89
                request.batch,
                self.model.tokenizer,
                self.model.processor,
                self.model.dtype,
                self.model.device,
90
91
92
93
94
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
95
        max_supported_total_tokens = self.model.warmup(batch)
96

97
98
99
        return generate_pb2.WarmupResponse(
            max_supported_total_tokens=max_supported_total_tokens
        )
100

101
    async def Prefill(self, request, context):
102
        start = time.time_ns()
OlivierDehaene's avatar
OlivierDehaene committed
103
104
105
        if (
            self.model.batch_type == IdeficsCausalLMBatch
        ):  # Hack, i would rather use kwargs in the `from_pb` call
106
            batch = self.model.batch_type.from_pb(
OlivierDehaene's avatar
OlivierDehaene committed
107
108
109
110
111
                request.batch,
                self.model.tokenizer,
                self.model.processor,
                self.model.dtype,
                self.model.device,
112
113
114
115
116
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
Olivier Dehaene's avatar
Olivier Dehaene committed
117

118
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
119
120
        self.cache.set(next_batch)

121
122
        return generate_pb2.PrefillResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
123
            batch=next_batch.to_pb() if next_batch else None,
124
125
126
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
127
128
        )

129
    async def Decode(self, request, context):
130
        start = time.time_ns()
Olivier Dehaene's avatar
Olivier Dehaene committed
131
132
133
134
135
136
137
138
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
139
            batches.append(batch)
140
141
142

        if len(batches) == 0:
            raise ValueError("All batches are empty")
Olivier Dehaene's avatar
Olivier Dehaene committed
143
144

        if len(batches) > 1:
145
            start_concat = time.time_ns()
146
            batch = self.model.batch_type.concatenate(batches)
147
            concat_ns = time.time_ns() - start_concat
Olivier Dehaene's avatar
Olivier Dehaene committed
148
149
        else:
            batch = batches[0]
150
            concat_ns = None
Olivier Dehaene's avatar
Olivier Dehaene committed
151

152
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
153
154
        self.cache.set(next_batch)

155
156
        return generate_pb2.DecodeResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
157
            batch=next_batch.to_pb() if next_batch else None,
158
159
160
161
            concat_ns=concat_ns,
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Olivier Dehaene committed
162
163
        )

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
164

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
165
def serve(
166
167
168
169
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
170
    speculate: Optional[int],
171
172
173
    dtype: Optional[str],
    trust_remote_code: bool,
    uds_path: Path,
174
175
):
    async def serve_inner(
176
177
178
179
        model_id: str,
        revision: Optional[str],
        sharded: bool = False,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
180
        speculate: Optional[int] = None,
181
182
        dtype: Optional[str] = None,
        trust_remote_code: bool = False,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
183
    ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
184
        unix_socket_template = "unix://{}-{}"
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
185
186
        if sharded:
            server_urls = [
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
187
                unix_socket_template.format(uds_path, rank)
188
                for rank in range(int(os.environ["WORLD_SIZE"]))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
189
            ]
190
            local_url = server_urls[int(os.environ["RANK"])]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
191
        else:
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
192
            local_url = unix_socket_template.format(uds_path, 0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
193
194
            server_urls = [local_url]

195
        try:
196
            model = get_model(
OlivierDehaene's avatar
OlivierDehaene committed
197
198
199
200
201
202
203
                model_id,
                revision,
                sharded,
                quantize,
                speculate,
                dtype,
                trust_remote_code,
204
            )
205
206
207
        except Exception:
            logger.exception("Error when initializing model")
            raise
208

209
210
211
212
213
214
        server = aio.server(
            interceptors=[
                ExceptionInterceptor(),
                UDSOpenTelemetryAioServerInterceptor(),
            ]
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
215
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
216
            TextGenerationService(model, Cache(), quantize, server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
217
218
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
219
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
220
221
222
223
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
224

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
225
        await server.start()
226

227
        logger.info("Server started at {}".format(local_url))
228

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
229
230
231
        try:
            await server.wait_for_termination()
        except KeyboardInterrupt:
232
            logger.info("Signal received. Shutting down")
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
233
            await server.stop(0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
234

235
    asyncio.run(
OlivierDehaene's avatar
OlivierDehaene committed
236
237
238
        serve_inner(
            model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
        )
239
    )