Commit 03d0f6a2 authored by Sean SH Choi's avatar Sean SH Choi Committed by GitHub
Browse files

feat: Add HTTP completions endpoint to kv router example (#258)


Co-authored-by: default avatarAlec <35311602+alec-flowers@users.noreply.github.com>
parent 7357b432
......@@ -14,7 +14,8 @@
# limitations under the License.
import json
from typing import AsyncIterator, List, Protocol, runtime_checkable
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm import TokensPrompt
from vllm.config import ModelConfig
......@@ -22,9 +23,11 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
......@@ -33,33 +36,45 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, chat_processor, model_config to be initialized
Requires engine_args, engine_client, processor, model_config to be initialized
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
def __init__(self):
pass
async def _parse_raw_request(self, raw_request):
if self.chat_processor is None:
raise RuntimeError("chat_processor has not been initialized")
request = self.chat_processor.parse_raw_request(raw_request)
(
conversation,
request_prompt,
engine_prompt,
) = await self.chat_processor.preprocess(raw_request)
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
engine_prompt["prompt_token_ids"]
preprocess_result.engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
......@@ -67,12 +82,19 @@ class ProcessMixIn(ProcessMixInRequired):
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return request, conversation, request_prompt, engine_prompt, sampling_params
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
if self.chat_processor is None:
raise RuntimeError("chat_processor has not been initialized")
return self.chat_processor.stream_response(
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
......@@ -80,6 +102,18 @@ class ProcessMixIn(ProcessMixInRequired):
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
......@@ -94,12 +128,12 @@ class ChatProcessor:
chat_template_content_format="auto",
)
def parse_raw_request(self, raw_request: dict) -> ChatCompletionRequest:
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(
self, raw_request: dict
) -> tuple[ConversationMessage, RequestPrompt, TokensPrompt]:
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
......@@ -122,7 +156,7 @@ class ChatProcessor:
add_special_tokens=request.add_special_tokens,
)
return conversation[0], request_prompts[0], engine_prompts[0]
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
......@@ -132,7 +166,8 @@ class ChatProcessor:
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
assert request.stream, "Only stream is supported"
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
......@@ -146,3 +181,60 @@ class ChatProcessor:
break
response = json.loads(raw_response.lstrip("data: "))
yield response
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
......@@ -15,10 +15,11 @@
import asyncio
import uuid
from typing import AsyncIterator
from enum import Enum
from typing import AsyncIterator, Tuple, Union
import uvloop
from common.chat_processor import ChatProcessor, ProcessMixIn
from common.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from common.parser import parse_vllm_args
from common.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from transformers import AutoTokenizer
......@@ -26,6 +27,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
CompletionRequest,
CompletionStreamResponse,
)
from vllm.logger import logger as vllm_logger
from vllm.outputs import RequestOutput
......@@ -39,6 +42,11 @@ from triton_distributed.runtime import (
)
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
......@@ -54,6 +62,9 @@ class Processor(ProcessMixIn):
self.model_config = self.engine_args.create_model_config()
self.tokenizer = self._create_tokenizer(engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.router_client = router_client
self.workers_client = workers_client
......@@ -71,27 +82,11 @@ class Processor(ProcessMixIn):
)
return base_tokenizer
async def generate_responses(
self, engine_generator
) -> AsyncIterator[RequestOutput]:
async for resp in engine_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
yield RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
@triton_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(self, raw_request):
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
request_type: RequestType,
):
request_id = str(uuid.uuid4())
vllm_logger.debug(f"Got raw request: {raw_request}")
(
......@@ -129,13 +124,54 @@ class Processor(ProcessMixIn):
int(worker_id),
)
output = self.generate_responses(engine_generator)
output = self._generate_responses(engine_generator, request_type)
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
async def _generate_responses(
self, engine_generator: AsyncIterator[RequestOutput], request_type: RequestType
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
prompt_idx = 0
async for resp in engine_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
elif request_type == RequestType.COMPLETION:
# Completion requests can have multiple prompts and stream generator requires the prompt index
yield (prompt_idx, request_output)
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
@triton_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate_chat(self, raw_request):
async for response in self._generate(raw_request, RequestType.CHAT):
yield response
@triton_endpoint(CompletionRequest, CompletionStreamResponse)
async def generate_completions(self, raw_request):
async for response in self._generate(raw_request, RequestType.COMPLETION):
yield response
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
......@@ -159,11 +195,16 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
preprocess_component = runtime.namespace("triton-init").component("process")
await preprocess_component.create_service()
preprocess_endpoint = preprocess_component.endpoint("chat/completions")
chat_endpoint = preprocess_component.endpoint("chat/completions")
completions_endpoint = preprocess_component.endpoint("completions")
processor = Processor(engine_args, router_client, workers_client)
assert isinstance(processor, ProcessMixIn)
await preprocess_endpoint.serve_endpoint(processor.generate)
await asyncio.gather(
chat_endpoint.serve_endpoint(processor.generate_chat),
completions_endpoint.serve_endpoint(processor.generate_completions),
)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment