server.py 5.16 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import asyncio
Olivier Dehaene's avatar
Olivier Dehaene committed
2
3
import os

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
6
7
8
9
10
11
12
13
14
from grpc import aio

from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import Optional, List

from bloom_inference.cache import Cache
from bloom_inference.model import BLOOM, Batch, BLOOMSharded
from bloom_inference.pb import generate_pb2_grpc, generate_pb2


Olivier Dehaene's avatar
Olivier Dehaene committed
15
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
16
17
18
19
20
21
22
23
24
25
    def __init__(self, model: BLOOM, cache: Cache, server_urls: List[str]):
        self.cache = cache
        self.model = model
        self.server_urls = server_urls

    async def ServiceDiscovery(self, request, context):
        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)

    async def ClearCache(self, request, context):
        self.cache.clear()
Olivier Dehaene's avatar
Olivier Dehaene committed
26
        return generate_pb2.ClearCacheResponse()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
27
28

    async def Generate(self, request, context):
Olivier Dehaene's avatar
Olivier Dehaene committed
29
30
31
32
33
34
35
36
        batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device)

        generated_texts, next_batch = self.model.generate_token(batch)
        self.cache.set(next_batch)

        return generate_pb2.GenerateResponse(
            generated_texts=[
                generated_text.to_pb() for generated_text in generated_texts
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
37
            ],
Olivier Dehaene's avatar
Olivier Dehaene committed
38
            batch=next_batch.to_pb() if next_batch else None,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39
40
41
        )

    async def GenerateWithCache(self, request, context):
Olivier Dehaene's avatar
Olivier Dehaene committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
            batches.append(batch)

        if len(batches) > 1:
            batch = Batch.concatenate(batches)
        else:
            batch = batches[0]

        generated_texts, next_batch = self.model.generate_token(batch)
        self.cache.set(next_batch)

        return generate_pb2.GenerateWithCacheResponse(
            generated_texts=[
                generated_text.to_pb() for generated_text in generated_texts
            ],
            batch=next_batch.to_pb() if next_batch else None,
        )

    async def GenerateUntilFinished(self, request, context):
        batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device)

        generated_texts = []
        while not generated_texts:
            generated_texts, next_batch = self.model.generate_token(batch)
            batch = next_batch
        self.cache.set(next_batch)

        return generate_pb2.GenerateUntilFinishedResponse(
            generated_texts=[
                generated_text.to_pb() for generated_text in generated_texts
            ],
            batch=next_batch.to_pb() if next_batch else None,
        )

    async def GenerateUntilFinishedWithCache(self, request, context):
        if len(request.batches) == 0:
            raise ValueError("Must provide at least one batch")

        batches = []
        for batch_pb in request.batches:
            batch = self.cache.pop(batch_pb.id)
            if batch is None:
                raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
            batches.append(batch)

        if len(batches) > 1:
            batch = Batch.concatenate(batches)
        else:
            batch = batches[0]

        generated_texts = []
        while not generated_texts:
            generated_texts, next_batch = self.model.generate_token(batch)
            batch = next_batch
        self.cache.set(next_batch)

        return generate_pb2.GenerateUntilFinishedWithCacheResponse(
            generated_texts=[
                generated_text.to_pb() for generated_text in generated_texts
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
108
            ],
Olivier Dehaene's avatar
Olivier Dehaene committed
109
            batch=next_batch.to_pb() if next_batch else None,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        )


def serve(model_name, sharded, shard_directory):
    async def serve_inner(
        model_name: str,
        sharded: bool = False,
        shard_directory: Optional[Path] = None,
    ):
        unix_socket_template = "unix:///tmp/bloom-inference-{}"
        if sharded:
            if shard_directory is None:
                raise ValueError("shard_directory must be set when sharded is True")
            model = BLOOMSharded(model_name, shard_directory)
            server_urls = [
                unix_socket_template.format(rank) for rank in range(model.world_size)
            ]
            local_url = unix_socket_template.format(model.rank)
        else:
            model = BLOOM(model_name)
            local_url = unix_socket_template.format(0)
            server_urls = [local_url]

        server = aio.server()
Olivier Dehaene's avatar
Olivier Dehaene committed
134
135
        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
            TextGenerationService(model, Cache(), server_urls), server
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
136
137
        )
        SERVICE_NAMES = (
Olivier Dehaene's avatar
Olivier Dehaene committed
138
            generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
139
140
141
142
143
144
145
146
147
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(SERVICE_NAMES, server)
        server.add_insecure_port(local_url)
        await server.start()
        print("Server started at {}".format(local_url))
        await server.wait_for_termination()

    asyncio.run(serve_inner(model_name, sharded, shard_directory))