Unverified Commit 96ecf490 authored by Haojie Wang's avatar Haojie Wang Committed by GitHub
Browse files

Merge pull request #203 from InfiniTensor/issue/193

Issue/193: inference_server适配部署需求
parents f2d9d397 0cddb99e
......@@ -261,6 +261,12 @@ class BlockManager:
def get_num_free_blocks(self) -> int:
return len(self.free_block_ids)
def get_total_usable_blocks(self) -> int:
freeable_used_blocks = sum(
1 for bid in self.used_block_ids if self.blocks[bid].ref_count == 0
)
return len(self.free_block_ids) + freeable_used_blocks
def __repr__(self):
return (
f"BlockManager(blocks={self.num_blocks}, block_size={self.block_size}, "
......
......@@ -228,12 +228,63 @@ class LLMEngine:
req.generated_token_ids.append(token_id)
if req.is_prefill:
req.is_prefill = False
# vLLM-style replacement character handling is primarily relevant for streaming.
# For offline generation (no output queue), keep the fast incremental path.
if req._output_queue is None:
token_text = self.detokenize([token_id])
req.generated_text += token_text
else:
# Streaming path: compute delta from a full decode so we can hold back
# trailing '\ufffd' (likely an incomplete UTF-8 sequence).
decoded_text = self.detokenize(req.generated_token_ids)
finished_now = False
# Update generated_text to the latest decode (used for stop-string checks and debugging)
req.generated_text = decoded_text
if self._check_request_finished(req, token_id):
req.mark_finished(req.finish_reason)
finished_now = True
# Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING:
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if decoded_text.endswith(stop_str):
# Remove the stop string from the end
decoded_text = decoded_text[:-len(stop_str)]
req.generated_text = decoded_text
break
holds_back_incomplete_utf8 = (
bool(decoded_text) and decoded_text.endswith("\ufffd")
)
token_text = self.tokenizer.decode(token_id)
req.generated_text += token_text
if self._check_request_finished(req, token_id):
# vLLM-style: hold back only if we are not on the final chunk.
# Suppress output when finish reason is LENGTH or STOP_STRING.
# Root cause fix: When STOP_STRING is detected, we suppress output for the token
# that completes the stop string, preventing additional tokens from being output.
if (holds_back_incomplete_utf8 and not finished_now) or (
finished_now and req.finish_reason in (FinishReason.LENGTH, FinishReason.STOP_STRING)
):
token_text = ""
else:
last_len = getattr(req, "_stream_last_yielded_length", 0)
token_text = decoded_text[last_len:]
if token_text:
req._stream_last_yielded_length = len(decoded_text)
# For non-streaming, finish checks happen here.
if req._output_queue is None and self._check_request_finished(req, token_id):
req.mark_finished(req.finish_reason)
# Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING:
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if req.generated_text.endswith(stop_str):
# Remove the stop string from the end
req.generated_text = req.generated_text[:-len(stop_str)]
break
# Put output in queue if it exists (for async streaming)
if req._output_queue is not None:
......@@ -283,12 +334,15 @@ class LLMEngine:
self,
messages: List[dict],
add_generation_prompt: bool = True,
chat_template_kwargs: Optional[dict] = None,
) -> str:
"""Apply chat template to messages."""
chat_template_kwargs = chat_template_kwargs or {}
return self.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=add_generation_prompt,
tokenize=False,
**chat_template_kwargs,
)
......@@ -486,6 +540,10 @@ class AsyncLLMEngine:
self._running = False
self._step_thread: Optional[threading.Thread] = None
self._healthy = True
def is_healthy(self) -> bool:
return bool(self._healthy)
def start(self):
"""Start the background inference loop."""
......@@ -520,6 +578,7 @@ class AsyncLLMEngine:
time.sleep(0.01)
except Exception as e:
logger.error(f"Error in step loop: {e}", exc_info=True)
self._healthy = False
self._running = False
break
......@@ -581,6 +640,8 @@ class AsyncLLMEngine:
request_id: Optional[str] = None,
request_data: Optional[dict] = None,
http_request: Optional[any] = None,
add_generation_prompt: bool = True,
chat_template_kwargs: Optional[dict] = None,
) -> InferenceRequest:
"""Add a chat request to the engine.
......@@ -594,7 +655,11 @@ class AsyncLLMEngine:
Returns:
The created InferenceRequest object.
"""
prompt = self.engine.apply_chat_template(messages, add_generation_prompt=True)
prompt = self.engine.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
chat_template_kwargs=chat_template_kwargs,
)
return self.add_request(
prompt=prompt,
sampling_params=sampling_params,
......@@ -607,6 +672,7 @@ class AsyncLLMEngine:
self,
request: InferenceRequest,
timeout: float = 100.0,
request_timeout: Optional[float] = None,
) -> AsyncIterator[TokenOutput]:
"""Stream tokens from a request.
......@@ -619,6 +685,7 @@ class AsyncLLMEngine:
"""
import asyncio
start = time.time()
while True:
if request.is_finished() and request.output_queue.async_q.empty():
break
......@@ -635,6 +702,20 @@ class AsyncLLMEngine:
if token_output.finished:
break
except asyncio.TimeoutError:
# Enforce request-level timeout even if no tokens are produced.
if request_timeout is not None:
now = time.time()
if now - start > float(request_timeout):
request.mark_timeout()
yield TokenOutput(
request_id=request.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=FinishReason.TIMEOUT,
generated_text=request.generated_text,
)
break
if request.is_finished():
break
continue
......
......@@ -144,6 +144,10 @@ class InferenceRequest:
# Output management (for async streaming)
self._output_queue: Optional[janus.Queue] = None
# Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer)
# Used by the engine to compute "delta" text chunks from a full decode.
self._stream_last_yielded_length: int = 0
@property
def output_queue(self) -> janus.Queue:
"""Lazy initialization of output queue."""
......
......@@ -15,7 +15,7 @@ class SamplingParams:
top_k: int = 1
max_tokens: Optional[int] = None
stop: Optional[List[str]] = None
stop_token_ids: Optional[List[int]] = None
stop_token_ids: Optional[List[int]] = None # Placeholder for future usage, not currently handled
def __post_init__(self):
if self.stop is None:
......
......@@ -155,12 +155,21 @@ class Scheduler:
except queue.Empty:
break
if not self.can_accept_request(req):
self.waiting_queue.sync_q.put(req)
break
# Skip requests that were already finished (e.g., timed out/canceled while waiting)
if req.is_finished():
self.complete_requests([req])
continue
req_tokens = req.get_input_tokens()
num_required_blocks = req.get_num_blocks_required(self.block_size)
if not self.cache_manager.can_allocate(num_required_blocks):
if not self.cache_manager.try_free_blocks(num_required_blocks):
raise RuntimeError("No available cache blocks")
raise RuntimeError("No available cache blocks for new request")
# Allocate blocks with automatic prefix caching support
req.block_table, req.slot_mapping, req.num_cached_tokens = (
......@@ -185,6 +194,10 @@ class Scheduler:
req = self.running_queue.sync_q.get_nowait()
except queue.Empty:
break
# Skip requests that were already finished (e.g., timed out/canceled while running)
if req.is_finished():
self.complete_requests([req])
continue
# Decode phase: allocate slot for newly generated token
try:
......@@ -197,7 +210,7 @@ class Scheduler:
scheduled_requests.append(req)
except RuntimeError as e:
raise RuntimeError("No available cache blocks") from e
raise RuntimeError("No available cache blocks for new token") from e
# Return decode batch if any running requests were scheduled
if scheduled_requests:
......@@ -237,6 +250,31 @@ class Scheduler:
# Still running, put back in running queue
self.running_queue.sync_q.put(req)
def can_accept_request(self, request: InferenceRequest) -> bool:
total_required_blocks = 0
# Calculate blocks needed for running requests
running_queue_size = self.running_queue.sync_q.qsize()
for _ in range(running_queue_size):
req = self.running_queue.sync_q.get()
remaining_tokens = (
req.sampling_params.max_tokens - req.get_num_generated_tokens()
)
num_blocks_needed = (
remaining_tokens + self.block_size - 1
) // self.block_size
total_required_blocks += num_blocks_needed
self.running_queue.sync_q.put(req)
# Calculate blocks needed for the new request
total_length = request.get_prompt_length()
total_length += request.sampling_params.max_tokens
num_blocks_needed = (total_length + self.block_size - 1) // self.block_size
total_required_blocks += num_blocks_needed
# Compare with total usable blocks in cache manager
return total_required_blocks <= self.cache_manager.get_total_usable_blocks()
def get_cache_stats(self) -> dict:
"""Get cache statistics."""
return {
......
......@@ -10,6 +10,7 @@ import uuid
import argparse
import uvicorn
import logging
import os
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
......@@ -22,7 +23,7 @@ DEFAULT_STREAM_TIMEOUT = 100.0
DEFAULT_REQUEST_TIMEOUT = 1000.0
def chunk_json(id_, content=None, role=None, finish_reason=None):
def chunk_json(id_, content=None, role=None, finish_reason=None, model: str = "unknown"):
"""Generate JSON chunk for streaming response."""
delta = {}
if content:
......@@ -33,7 +34,7 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
"id": id_,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "jiuge",
"model": model,
"system_fingerprint": None,
"choices": [
{
......@@ -84,6 +85,8 @@ class InferenceServer:
port: Server port number.
"""
self.model_path = model_path
# vLLM-like served model id: directory name of model_path
self.model_id = os.path.basename(os.path.normpath(model_path)) or "model"
self.device = device
self.dtype = dtype
self.tensor_parallel_size = tensor_parallel_size
......@@ -136,7 +139,10 @@ class InferenceServer:
def _register_routes(self, app: FastAPI):
"""Register API routes."""
# OpenAI-compatible chat completions endpoint.
# Support both legacy path and OpenAI-style /v1 prefix for proxy/router compatibility.
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
try:
data = await request.json()
......@@ -169,15 +175,21 @@ class InferenceServer:
@app.get("/health")
async def health():
# Expose engine health so babysitter/registry can treat backend as unhealthy.
if (
self.engine is not None
and hasattr(self.engine, "is_healthy")
and not self.engine.is_healthy()
):
return JSONResponse(content={"status": "unhealthy"}, status_code=503)
return {"status": "healthy"}
@app.get("/v1/models")
async def list_models():
def _models_payload():
return {
"object": "list",
"data": [
{
"id": "jiuge",
"id": self.model_id,
"object": "model",
"created": int(time.time()),
"owned_by": "infinilm",
......@@ -185,20 +197,54 @@ class InferenceServer:
],
}
# Support both /v1/models (OpenAI) and /models (common legacy) for compatibility.
@app.get("/v1/models")
async def list_models():
return _models_payload()
@app.get("/models")
async def list_models_legacy():
return _models_payload()
def _build_sampling_params(self, data: dict) -> SamplingParams:
"""Build SamplingParams from request data."""
# Support both:
# - top-level OpenAI-ish fields: temperature/top_p/top_k/max_tokens/stop
# - nested dict: sampling_params: { ... }
sp = data.get("sampling_params") or {}
if not isinstance(sp, dict):
sp = {}
def pick(key: str, default):
# Priority: explicit top-level field > nested sampling_params > server default
if key in data and data.get(key) is not None:
return data.get(key)
if key in sp and sp.get(key) is not None:
return sp.get(key)
return default
# Accept common alias
max_tokens = pick("max_tokens", self.max_tokens)
if max_tokens is None:
# Some clients use max_new_tokens
max_tokens = pick("max_new_tokens", self.max_tokens)
stop = pick("stop", None)
if isinstance(stop, str):
stop = [stop]
return SamplingParams(
temperature=data.get("temperature", self.temperature),
top_p=data.get("top_p", self.top_p),
top_k=data.get("top_k", self.top_k),
max_tokens=data.get("max_tokens", self.max_tokens),
stop=data.get("stop"),
temperature=float(pick("temperature", self.temperature)),
top_p=float(pick("top_p", self.top_p)),
top_k=int(pick("top_k", self.top_k)),
max_tokens=int(max_tokens) if max_tokens is not None else None,
stop=stop,
)
async def _stream_chat(self, request_id: str, data: dict, http_request: Request):
"""Handle streaming chat request."""
req = None
start_time = time.time()
try:
messages = data.get("messages", [])
......@@ -210,22 +256,26 @@ class InferenceServer:
request_id=request_id,
request_data=data,
http_request=http_request,
add_generation_prompt=bool(data.get("add_generation_prompt", True)),
chat_template_kwargs=data.get("chat_template_kwargs") or {},
)
async for token_output in self.engine.stream_request(
req, timeout=DEFAULT_STREAM_TIMEOUT
req,
timeout=DEFAULT_STREAM_TIMEOUT,
request_timeout=DEFAULT_REQUEST_TIMEOUT,
):
# Check timeout
if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT:
# If stream_request enforces timeout, we can just surface the state to the client.
if token_output.finish_reason == FinishReason.TIMEOUT:
logger.warning(
f"Request {request_id} timed out after {DEFAULT_REQUEST_TIMEOUT}s"
)
req.mark_timeout()
error_chunk = json.dumps(
chunk_json(
request_id,
content="[Request timeout]",
finish_reason="timeout",
model=self.model_id,
),
ensure_ascii=False,
)
......@@ -238,19 +288,31 @@ class InferenceServer:
req.mark_canceled()
break
# Send token
chunk = json.dumps(
chunk_json(request_id, content=token_output.token_text),
ensure_ascii=False,
# Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids
is_eos_token = (
eos_token_ids and token_output.token_id in eos_token_ids
)
yield f"data: {chunk}\n\n"
if not is_eos_token and token_output.token_text:
# Send token
chunk = json.dumps(
chunk_json(
request_id, content=token_output.token_text, model=self.model_id
),
ensure_ascii=False,
)
yield f"data: {chunk}\n\n"
if token_output.finished:
finish_reason = self._convert_finish_reason(
token_output.finish_reason
)
chunk = json.dumps(
chunk_json(request_id, finish_reason=finish_reason),
chunk_json(
request_id, finish_reason=finish_reason, model=self.model_id
),
ensure_ascii=False,
)
yield f"data: {chunk}\n\n"
......@@ -262,7 +324,10 @@ class InferenceServer:
req.mark_failed()
error_chunk = json.dumps(
chunk_json(
request_id, content=f"[Error: {str(e)}]", finish_reason="error"
request_id,
content=f"[Error: {str(e)}]",
finish_reason="error",
model=self.model_id,
),
ensure_ascii=False,
)
......@@ -278,7 +343,6 @@ class InferenceServer:
async def _chat(self, request_id: str, data: dict, http_request: Request):
"""Handle non-streaming chat request."""
req = None
start_time = time.time()
try:
messages = data.get("messages", [])
......@@ -290,17 +354,20 @@ class InferenceServer:
request_id=request_id,
request_data=data,
http_request=http_request,
add_generation_prompt=bool(data.get("add_generation_prompt", True)),
chat_template_kwargs=data.get("chat_template_kwargs") or {},
)
# Collect all generated tokens
output_text = ""
async for token_output in self.engine.stream_request(
req, timeout=DEFAULT_STREAM_TIMEOUT
req,
timeout=DEFAULT_STREAM_TIMEOUT,
request_timeout=DEFAULT_REQUEST_TIMEOUT,
):
# Check timeout
if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT:
# Request-level timeout is handled inside stream_request.
if token_output.finish_reason == FinishReason.TIMEOUT:
logger.warning(f"Request {request_id} timed out")
req.mark_timeout()
break
# Check client disconnect
......@@ -309,7 +376,15 @@ class InferenceServer:
req.mark_canceled()
break
output_text += token_output.token_text
# Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids
is_eos_token = (
eos_token_ids and token_output.token_id in eos_token_ids
)
if not is_eos_token:
output_text += token_output.token_text
if token_output.finished:
break
......@@ -322,6 +397,7 @@ class InferenceServer:
content=output_text,
role="assistant",
finish_reason=finish_reason or "stop",
model=self.model_id,
)
return response
......
......@@ -4,7 +4,6 @@ import argparse
import time
import re
import csv
from datasets import load_dataset, Dataset
import numpy as np
import infinicore
from infinilm.modeling_utils import load_model_state_dict_by_file
......@@ -12,6 +11,7 @@ from infinilm.distributed import DistConfig
from infinilm.cache import StaticKVCacheConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.cache import StaticKVCacheConfig
from datasets import load_dataset, Dataset
from abc import ABC, abstractmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment