Unverified Commit 7fbfbb0d authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: Add token streaming using ServerSideEvents support (#36)

Add token streaming using ServerSideEvents (SSE).

The signature of the SSE events is: 

```rust
struct Details {
    finish_reason: String,
    generated_tokens: u32,
    seed: Option<u64>,
}

struct StreamResponse {
    token: Token,
    generated_text: Option<String>,
    details: Option<Details>,
}

struct ErrorResponse {
    error: String,
}
```
parent cd298bc5
...@@ -29,26 +29,61 @@ class Batch(ABC): ...@@ -29,26 +29,61 @@ class Batch(ABC):
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError raise NotImplementedError
@abstractmethod
def __len__(self):
raise NotImplementedError
@dataclass @dataclass
class GeneratedText: class GeneratedText:
request: generate_pb2.Request text: str
output_text: str
generated_tokens: int generated_tokens: int
tokens: List[str] finish_reason: str
token_ids: List[int]
logprobs: List[float]
reason: str
seed: Optional[int] seed: Optional[int]
def to_pb(self) -> generate_pb2.GeneratedText: def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText( return generate_pb2.GeneratedText(
request=self.request, text=self.text,
output_text=self.output_text,
generated_tokens=self.generated_tokens, generated_tokens=self.generated_tokens,
tokens=self.tokens, finish_reason=self.finish_reason,
token_ids=self.token_ids,
logprobs=self.logprobs,
finish_reason=self.reason,
seed=self.seed, seed=self.seed,
) )
@dataclass
class PrefillTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
)
def __len__(self):
return len(self.token_ids)
@dataclass
class Generation:
request_id: int
prefill_tokens: Optional[PrefillTokens]
token_id: int
token_logprob: float
token_text: str
generated_text: Optional[GeneratedText]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
request_id=self.request_id,
prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None
else None,
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
)
...@@ -27,22 +27,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -27,22 +27,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.cache.clear() self.cache.clear()
return generate_pb2.ClearCacheResponse() return generate_pb2.ClearCacheResponse()
async def Generate(self, request, context): async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device request.batch, self.model.tokenizer, self.model.device
) )
generated_texts, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.GenerateResponse( return generate_pb2.PrefillResponse(
generated_texts=[ generations=[generation.to_pb() for generation in generations],
generated_text.to_pb() for generated_text in generated_texts
],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
) )
async def GenerateWithCache(self, request, context): async def Decode(self, request, context):
if len(request.batches) == 0: if len(request.batches) == 0:
raise ValueError("Must provide at least one batch") raise ValueError("Must provide at least one batch")
...@@ -58,13 +56,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -58,13 +56,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
else: else:
batch = batches[0] batch = batches[0]
generated_texts, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.GenerateWithCacheResponse( return generate_pb2.DecodeResponse(
generated_texts=[ generations=[generation.to_pb() for generation in generations],
generated_text.to_pb() for generated_text in generated_texts
],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment