server.py 3.77 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
3
import os

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
6
7
8
9
10
11
12
13
14
from grpc import aio

from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import Optional, List

from bloom_inference.cache import Cache
from bloom_inference.model import BLOOM, Batch, BLOOMSharded
from bloom_inference.pb import generate_pb2_grpc, generate_pb2


Olivier Dehaene's avatar
Olivier Dehaene committed
15
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
16
17
18
19
20
21
22
23
24
25
    def __init__(self, model: BLOOM, cache: Cache, server_urls: List[str]):
        self.cache = cache
        self.model = model
        self.server_urls = server_urls

    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
        self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
26
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
27
28

    async def Generate(self, request, context):
Olivier Dehaene's avatar
Olivier Dehaene committed
29
30
31
32
33
34
35
36
        batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device)

        generated_texts, next_batch = self.model.generate_token(batch)
        self.cache.set(next_batch)

        return generate_pb2.GenerateResponse(
            generated_texts=[
                generated_text.to_pb() for generated_text in generated_texts
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
37
            ],
Olivier Dehaene's avatar
Olivier Dehaene committed
38
            batch=next_batch.to_pb() if next_batch else None,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39
40
41
        )

    async def GenerateWithCache(self, request, context):
Olivier Dehaene's avatar
Olivier Dehaene committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
            batches.append(batch)

        if len(batches) > 1:
            batch = Batch.concatenate(batches)
        else:
            batch = batches[0]

        generated_texts, next_batch = self.model.generate_token(batch)
        self.cache.set(next_batch)

        return generate_pb2.GenerateWithCacheResponse(
            generated_texts=[
                generated_text.to_pb() for generated_text in generated_texts
            ],
            batch=next_batch.to_pb() if next_batch else None,
        )

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
67

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
68
69
70
71
72
73
def serve(
    model_name: str,
    sharded: bool,
    uds_path: Path,
    shard_directory: Optional[Path] = None,
):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
74
75
76
77
78
    async def serve_inner(
        model_name: str,
        sharded: bool = False,
        shard_directory: Optional[Path] = None,
    ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
79
        unix_socket_template = "unix://{}-{}"
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
80
81
82
83
84
        if sharded:
            if shard_directory is None:
                raise ValueError("shard_directory must be set when sharded is True")
            model = BLOOMSharded(model_name, shard_directory)
            server_urls = [
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
85
86
                unix_socket_template.format(uds_path, rank)
                for rank in range(model.world_size)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
87
            ]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
88
            local_url = server_urls[model.rank]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
89
90
        else:
            model = BLOOM(model_name)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
91
            local_url = unix_socket_template.format(uds_path, 0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
92
93
94
            server_urls = [local_url]

        server = aio.server()
Olivier Dehaene's avatar
Olivier Dehaene committed
95
96
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
            TextGenerationService(model, Cache(), server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
97
98
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
99
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
100
101
102
103
104
105
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
        await server.start()
        print("Server started at {}".format(local_url))
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
106
107
108
109
110
        try:
            await server.wait_for_termination()
        except KeyboardInterrupt:
            print("Signal received. Shutting down")
            await server.stop(0)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
111
112

    asyncio.run(serve_inner(model_name, sharded, shard_directory))