server.py 7.27 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
import os
3
import torch
Olivier Dehaene's avatar
Olivier Dehaene committed
4

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
5
from grpc import aio
6
from loguru import logger
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
8
9

from grpc_reflection.v1alpha import reflection
from pathlib import Path
10
from typing import List, Optional
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
11

12
13
14
15
16
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
17
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18

OlivierDehaene's avatar
OlivierDehaene committed
19

Olivier Dehaene's avatar
Olivier Dehaene committed
20
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
21
    def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
22
23
24
        self.cache = cache
        self.model = model
        self.server_urls = server_urls
25
26
27
28
        # For some reason, inference_mode does not work well with GLOO which we use on CPU
        if model.device.type == "cuda":
            # Force inference mode for the lifetime of TextGenerationService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
29

30
31
32
    async def Info(self, request, context):
        return self.model.info

33
34
35
36
37
    async def Health(self, request, context):
        if self.model.device.type == "cuda":
            torch.zeros((2, 2)).cuda()
        return generate_pb2.HealthResponse()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
38
39
40
41
    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
42
43
44
45
        if request.HasField("id"):
            self.cache.delete(request.id)
        else:
            self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
46
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
47

48
49
50
51
    async def FilterBatch(self, request, context):
        batch = self.cache.pop(request.batch_id)
        if batch is None:
            raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
52
        filtered_batch = batch.filter(request.request_ids)
53
54
55
56
        self.cache.set(filtered_batch)

        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())

57
    async def Warmup(self, request, context):
OlivierDehaene's avatar
OlivierDehaene committed
58
59
60
        if (
            self.model.batch_type == IdeficsCausalLMBatch
        ):  # Hack, i would rather use kwargs in the `from_pb` call
61
            batch = self.model.batch_type.from_pb(
OlivierDehaene's avatar
OlivierDehaene committed
62
63
64
65
66
                request.batch,
                self.model.tokenizer,
                self.model.processor,
                self.model.dtype,
                self.model.device,
67
68
69
70
71
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
72
        max_supported_total_tokens = self.model.warmup(batch)
73

74
75
76
        return generate_pb2.WarmupResponse(
            max_supported_total_tokens=max_supported_total_tokens
        )
77

78
    async def Prefill(self, request, context):
OlivierDehaene's avatar
OlivierDehaene committed
79
80
81
        if (
            self.model.batch_type == IdeficsCausalLMBatch
        ):  # Hack, i would rather use kwargs in the `from_pb` call
82
            batch = self.model.batch_type.from_pb(
OlivierDehaene's avatar
OlivierDehaene committed
83
84
85
86
87
                request.batch,
                self.model.tokenizer,
                self.model.processor,
                self.model.dtype,
                self.model.device,
88
89
90
91
92
            )
        else:
            batch = self.model.batch_type.from_pb(
                request.batch, self.model.tokenizer, self.model.dtype, self.model.device
            )
Olivier Dehaene's avatar
Olivier Dehaene committed
93

94
        generations, next_batch = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
95
96
        self.cache.set(next_batch)

97
98
        return generate_pb2.PrefillResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
99
            batch=next_batch.to_pb() if next_batch else None,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
100
101
        )

102
    async def Decode(self, request, context):
Olivier Dehaene's avatar
Olivier Dehaene committed
103
104
105
106
107
108
109
110
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
111
            batches.append(batch)
112
113
114

        if len(batches) == 0:
            raise ValueError("All batches are empty")
Olivier Dehaene's avatar
Olivier Dehaene committed
115
116

        if len(batches) > 1:
117
            batch = self.model.batch_type.concatenate(batches)
Olivier Dehaene's avatar
Olivier Dehaene committed
118
119
120
        else:
            batch = batches[0]

121
        generations, next_batch = self.model.generate_token(batch)
Olivier Dehaene's avatar
Olivier Dehaene committed
122
123
        self.cache.set(next_batch)

124
125
        return generate_pb2.DecodeResponse(
            generations=[generation.to_pb() for generation in generations],
Olivier Dehaene's avatar
Olivier Dehaene committed
126
127
128
            batch=next_batch.to_pb() if next_batch else None,
        )

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
129

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
130
def serve(
131
132
133
134
    model_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
Nicolas Patry's avatar
Nicolas Patry committed
135
    speculate: Optional[int],
136
137
138
    dtype: Optional[str],
    trust_remote_code: bool,
    uds_path: Path,
139
140
):
    async def serve_inner(
141
142
143
144
        model_id: str,
        revision: Optional[str],
        sharded: bool = False,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
145
        speculate: Optional[int] = None,
146
147
        dtype: Optional[str] = None,
        trust_remote_code: bool = False,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
148
    ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
149
        unix_socket_template = "unix://{}-{}"
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
150
151
        if sharded:
            server_urls = [
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
152
                unix_socket_template.format(uds_path, rank)
153
                for rank in range(int(os.environ["WORLD_SIZE"]))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
154
            ]
155
            local_url = server_urls[int(os.environ["RANK"])]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
156
        else:
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
157
            local_url = unix_socket_template.format(uds_path, 0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
158
159
            server_urls = [local_url]

160
        try:
161
            model = get_model(
Nicolas Patry's avatar
Nicolas Patry committed
162
                model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
163
            )
164
165
166
        except Exception:
            logger.exception("Error when initializing model")
            raise
167

168
169
170
171
172
        if quantize == "gptq":
            try:
                # When using GPTQ, Exllama kernels need some global kernels
                # For which we have the finale shapes only after the model has loaded
                # This will allocate those buffers.
Nicolas Patry's avatar
Nicolas Patry committed
173
                from text_generation_server.utils.layers import (
174
                    create_exllama_buffers,
175
                    set_device,
176
177
                )

178
                set_device(model.device)
179
180
181
182
                create_exllama_buffers()
            except ImportError:
                pass

183
184
185
186
187
188
        server = aio.server(
            interceptors=[
                ExceptionInterceptor(),
                UDSOpenTelemetryAioServerInterceptor(),
            ]
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
189
190
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
            TextGenerationService(model, Cache(), server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
191
192
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
193
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
194
195
196
197
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
198

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
199
        await server.start()
200

201
        logger.info("Server started at {}".format(local_url))
202

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
203
204
205
        try:
            await server.wait_for_termination()
        except KeyboardInterrupt:
206
            logger.info("Signal received. Shutting down")
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
207
            await server.stop(0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
208

209
    asyncio.run(
Nicolas Patry's avatar
Nicolas Patry committed
210
        serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code)
211
    )