"vscode:/vscode.git/clone" did not exist on "ee02a111d77039ebd9509257d8e3803e1f6c7bfd"
server.py 7.38 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
import os
3
import torch
Olivier Dehaene's avatar
Olivier Dehaene committed
4

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
5
from grpc import aio
6
from loguru import logger
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
8
9

from grpc_reflection.v1alpha import reflection
from pathlib import Path
10
from typing import List, Optional
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
11

12
13
14
15
16
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
17
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18

OlivierDehaene's avatar
OlivierDehaene committed
19

Olivier Dehaene's avatar
Olivier Dehaene committed
20
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
21
    def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
22
23
24
        self.cache = cache
        self.model = model
        self.server_urls = server_urls
25
26
27
28
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        if model.device.type == "cuda":
            # Force inference mode for the lifetime of TextGenerationService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
29

30
31
32
    async def Info(self, request, context):
        return self.model.info

33
34
35
36
37
    async def Health(self, request, context):
        if self.model.device.type == "cuda":
            torch.zeros((2, 2)).cuda()
        return generate_pb2.HealthResponse()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
38
39
40
41
    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
42
43
44
45
        if request.HasField("id"):
            self.cache.delete(request.id)
        else:
            self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
46
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
47

48
49
50
51
    async def FilterBatch(self, request, context):
        batch = self.cache.pop(request.batch_id)
        if batch is None:
            raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
52
        filtered_batch = batch.filter(request.request_ids)
53
54
55
56
        self.cache.set(filtered_batch)

        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())

57
    async def Warmup(self, request, context):
OlivierDehaene's avatar
OlivierDehaene committed
58
59
60
        if (
            self.model.batch_type == IdeficsCausalLMBatch
        ):  # Hack, i would rather use kwargs in the `from_pb` call
61
            batch = self.model.batch_type.from_pb(
OlivierDehaene's avatar
OlivierDehaene committed
62
63
64
65
66
                request.batch,
                self.model.tokenizer,
                self.model.processor,
                self.model.dtype,
                self.model.device,
67
68
69
70
71
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
72
        max_supported_total_tokens = self.model.warmup(batch)
73

74
75
76
        return generate_pb2.WarmupResponse(
            max_supported_total_tokens=max_supported_total_tokens
        )
77

78
    async def Prefill(self, request, context):
OlivierDehaene's avatar
OlivierDehaene committed
79
80
81
        if (
            self.model.batch_type == IdeficsCausalLMBatch
        ):  # Hack, i would rather use kwargs in the `from_pb` call
82
            batch = self.model.batch_type.from_pb(
OlivierDehaene's avatar
OlivierDehaene committed
83
84
85
86
87
                request.batch,
                self.model.tokenizer,
                self.model.processor,
                self.model.dtype,
                self.model.device,
88
89
90
91
92
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
Olivier Dehaene's avatar
Olivier Dehaene committed
93

94
        generations, next_batch = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
95
96
        self.cache.set(next_batch)

97
98
        return generate_pb2.PrefillResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
99
            batch=next_batch.to_pb() if next_batch else None,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
100
101
        )

102
    async def Decode(self, request, context):
Olivier Dehaene's avatar
Olivier Dehaene committed
103
104
105
106
107
108
109
110
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
111
            batches.append(batch)
112
113
114

        if len(batches) == 0:
            raise ValueError("All batches are empty")
Olivier Dehaene's avatar
Olivier Dehaene committed
115
116

        if len(batches) > 1:
117
            batch = self.model.batch_type.concatenate(batches)
Olivier Dehaene's avatar
Olivier Dehaene committed
118
119
120
        else:
            batch = batches[0]

121
        generations, next_batch = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
122
123
        self.cache.set(next_batch)

124
125
        return generate_pb2.DecodeResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
126
127
128
            batch=next_batch.to_pb() if next_batch else None,
        )

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
129

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
130
def serve(
131
132
133
134
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
135
    speculate: Optional[int],
136
137
138
    dtype: Optional[str],
    trust_remote_code: bool,
    uds_path: Path,
139
140
):
    async def serve_inner(
141
142
143
144
        model_id: str,
        revision: Optional[str],
        sharded: bool = False,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
145
        speculate: Optional[int] = None,
146
147
        dtype: Optional[str] = None,
        trust_remote_code: bool = False,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
148
    ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
149
        unix_socket_template = "unix://{}-{}"
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
150
151
        if sharded:
            server_urls = [
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
152
                unix_socket_template.format(uds_path, rank)
153
                for rank in range(int(os.environ["WORLD_SIZE"]))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
154
            ]
155
            local_url = server_urls[int(os.environ["RANK"])]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
156
        else:
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
157
            local_url = unix_socket_template.format(uds_path, 0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
158
159
            server_urls = [local_url]

160
        try:
161
            model = get_model(
OlivierDehaene's avatar
OlivierDehaene committed
162
163
164
165
166
167
168
                model_id,
                revision,
                sharded,
                quantize,
                speculate,
                dtype,
                trust_remote_code,
169
            )
170
171
172
        except Exception:
            logger.exception("Error when initializing model")
            raise
173

174
175
176
177
178
        if quantize == "gptq":
            try:
                # When using GPTQ, Exllama kernels need some global kernels
                # For which we have the finale shapes only after the model has loaded
                # This will allocate those buffers.
Nicolas Patry's avatar
Nicolas Patry committed
179
                from text_generation_server.utils.layers import (
180
                    create_exllama_buffers,
181
                    set_device,
182
183
                )

184
                set_device(model.device)
185
186
187
188
                create_exllama_buffers()
            except ImportError:
                pass

189
190
191
192
193
194
        server = aio.server(
            interceptors=[
                ExceptionInterceptor(),
                UDSOpenTelemetryAioServerInterceptor(),
            ]
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
195
196
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
            TextGenerationService(model, Cache(), server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
197
198
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
199
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
200
201
202
203
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
204

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
205
        await server.start()
206

207
        logger.info("Server started at {}".format(local_url))
208

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
209
210
211
        try:
            await server.wait_for_termination()
        except KeyboardInterrupt:
212
            logger.info("Signal received. Shutting down")
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
213
            await server.stop(0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
214

215
    asyncio.run(
OlivierDehaene's avatar
OlivierDehaene committed
216
217
218
        serve_inner(
            model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
        )
219
    )