"vscode:/vscode.git/clone" did not exist on "5be8f1ed987fe49030acc51c62bdd6898720aa83"
server.py 3.62 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
3
import os

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
6
7
from grpc import aio

from grpc_reflection.v1alpha import reflection
from pathlib import Path
8
from typing import List
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9

10
11
12
13
from text_generation.cache import Cache
from text_generation.models import Model, get_model
from text_generation.models.types import Batch
from text_generation.pb import generate_pb2_grpc, generate_pb2
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
14
15


Olivier Dehaene's avatar
Olivier Dehaene committed
16
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
17
    def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18
19
20
21
22
23
24
25
26
        self.cache = cache
        self.model = model
        self.server_urls = server_urls

    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
        self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
27
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
28
29

    async def Generate(self, request, context):
Olivier Dehaene's avatar
Olivier Dehaene committed
30
31
32
33
34
35
36
37
        batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device)

        generated_texts, next_batch = self.model.generate_token(batch)
        self.cache.set(next_batch)

        return generate_pb2.GenerateResponse(
            generated_texts=[
                generated_text.to_pb() for generated_text in generated_texts
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
38
            ],
Olivier Dehaene's avatar
Olivier Dehaene committed
39
            batch=next_batch.to_pb() if next_batch else None,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
40
41
42
        )

    async def GenerateWithCache(self, request, context):
Olivier Dehaene's avatar
Olivier Dehaene committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
            batches.append(batch)

        if len(batches) > 1:
            batch = Batch.concatenate(batches)
        else:
            batch = batches[0]

        generated_texts, next_batch = self.model.generate_token(batch)
        self.cache.set(next_batch)

        return generate_pb2.GenerateWithCacheResponse(
            generated_texts=[
                generated_text.to_pb() for generated_text in generated_texts
            ],
            batch=next_batch.to_pb() if next_batch else None,
        )

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
68

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
69
70
71
def serve(
    model_name: str,
    sharded: bool,
72
    quantize: bool,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
73
74
    uds_path: Path,
):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
75
76
77
    async def serve_inner(
        model_name: str,
        sharded: bool = False,
78
        quantize: bool = False,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
79
    ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
80
        unix_socket_template = "unix://{}-{}"
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
81
82
        if sharded:
            server_urls = [
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
83
                unix_socket_template.format(uds_path, rank)
84
                for rank in range(int(os.environ["WORLD_SIZE"]))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
85
            ]
86
            local_url = server_urls[int(os.environ["RANK"])]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
87
        else:
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
88
            local_url = unix_socket_template.format(uds_path, 0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
89
90
            server_urls = [local_url]

91
92
        model = get_model(model_name, sharded, quantize)

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
93
        server = aio.server()
Olivier Dehaene's avatar
Olivier Dehaene committed
94
95
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
            TextGenerationService(model, Cache(), server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
96
97
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
98
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
99
100
101
102
103
104
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
        await server.start()
        print("Server started at {}".format(local_url))
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
105
106
107
108
109
        try:
            await server.wait_for_termination()
        except KeyboardInterrupt:
            print("Signal received. Shutting down")
            await server.stop(0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
110

111
    asyncio.run(serve_inner(model_name, sharded, quantize))