Commit d45ad1e6 authored by MaYuhang's avatar MaYuhang
Browse files

issue/233 fix: improve request lifecycle management and timeout handling

parent 323c78a1
......@@ -13,6 +13,9 @@ import threading
from typing import List, Optional, Union, AsyncIterator
from dataclasses import dataclass
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
import infinicore
from infinilm.llm.request import (
......@@ -29,8 +32,6 @@ from infinilm.distributed import DistConfig
from infinilm.infer_engine import InferEngine
from infinilm.cache.cache import PagedKVCacheConfig, StaticKVCacheConfig
from infinilm.modeling_utils import load_model_state_dict_by_file
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
logger = logging.getLogger(__name__)
......@@ -249,47 +250,36 @@ class LLMEngine:
self.scheduler.cache_manager.reset_req_blocks()
for req, token_id in zip(requests, sampled_tokens):
req.generated_token_ids.append(token_id)
if req.is_aborted():
logger.info(
f"Request {req.request_id} aborted by client, skipping update"
)
continue
if req.is_prefill:
req.is_prefill = False
# vLLM-style replacement character handling is primarily relevant for streaming.
# For offline generation (no output queue), keep the fast incremental path.
if req._output_queue is None:
token_text = self.detokenize([token_id])
req.generated_text += token_text
else:
# Streaming path: compute delta from a full decode so we can hold back
# trailing '\ufffd' (likely an incomplete UTF-8 sequence).
decoded_text = self.detokenize(req.generated_token_ids)
finished_now = False
# Update generated_text to the latest decode (used for stop-string checks and debugging)
req.generated_token_ids.append(token_id)
decoded_text = self.detokenize(req.generated_token_ids)
req.generated_text = decoded_text
holds_back_incomplete_utf8 = bool(decoded_text) and decoded_text.endswith(
"\ufffd"
)
if self._check_request_finished(req, token_id):
req.mark_finished(req.finish_reason)
finished_now = True
is_finished = self._check_request_finished(req, token_id)
# Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING:
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if decoded_text.endswith(stop_str):
# Remove the stop string from the end
decoded_text = decoded_text[: -len(stop_str)]
req.generated_text = decoded_text
break
holds_back_incomplete_utf8 = bool(
decoded_text
) and decoded_text.endswith("\ufffd")
# vLLM-style replacement character handling is primarily relevant for streaming.
# For offline generation (no output queue), keep the fast incremental path.
if req._output_queue is None:
if is_finished:
if holds_back_incomplete_utf8:
req.generated_text = decoded_text[:-1]
req.mark_finished(req.finish_reason)
# vLLM-style: hold back only if we are not on the final chunk.
# Suppress output when finish reason is LENGTH or STOP_STRING.
# Root cause fix: When STOP_STRING is detected, we suppress output for the token
# that completes the stop string, preventing additional tokens from being output.
if (holds_back_incomplete_utf8 and not finished_now) or (
finished_now
else:
if (holds_back_incomplete_utf8 and not is_finished) or (
is_finished
and req.finish_reason
in (FinishReason.LENGTH, FinishReason.STOP_STRING)
):
......@@ -300,30 +290,29 @@ class LLMEngine:
if token_text:
req._stream_last_yielded_length = len(decoded_text)
# For non-streaming, finish checks happen here.
if req._output_queue is None and self._check_request_finished(
req, token_id
):
if is_finished:
req.mark_finished(req.finish_reason)
# Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING:
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if req.generated_text.endswith(stop_str):
# Remove the stop string from the end
req.generated_text = req.generated_text[: -len(stop_str)]
break
# Put output in queue if it exists (for async streaming)
if req._output_queue is not None:
output = TokenOutput(
request_id=req.request_id,
token_id=token_id,
token_text=token_text,
finished=req.is_finished(),
finish_reason=req.finish_reason,
finished=is_finished,
finish_reason=req.finish_reason if is_finished else None,
generated_text=req.generated_text,
)
if req.is_aborted():
logger.info(
f"Request {req.request_id} aborted before putting token"
)
continue
try:
req.output_queue.sync_q.put(output)
except Exception as e:
logger.warning(
f"Failed to put token for {req.request_id}: {e}. "
f"Likely due to client disconnecting or request cancelation."
)
continue
self.scheduler.complete_requests(requests)
......@@ -341,9 +330,11 @@ class LLMEngine:
return True
# Check stop strings
# Remove stop string from generated_text if STOP_STRING finish reason
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if req.generated_text.endswith(stop_str):
req.generated_text = req.generated_text[: -len(stop_str)]
req.finish_reason = FinishReason.STOP_STRING
return True
......@@ -732,10 +723,19 @@ class AsyncLLMEngine:
start = time.time()
while True:
if request.is_finished() and request.output_queue.async_q.empty():
try:
if request_timeout and time.time() - start > float(request_timeout):
request.mark_timeout()
yield TokenOutput(
request_id=request.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=FinishReason.TIMEOUT,
generated_text=request.generated_text,
)
break
try:
token_output = await asyncio.wait_for(
request.output_queue.async_q.get(), timeout=timeout
)
......@@ -747,26 +747,28 @@ class AsyncLLMEngine:
if token_output.finished:
break
except asyncio.TimeoutError:
# Enforce request-level timeout even if no tokens are produced.
if request_timeout is not None:
now = time.time()
if now - start > float(request_timeout):
request.mark_timeout()
logger.warning(
f"Timeout while waiting for token from request {request.request_id}"
)
if request.is_aborted():
while not request.output_queue.async_q.empty():
try:
token_output = request.output_queue.async_q.get_nowait()
request.output_queue.async_q.task_done()
yield token_output
except asyncio.QueueEmpty:
break
yield TokenOutput(
request_id=request.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=FinishReason.TIMEOUT,
finish_reason=request.finish_reason,
generated_text=request.generated_text,
)
break
if request.is_finished():
break
continue
except asyncio.CancelledError:
request.mark_canceled()
break
except Exception as e:
logger.error(f"Error streaming request {request.request_id}: {e}")
await asyncio.sleep(0.01)
logger.error(f"Error while streaming request {request.request_id}: {e}")
break
......@@ -7,9 +7,13 @@ from dataclasses import dataclass, field
from typing import List, Optional, Any
import time
import janus
import asyncio
import logging
from infinilm.llm.sampling_params import SamplingParams
logger = logging.getLogger(__name__)
class RequestStatus(Enum):
"""Status of an inference request."""
......@@ -143,6 +147,7 @@ class InferenceRequest:
# Output management (for async streaming)
self._output_queue: Optional[janus.Queue] = None
self._aborted = False
# Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer)
# Used by the engine to compute "delta" text chunks from a full decode.
......@@ -185,6 +190,14 @@ class InferenceRequest:
RequestStatus.TIMEOUT,
]
def abort(self):
"""Signal that the request has been aborted and should stop generation."""
self._aborted = True
def is_aborted(self) -> bool:
"""Check if the request has been aborted."""
return self._aborted
def mark_finished(self, reason: FinishReason):
"""Mark the request as finished with the given reason."""
self.status = RequestStatus.FINISHED
......@@ -193,18 +206,21 @@ class InferenceRequest:
def mark_failed(self, reason: FinishReason = FinishReason.ERROR):
"""Mark the request as failed."""
self.abort()
self.status = RequestStatus.FAILED
self.finish_reason = reason
self.finished_time = time.time()
def mark_canceled(self):
"""Mark the request as canceled."""
self.abort()
self.status = RequestStatus.CANCELED
self.finish_reason = FinishReason.CANCELED
self.finished_time = time.time()
def mark_timeout(self):
"""Mark the request as timed out."""
self.abort()
self.status = RequestStatus.TIMEOUT
self.finish_reason = FinishReason.TIMEOUT
self.finished_time = time.time()
......@@ -212,9 +228,25 @@ class InferenceRequest:
async def close(self):
"""Close the output queue and clean up resources."""
if self._output_queue is not None:
await self._output_queue.async_q.join()
self.abort()
try:
while not self._output_queue.async_q.empty():
try:
self._output_queue.async_q.get_nowait()
self._output_queue.async_q.task_done()
except asyncio.QueueEmpty:
break
except Exception as e:
logger.error(
f"Error while clearing output queue for request {self.request_id}: {e}"
)
pass
self._output_queue.close()
await self._output_queue.wait_closed()
try:
await asyncio.wait_for(self._output_queue.wait_closed(), timeout=0.5)
except asyncio.TimeoutError:
logger.warning("wait_closed timeout, force close")
def to_request_output(self) -> RequestOutput:
"""Convert to RequestOutput for external use."""
......
......@@ -7,7 +7,12 @@ import queue
import janus
from typing import List, Optional
from infinilm.llm.request import RequestStatus, InferenceRequest, FinishReason
from infinilm.llm.request import (
RequestStatus,
InferenceRequest,
FinishReason,
TokenOutput,
)
logger = logging.getLogger(__name__)
......@@ -115,6 +120,21 @@ class StaticScheduler:
)
self.running_request = None
req.mark_failed(FinishReason.LENGTH)
output = TokenOutput(
request_id=req.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=req.finish_reason,
generated_text=req.generated_text,
)
try:
req.output_queue.sync_q.put(output)
except Exception as e:
logger.warning(
f"Failed to put completion token for {req.request_id}: {e}. "
f"Likely due to client disconnecting or request cancelation."
)
continue
return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False)
......@@ -137,6 +157,21 @@ class StaticScheduler:
)
req.mark_failed(FinishReason.LENGTH)
output = TokenOutput(
request_id=req.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=req.finish_reason,
generated_text=req.generated_text,
)
try:
req.output_queue.sync_q.put(output)
except Exception as e:
logger.warning(
f"Failed to put completion token for {req.request_id}: {e}. "
f"Likely due to client disconnecting or request cancelation."
)
continue
req.status = RequestStatus.RUNNING
......
......@@ -11,6 +11,7 @@ import argparse
import uvicorn
import logging
import os
import asyncio
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
......@@ -351,6 +352,12 @@ class InferenceServer:
timeout=DEFAULT_STREAM_TIMEOUT,
request_timeout=DEFAULT_REQUEST_TIMEOUT,
):
# Check client disconnect
if await http_request.is_disconnected():
logger.info(f"Client disconnected for request {request_id}")
req.mark_canceled()
break
# If stream_request enforces timeout, we can just surface the state to the client.
if token_output.finish_reason == FinishReason.TIMEOUT:
logger.warning(
......@@ -368,12 +375,6 @@ class InferenceServer:
yield f"data: {error_chunk}\n\n"
break
# Check client disconnect
if await http_request.is_disconnected():
logger.info(f"Client disconnected for request {request_id}")
req.mark_canceled()
break
# Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids
......@@ -404,6 +405,12 @@ class InferenceServer:
yield f"data: {chunk}\n\n"
break
except asyncio.CancelledError:
logger.info(f"Request {request_id} was cancelled")
if req:
req.mark_canceled()
raise
except Exception as e:
logger.error(f"Stream error for {request_id}: {e}", exc_info=True)
if req:
......@@ -451,23 +458,23 @@ class InferenceServer:
timeout=DEFAULT_STREAM_TIMEOUT,
request_timeout=DEFAULT_REQUEST_TIMEOUT,
):
# Request-level timeout is handled inside stream_request.
if token_output.finish_reason == FinishReason.TIMEOUT:
logger.warning(f"Request {request_id} timed out")
break
# Check client disconnect
if await http_request.is_disconnected():
logger.info(f"Client disconnected for request {request_id}")
req.mark_canceled()
break
# Request-level timeout is handled inside stream_request.
if token_output.finish_reason == FinishReason.TIMEOUT:
logger.warning(f"Request {request_id} timed out")
break
# Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids
is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids
if not is_eos_token:
if not is_eos_token and token_output.token_text:
output_text += token_output.token_text
if token_output.finished:
......@@ -488,6 +495,12 @@ class InferenceServer:
)
return response
except asyncio.CancelledError:
logger.info(f"Request {request_id} was cancelled")
if req:
req.mark_canceled()
raise
except Exception as e:
logger.error(f"Chat error for {request_id}: {e}", exc_info=True)
if req:
......
......@@ -4,7 +4,6 @@ from openai import AsyncOpenAI
import argparse
import random
PROMPTS = [
"如果猫能写诗,它们会写些什么?",
"描述一个没有重力的世界。",
......@@ -25,11 +24,11 @@ PROMPTS = [
"如果你可以变成任何一种动物,你会选择什么?",
"描述一个由机器人统治的未来世界。",
"如果你能与任何虚构角色成为朋友,你会选择谁?",
"想象一下,如果每个人都能读懂他人的思想。"
"想象一下,如果每个人都能读懂他人的思想。",
]
NUM_REQUESTS = 10
CONCURRENCY = 5
NUM_REQUESTS = 64
CONCURRENCY = 20
API_URL = "http://127.0.0.1:8000"
MODEL = "FM9G-7B"
......@@ -50,7 +49,7 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose):
stream = await client.chat.completions.create(
model=MODEL,
messages=[{"role": "user", "content": question}],
stream=True
stream=True,
)
first_token_time = None
......@@ -71,19 +70,33 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose):
ttft = first_token_time - start_time if first_token_time else None
elapsed_time = end_time - start_time if start_time else None
ms_per_token = (elapsed_time / total_tokens * 1000) if total_tokens > 0 and elapsed_time else None
tokens_per_second = total_tokens / elapsed_time if elapsed_time > 0 else 0
ms_per_token = (
(elapsed_time / total_tokens * 1000)
if total_tokens > 0 and elapsed_time
else None
)
tokens_per_second = (
total_tokens / elapsed_time if elapsed_time > 0 else 0
)
answer = "".join(answer_chunks)
results.append((total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token))
results.append(
(total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token)
)
if verbose:
print(f"\n📝 Request #{task_id} (User #{user_id})")
if ttft is not None:
print(f" ⏱ 首字延迟 TTFT: {ttft:.3f}s")
if elapsed_time is not None:
print(f" ⏱ 总耗时: {elapsed_time:.3f}s")
print(f" 🔤 解码 token 总数: {total_tokens}")
if ms_per_token is not None:
print(f" 📏 平均 token 解码时间: {ms_per_token:.2f} ms/token")
else:
print(f" 📏 平均 token 解码时间: N/A (no token generated)")
print(f" ❓ 提问: {question}")
print(f" 💬 回答: {answer}\n")
......@@ -92,6 +105,8 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose):
if verbose:
print(f"\n⚠️ Request #{task_id} (User #{user_id}) FAILED:")
print(f" ❌ Error: {e}\n")
queue.task_done()
async def run_benchmark(verbose=False):
client = AsyncOpenAI(base_url=API_URL, api_key="default")
......@@ -104,7 +119,9 @@ async def run_benchmark(verbose=False):
await queue.put(None)
users = [
asyncio.create_task(benchmark_user(client, semaphore, queue, results, user_id, verbose))
asyncio.create_task(
benchmark_user(client, semaphore, queue, results, user_id, verbose)
)
for user_id in range(CONCURRENCY)
]
......@@ -121,11 +138,19 @@ async def run_benchmark(verbose=False):
ms_per_token_list = [r[4] for r in results if r and r[4] is not None]
successful_requests = len(results)
requests_per_second = successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0
requests_per_second = (
successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0
)
avg_latency = sum(latencies) / len(latencies) if latencies else 0
avg_tokens_per_second = sum(tokens_per_second_list) / len(tokens_per_second_list) if tokens_per_second_list else 0
avg_tokens_per_second = (
sum(tokens_per_second_list) / len(tokens_per_second_list)
if tokens_per_second_list
else 0
)
avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0
avg_ms_per_token = sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None
avg_ms_per_token = (
sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None
)
width_label = 24
sep = "-" * 60
......@@ -142,7 +167,9 @@ async def run_benchmark(verbose=False):
print(f"{'Average latency':<{width_label}}: {avg_latency:.2f} s")
print(f"{'Average TTFT':<{width_label}}: {avg_ttft:.2f} s")
print(f"{'Avg time per token':<{width_label}}: {avg_ms_per_token:.2f} ms/token")
print(f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s")
print(
f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s"
)
if __name__ == "__main__":
......@@ -150,6 +177,4 @@ if __name__ == "__main__":
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
asyncio.run(run_benchmark(
args.verbose
))
asyncio.run(run_benchmark(args.verbose))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment