server.py 7.85 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
import os
3
import torch
4
import time
Olivier Dehaene's avatar
Olivier Dehaene committed
5

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
6
from grpc import aio
7
from loguru import logger
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
8
9
10

from grpc_reflection.v1alpha import reflection
from pathlib import Path
11
from typing import List, Optional
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12

13
14
15
16
17
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
18
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
19

OlivierDehaene's avatar
OlivierDehaene committed
20

Olivier Dehaene's avatar
Olivier Dehaene committed
21
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
22
    def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
23
24
25
        self.cache = cache
        self.model = model
        self.server_urls = server_urls
26
27
28
29
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        if model.device.type == "cuda":
            # Force inference mode for the lifetime of TextGenerationService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
30

31
32
33
    async def Info(self, request, context):
        return self.model.info

34
35
36
37
38
    async def Health(self, request, context):
        if self.model.device.type == "cuda":
            torch.zeros((2, 2)).cuda()
        return generate_pb2.HealthResponse()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39
40
41
42
    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
43
44
45
46
        if request.HasField("id"):
            self.cache.delete(request.id)
        else:
            self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
47
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
48

49
50
51
52
    async def FilterBatch(self, request, context):
        batch = self.cache.pop(request.batch_id)
        if batch is None:
            raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
53
        filtered_batch = batch.filter(request.request_ids)
54
55
56
57
        self.cache.set(filtered_batch)

        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())

58
    async def Warmup(self, request, context):
OlivierDehaene's avatar
OlivierDehaene committed
59
60
61
        if (
            self.model.batch_type == IdeficsCausalLMBatch
        ):  # Hack, i would rather use kwargs in the `from_pb` call
62
            batch = self.model.batch_type.from_pb(
OlivierDehaene's avatar
OlivierDehaene committed
63
64
65
66
67
                request.batch,
                self.model.tokenizer,
                self.model.processor,
                self.model.dtype,
                self.model.device,
68
69
70
71
72
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
73
        max_supported_total_tokens = self.model.warmup(batch)
74

75
76
77
        return generate_pb2.WarmupResponse(
            max_supported_total_tokens=max_supported_total_tokens
        )
78

79
    async def Prefill(self, request, context):
80
        start = time.time_ns()
OlivierDehaene's avatar
OlivierDehaene committed
81
82
83
        if (
            self.model.batch_type == IdeficsCausalLMBatch
        ):  # Hack, i would rather use kwargs in the `from_pb` call
84
            batch = self.model.batch_type.from_pb(
OlivierDehaene's avatar
OlivierDehaene committed
85
86
87
88
89
                request.batch,
                self.model.tokenizer,
                self.model.processor,
                self.model.dtype,
                self.model.device,
90
91
92
93
94
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
Olivier Dehaene's avatar
Olivier Dehaene committed
95

96
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
97
98
        self.cache.set(next_batch)

99
100
        return generate_pb2.PrefillResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
101
            batch=next_batch.to_pb() if next_batch else None,
102
103
104
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
105
106
        )

107
    async def Decode(self, request, context):
108
        start = time.time_ns()
Olivier Dehaene's avatar
Olivier Dehaene committed
109
110
111
112
113
114
115
116
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
117
            batches.append(batch)
118
119
120

        if len(batches) == 0:
            raise ValueError("All batches are empty")
Olivier Dehaene's avatar
Olivier Dehaene committed
121
122

        if len(batches) > 1:
123
            start_concat = time.time_ns()
124
            batch = self.model.batch_type.concatenate(batches)
125
            concat_ns = time.time_ns() - start_concat
Olivier Dehaene's avatar
Olivier Dehaene committed
126
127
        else:
            batch = batches[0]
128
            concat_ns = None
Olivier Dehaene's avatar
Olivier Dehaene committed
129

130
        generations, next_batch, timings = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
131
132
        self.cache.set(next_batch)

133
134
        return generate_pb2.DecodeResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
135
            batch=next_batch.to_pb() if next_batch else None,
136
137
138
139
            concat_ns=concat_ns,
            forward_ns=timings[0],
            decode_ns=timings[1],
            total_ns=time.time_ns() - start,
Olivier Dehaene's avatar
Olivier Dehaene committed
140
141
        )

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
142

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
143
def serve(
144
145
146
147
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
148
    speculate: Optional[int],
149
150
151
    dtype: Optional[str],
    trust_remote_code: bool,
    uds_path: Path,
152
153
):
    async def serve_inner(
154
155
156
157
        model_id: str,
        revision: Optional[str],
        sharded: bool = False,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
158
        speculate: Optional[int] = None,
159
160
        dtype: Optional[str] = None,
        trust_remote_code: bool = False,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
161
    ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
162
        unix_socket_template = "unix://{}-{}"
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
163
164
        if sharded:
            server_urls = [
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
165
                unix_socket_template.format(uds_path, rank)
166
                for rank in range(int(os.environ["WORLD_SIZE"]))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
167
            ]
168
            local_url = server_urls[int(os.environ["RANK"])]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
169
        else:
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
170
            local_url = unix_socket_template.format(uds_path, 0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
171
172
            server_urls = [local_url]

173
        try:
174
            model = get_model(
OlivierDehaene's avatar
OlivierDehaene committed
175
176
177
178
179
180
181
                model_id,
                revision,
                sharded,
                quantize,
                speculate,
                dtype,
                trust_remote_code,
182
            )
183
184
185
        except Exception:
            logger.exception("Error when initializing model")
            raise
186

187
188
189
190
191
        if quantize == "gptq":
            try:
                # When using GPTQ, Exllama kernels need some global kernels
                # For which we have the finale shapes only after the model has loaded
                # This will allocate those buffers.
Nicolas Patry's avatar
Nicolas Patry committed
192
                from text_generation_server.utils.layers import (
193
                    create_exllama_buffers,
194
                    set_device,
195
196
                )

197
                set_device(model.device)
198
199
200
201
                create_exllama_buffers()
            except ImportError:
                pass

202
203
204
205
206
207
        server = aio.server(
            interceptors=[
                ExceptionInterceptor(),
                UDSOpenTelemetryAioServerInterceptor(),
            ]
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
208
209
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
            TextGenerationService(model, Cache(), server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
210
211
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
212
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
213
214
215
216
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
217

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
218
        await server.start()
219

220
        logger.info("Server started at {}".format(local_url))
221

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
222
223
224
        try:
            await server.wait_for_termination()
        except KeyboardInterrupt:
225
            logger.info("Signal received. Shutting down")
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
226
            await server.stop(0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
227

228
    asyncio.run(
OlivierDehaene's avatar
OlivierDehaene committed
229
230
231
        serve_inner(
            model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
        )
232
    )