Commit 338b35f5 authored by Ceng23333's avatar Ceng23333
Browse files

Issue/193: feats for deployment


Signed-off-by: default avatarCeng23333 <441651826@qq.com>
parent f2d9d397
...@@ -228,12 +228,40 @@ class LLMEngine: ...@@ -228,12 +228,40 @@ class LLMEngine:
req.generated_token_ids.append(token_id) req.generated_token_ids.append(token_id)
if req.is_prefill: if req.is_prefill:
req.is_prefill = False req.is_prefill = False
# vLLM-style replacement character handling is primarily relevant for streaming.
token_text = self.tokenizer.decode(token_id) # For offline generation (no output queue), keep the fast incremental path.
if req._output_queue is None:
token_text = self.detokenize([token_id])
req.generated_text += token_text req.generated_text += token_text
else:
# Streaming path: compute delta from a full decode so we can hold back
# trailing '\ufffd' (likely an incomplete UTF-8 sequence).
decoded_text = self.detokenize(req.generated_token_ids)
finished_now = False
if self._check_request_finished(req, token_id): if self._check_request_finished(req, token_id):
req.mark_finished(req.finish_reason) req.mark_finished(req.finish_reason)
finished_now = True
# Update generated_text to the latest decode (used for stop-string checks and debugging)
req.generated_text = decoded_text
holds_back_incomplete_utf8 = (
bool(decoded_text) and decoded_text.endswith("\ufffd")
)
# vLLM-style: hold back only if we are not on the final chunk.
if holds_back_incomplete_utf8 and not finished_now:
token_text = ""
else:
last_len = getattr(req, "_stream_last_yielded_length", 0)
token_text = decoded_text[last_len:]
if token_text:
req._stream_last_yielded_length = len(decoded_text)
# For non-streaming, finish checks happen here.
if req._output_queue is None and self._check_request_finished(req, token_id):
req.mark_finished(req.finish_reason)
# Put output in queue if it exists (for async streaming) # Put output in queue if it exists (for async streaming)
if req._output_queue is not None: if req._output_queue is not None:
...@@ -283,12 +311,15 @@ class LLMEngine: ...@@ -283,12 +311,15 @@ class LLMEngine:
self, self,
messages: List[dict], messages: List[dict],
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
chat_template_kwargs: Optional[dict] = None,
) -> str: ) -> str:
"""Apply chat template to messages.""" """Apply chat template to messages."""
chat_template_kwargs = chat_template_kwargs or {}
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
conversation=messages, conversation=messages,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
tokenize=False, tokenize=False,
**chat_template_kwargs,
) )
...@@ -486,6 +517,10 @@ class AsyncLLMEngine: ...@@ -486,6 +517,10 @@ class AsyncLLMEngine:
self._running = False self._running = False
self._step_thread: Optional[threading.Thread] = None self._step_thread: Optional[threading.Thread] = None
self._healthy = True
def is_healthy(self) -> bool:
return bool(self._healthy)
def start(self): def start(self):
"""Start the background inference loop.""" """Start the background inference loop."""
...@@ -520,6 +555,7 @@ class AsyncLLMEngine: ...@@ -520,6 +555,7 @@ class AsyncLLMEngine:
time.sleep(0.01) time.sleep(0.01)
except Exception as e: except Exception as e:
logger.error(f"Error in step loop: {e}", exc_info=True) logger.error(f"Error in step loop: {e}", exc_info=True)
self._healthy = False
self._running = False self._running = False
break break
...@@ -581,6 +617,8 @@ class AsyncLLMEngine: ...@@ -581,6 +617,8 @@ class AsyncLLMEngine:
request_id: Optional[str] = None, request_id: Optional[str] = None,
request_data: Optional[dict] = None, request_data: Optional[dict] = None,
http_request: Optional[any] = None, http_request: Optional[any] = None,
add_generation_prompt: bool = True,
chat_template_kwargs: Optional[dict] = None,
) -> InferenceRequest: ) -> InferenceRequest:
"""Add a chat request to the engine. """Add a chat request to the engine.
...@@ -594,7 +632,11 @@ class AsyncLLMEngine: ...@@ -594,7 +632,11 @@ class AsyncLLMEngine:
Returns: Returns:
The created InferenceRequest object. The created InferenceRequest object.
""" """
prompt = self.engine.apply_chat_template(messages, add_generation_prompt=True) prompt = self.engine.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
chat_template_kwargs=chat_template_kwargs,
)
return self.add_request( return self.add_request(
prompt=prompt, prompt=prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
...@@ -607,6 +649,7 @@ class AsyncLLMEngine: ...@@ -607,6 +649,7 @@ class AsyncLLMEngine:
self, self,
request: InferenceRequest, request: InferenceRequest,
timeout: float = 100.0, timeout: float = 100.0,
request_timeout: Optional[float] = None,
) -> AsyncIterator[TokenOutput]: ) -> AsyncIterator[TokenOutput]:
"""Stream tokens from a request. """Stream tokens from a request.
...@@ -619,6 +662,7 @@ class AsyncLLMEngine: ...@@ -619,6 +662,7 @@ class AsyncLLMEngine:
""" """
import asyncio import asyncio
start = time.time()
while True: while True:
if request.is_finished() and request.output_queue.async_q.empty(): if request.is_finished() and request.output_queue.async_q.empty():
break break
...@@ -635,6 +679,20 @@ class AsyncLLMEngine: ...@@ -635,6 +679,20 @@ class AsyncLLMEngine:
if token_output.finished: if token_output.finished:
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Enforce request-level timeout even if no tokens are produced.
if request_timeout is not None:
now = time.time()
if now - start > float(request_timeout):
request.mark_timeout()
yield TokenOutput(
request_id=request.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=FinishReason.TIMEOUT,
generated_text=request.generated_text,
)
break
if request.is_finished(): if request.is_finished():
break break
continue continue
......
...@@ -144,6 +144,10 @@ class InferenceRequest: ...@@ -144,6 +144,10 @@ class InferenceRequest:
# Output management (for async streaming) # Output management (for async streaming)
self._output_queue: Optional[janus.Queue] = None self._output_queue: Optional[janus.Queue] = None
# Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer)
# Used by the engine to compute "delta" text chunks from a full decode.
self._stream_last_yielded_length: int = 0
@property @property
def output_queue(self) -> janus.Queue: def output_queue(self) -> janus.Queue:
"""Lazy initialization of output queue.""" """Lazy initialization of output queue."""
......
...@@ -154,6 +154,10 @@ class Scheduler: ...@@ -154,6 +154,10 @@ class Scheduler:
req = self.waiting_queue.sync_q.get_nowait() req = self.waiting_queue.sync_q.get_nowait()
except queue.Empty: except queue.Empty:
break break
# Skip requests that were already finished (e.g., timed out/canceled while waiting)
if req.is_finished():
self.complete_requests([req])
continue
req_tokens = req.get_input_tokens() req_tokens = req.get_input_tokens()
num_required_blocks = req.get_num_blocks_required(self.block_size) num_required_blocks = req.get_num_blocks_required(self.block_size)
...@@ -185,6 +189,10 @@ class Scheduler: ...@@ -185,6 +189,10 @@ class Scheduler:
req = self.running_queue.sync_q.get_nowait() req = self.running_queue.sync_q.get_nowait()
except queue.Empty: except queue.Empty:
break break
# Skip requests that were already finished (e.g., timed out/canceled while running)
if req.is_finished():
self.complete_requests([req])
continue
# Decode phase: allocate slot for newly generated token # Decode phase: allocate slot for newly generated token
try: try:
......
...@@ -10,6 +10,7 @@ import uuid ...@@ -10,6 +10,7 @@ import uuid
import argparse import argparse
import uvicorn import uvicorn
import logging import logging
import os
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
...@@ -22,7 +23,7 @@ DEFAULT_STREAM_TIMEOUT = 100.0 ...@@ -22,7 +23,7 @@ DEFAULT_STREAM_TIMEOUT = 100.0
DEFAULT_REQUEST_TIMEOUT = 1000.0 DEFAULT_REQUEST_TIMEOUT = 1000.0
def chunk_json(id_, content=None, role=None, finish_reason=None): def chunk_json(id_, content=None, role=None, finish_reason=None, model: str = "unknown"):
"""Generate JSON chunk for streaming response.""" """Generate JSON chunk for streaming response."""
delta = {} delta = {}
if content: if content:
...@@ -33,7 +34,7 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): ...@@ -33,7 +34,7 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
"id": id_, "id": id_,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": int(time.time()), "created": int(time.time()),
"model": "jiuge", "model": model,
"system_fingerprint": None, "system_fingerprint": None,
"choices": [ "choices": [
{ {
...@@ -84,6 +85,8 @@ class InferenceServer: ...@@ -84,6 +85,8 @@ class InferenceServer:
port: Server port number. port: Server port number.
""" """
self.model_path = model_path self.model_path = model_path
# vLLM-like served model id: directory name of model_path
self.model_id = os.path.basename(os.path.normpath(model_path)) or "model"
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self.tensor_parallel_size = tensor_parallel_size self.tensor_parallel_size = tensor_parallel_size
...@@ -136,7 +139,10 @@ class InferenceServer: ...@@ -136,7 +139,10 @@ class InferenceServer:
def _register_routes(self, app: FastAPI): def _register_routes(self, app: FastAPI):
"""Register API routes.""" """Register API routes."""
# OpenAI-compatible chat completions endpoint.
# Support both legacy path and OpenAI-style /v1 prefix for proxy/router compatibility.
@app.post("/chat/completions") @app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def chat_completions(request: Request): async def chat_completions(request: Request):
try: try:
data = await request.json() data = await request.json()
...@@ -169,15 +175,21 @@ class InferenceServer: ...@@ -169,15 +175,21 @@ class InferenceServer:
@app.get("/health") @app.get("/health")
async def health(): async def health():
# Expose engine health so babysitter/registry can treat backend as unhealthy.
if (
self.engine is not None
and hasattr(self.engine, "is_healthy")
and not self.engine.is_healthy()
):
return JSONResponse(content={"status": "unhealthy"}, status_code=503)
return {"status": "healthy"} return {"status": "healthy"}
@app.get("/v1/models") def _models_payload():
async def list_models():
return { return {
"object": "list", "object": "list",
"data": [ "data": [
{ {
"id": "jiuge", "id": self.model_id,
"object": "model", "object": "model",
"created": int(time.time()), "created": int(time.time()),
"owned_by": "infinilm", "owned_by": "infinilm",
...@@ -185,14 +197,53 @@ class InferenceServer: ...@@ -185,14 +197,53 @@ class InferenceServer:
], ],
} }
# Support both /v1/models (OpenAI) and /models (common legacy) for compatibility.
@app.get("/v1/models")
async def list_models():
return _models_payload()
@app.get("/models")
async def list_models_legacy():
return _models_payload()
def _build_sampling_params(self, data: dict) -> SamplingParams: def _build_sampling_params(self, data: dict) -> SamplingParams:
"""Build SamplingParams from request data.""" """Build SamplingParams from request data."""
# Support both:
# - top-level OpenAI-ish fields: temperature/top_p/top_k/max_tokens/stop
# - nested dict: sampling_params: { ... }
sp = data.get("sampling_params") or {}
if not isinstance(sp, dict):
sp = {}
def pick(key: str, default):
# Priority: explicit top-level field > nested sampling_params > server default
if key in data and data.get(key) is not None:
return data.get(key)
if key in sp and sp.get(key) is not None:
return sp.get(key)
return default
# Accept common alias
max_tokens = pick("max_tokens", self.max_tokens)
if max_tokens is None:
# Some clients use max_new_tokens
max_tokens = pick("max_new_tokens", self.max_tokens)
stop = pick("stop", None)
if isinstance(stop, str):
stop = [stop]
stop_token_ids = pick("stop_token_ids", None)
if isinstance(stop_token_ids, int):
stop_token_ids = [stop_token_ids]
return SamplingParams( return SamplingParams(
temperature=data.get("temperature", self.temperature), temperature=float(pick("temperature", self.temperature)),
top_p=data.get("top_p", self.top_p), top_p=float(pick("top_p", self.top_p)),
top_k=data.get("top_k", self.top_k), top_k=int(pick("top_k", self.top_k)),
max_tokens=data.get("max_tokens", self.max_tokens), max_tokens=int(max_tokens) if max_tokens is not None else None,
stop=data.get("stop"), stop=stop,
stop_token_ids=stop_token_ids,
) )
async def _stream_chat(self, request_id: str, data: dict, http_request: Request): async def _stream_chat(self, request_id: str, data: dict, http_request: Request):
...@@ -210,22 +261,26 @@ class InferenceServer: ...@@ -210,22 +261,26 @@ class InferenceServer:
request_id=request_id, request_id=request_id,
request_data=data, request_data=data,
http_request=http_request, http_request=http_request,
add_generation_prompt=bool(data.get("add_generation_prompt", True)),
chat_template_kwargs=data.get("chat_template_kwargs") or {},
) )
async for token_output in self.engine.stream_request( async for token_output in self.engine.stream_request(
req, timeout=DEFAULT_STREAM_TIMEOUT req,
timeout=DEFAULT_STREAM_TIMEOUT,
request_timeout=DEFAULT_REQUEST_TIMEOUT,
): ):
# Check timeout # If stream_request enforces timeout, we can just surface the state to the client.
if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT: if token_output.finish_reason == FinishReason.TIMEOUT:
logger.warning( logger.warning(
f"Request {request_id} timed out after {DEFAULT_REQUEST_TIMEOUT}s" f"Request {request_id} timed out after {DEFAULT_REQUEST_TIMEOUT}s"
) )
req.mark_timeout()
error_chunk = json.dumps( error_chunk = json.dumps(
chunk_json( chunk_json(
request_id, request_id,
content="[Request timeout]", content="[Request timeout]",
finish_reason="timeout", finish_reason="timeout",
model=self.model_id,
), ),
ensure_ascii=False, ensure_ascii=False,
) )
...@@ -240,7 +295,9 @@ class InferenceServer: ...@@ -240,7 +295,9 @@ class InferenceServer:
# Send token # Send token
chunk = json.dumps( chunk = json.dumps(
chunk_json(request_id, content=token_output.token_text), chunk_json(
request_id, content=token_output.token_text, model=self.model_id
),
ensure_ascii=False, ensure_ascii=False,
) )
yield f"data: {chunk}\n\n" yield f"data: {chunk}\n\n"
...@@ -250,7 +307,9 @@ class InferenceServer: ...@@ -250,7 +307,9 @@ class InferenceServer:
token_output.finish_reason token_output.finish_reason
) )
chunk = json.dumps( chunk = json.dumps(
chunk_json(request_id, finish_reason=finish_reason), chunk_json(
request_id, finish_reason=finish_reason, model=self.model_id
),
ensure_ascii=False, ensure_ascii=False,
) )
yield f"data: {chunk}\n\n" yield f"data: {chunk}\n\n"
...@@ -262,7 +321,10 @@ class InferenceServer: ...@@ -262,7 +321,10 @@ class InferenceServer:
req.mark_failed() req.mark_failed()
error_chunk = json.dumps( error_chunk = json.dumps(
chunk_json( chunk_json(
request_id, content=f"[Error: {str(e)}]", finish_reason="error" request_id,
content=f"[Error: {str(e)}]",
finish_reason="error",
model=self.model_id,
), ),
ensure_ascii=False, ensure_ascii=False,
) )
...@@ -290,17 +352,20 @@ class InferenceServer: ...@@ -290,17 +352,20 @@ class InferenceServer:
request_id=request_id, request_id=request_id,
request_data=data, request_data=data,
http_request=http_request, http_request=http_request,
add_generation_prompt=bool(data.get("add_generation_prompt", True)),
chat_template_kwargs=data.get("chat_template_kwargs") or {},
) )
# Collect all generated tokens # Collect all generated tokens
output_text = "" output_text = ""
async for token_output in self.engine.stream_request( async for token_output in self.engine.stream_request(
req, timeout=DEFAULT_STREAM_TIMEOUT req,
timeout=DEFAULT_STREAM_TIMEOUT,
request_timeout=DEFAULT_REQUEST_TIMEOUT,
): ):
# Check timeout # Request-level timeout is handled inside stream_request.
if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT: if token_output.finish_reason == FinishReason.TIMEOUT:
logger.warning(f"Request {request_id} timed out") logger.warning(f"Request {request_id} timed out")
req.mark_timeout()
break break
# Check client disconnect # Check client disconnect
...@@ -322,6 +387,7 @@ class InferenceServer: ...@@ -322,6 +387,7 @@ class InferenceServer:
content=output_text, content=output_text,
role="assistant", role="assistant",
finish_reason=finish_reason or "stop", finish_reason=finish_reason or "stop",
model=self.model_id,
) )
return response return response
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment