Commit d45ad1e6 authored by MaYuhang's avatar MaYuhang
Browse files

issue/233 fix: improve request lifecycle management and timeout handling

parent 323c78a1
...@@ -13,6 +13,9 @@ import threading ...@@ -13,6 +13,9 @@ import threading
from typing import List, Optional, Union, AsyncIterator from typing import List, Optional, Union, AsyncIterator
from dataclasses import dataclass from dataclasses import dataclass
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
import infinicore import infinicore
from infinilm.llm.request import ( from infinilm.llm.request import (
...@@ -29,8 +32,6 @@ from infinilm.distributed import DistConfig ...@@ -29,8 +32,6 @@ from infinilm.distributed import DistConfig
from infinilm.infer_engine import InferEngine from infinilm.infer_engine import InferEngine
from infinilm.cache.cache import PagedKVCacheConfig, StaticKVCacheConfig from infinilm.cache.cache import PagedKVCacheConfig, StaticKVCacheConfig
from infinilm.modeling_utils import load_model_state_dict_by_file from infinilm.modeling_utils import load_model_state_dict_by_file
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -249,47 +250,36 @@ class LLMEngine: ...@@ -249,47 +250,36 @@ class LLMEngine:
self.scheduler.cache_manager.reset_req_blocks() self.scheduler.cache_manager.reset_req_blocks()
for req, token_id in zip(requests, sampled_tokens): for req, token_id in zip(requests, sampled_tokens):
req.generated_token_ids.append(token_id)
if req.is_aborted():
logger.info(
f"Request {req.request_id} aborted by client, skipping update"
)
continue
if req.is_prefill: if req.is_prefill:
req.is_prefill = False req.is_prefill = False
req.generated_token_ids.append(token_id)
decoded_text = self.detokenize(req.generated_token_ids)
req.generated_text = decoded_text
holds_back_incomplete_utf8 = bool(decoded_text) and decoded_text.endswith(
"\ufffd"
)
is_finished = self._check_request_finished(req, token_id)
# vLLM-style replacement character handling is primarily relevant for streaming. # vLLM-style replacement character handling is primarily relevant for streaming.
# For offline generation (no output queue), keep the fast incremental path. # For offline generation (no output queue), keep the fast incremental path.
if req._output_queue is None: if req._output_queue is None:
token_text = self.detokenize([token_id]) if is_finished:
req.generated_text += token_text if holds_back_incomplete_utf8:
else: req.generated_text = decoded_text[:-1]
# Streaming path: compute delta from a full decode so we can hold back
# trailing '\ufffd' (likely an incomplete UTF-8 sequence).
decoded_text = self.detokenize(req.generated_token_ids)
finished_now = False
# Update generated_text to the latest decode (used for stop-string checks and debugging)
req.generated_text = decoded_text
if self._check_request_finished(req, token_id):
req.mark_finished(req.finish_reason) req.mark_finished(req.finish_reason)
finished_now = True
else:
# Remove stop string from generated_text if STOP_STRING finish reason if (holds_back_incomplete_utf8 and not is_finished) or (
if req.finish_reason == FinishReason.STOP_STRING: is_finished
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if decoded_text.endswith(stop_str):
# Remove the stop string from the end
decoded_text = decoded_text[: -len(stop_str)]
req.generated_text = decoded_text
break
holds_back_incomplete_utf8 = bool(
decoded_text
) and decoded_text.endswith("\ufffd")
# vLLM-style: hold back only if we are not on the final chunk.
# Suppress output when finish reason is LENGTH or STOP_STRING.
# Root cause fix: When STOP_STRING is detected, we suppress output for the token
# that completes the stop string, preventing additional tokens from being output.
if (holds_back_incomplete_utf8 and not finished_now) or (
finished_now
and req.finish_reason and req.finish_reason
in (FinishReason.LENGTH, FinishReason.STOP_STRING) in (FinishReason.LENGTH, FinishReason.STOP_STRING)
): ):
...@@ -300,30 +290,29 @@ class LLMEngine: ...@@ -300,30 +290,29 @@ class LLMEngine:
if token_text: if token_text:
req._stream_last_yielded_length = len(decoded_text) req._stream_last_yielded_length = len(decoded_text)
# For non-streaming, finish checks happen here. if is_finished:
if req._output_queue is None and self._check_request_finished( req.mark_finished(req.finish_reason)
req, token_id
):
req.mark_finished(req.finish_reason)
# Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING:
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if req.generated_text.endswith(stop_str):
# Remove the stop string from the end
req.generated_text = req.generated_text[: -len(stop_str)]
break
# Put output in queue if it exists (for async streaming)
if req._output_queue is not None:
output = TokenOutput( output = TokenOutput(
request_id=req.request_id, request_id=req.request_id,
token_id=token_id, token_id=token_id,
token_text=token_text, token_text=token_text,
finished=req.is_finished(), finished=is_finished,
finish_reason=req.finish_reason, finish_reason=req.finish_reason if is_finished else None,
generated_text=req.generated_text, generated_text=req.generated_text,
) )
req.output_queue.sync_q.put(output) if req.is_aborted():
logger.info(
f"Request {req.request_id} aborted before putting token"
)
continue
try:
req.output_queue.sync_q.put(output)
except Exception as e:
logger.warning(
f"Failed to put token for {req.request_id}: {e}. "
f"Likely due to client disconnecting or request cancelation."
)
continue
self.scheduler.complete_requests(requests) self.scheduler.complete_requests(requests)
...@@ -341,9 +330,11 @@ class LLMEngine: ...@@ -341,9 +330,11 @@ class LLMEngine:
return True return True
# Check stop strings # Check stop strings
# Remove stop string from generated_text if STOP_STRING finish reason
stop_strings = req.sampling_params.stop or [] stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings: for stop_str in stop_strings:
if req.generated_text.endswith(stop_str): if req.generated_text.endswith(stop_str):
req.generated_text = req.generated_text[: -len(stop_str)]
req.finish_reason = FinishReason.STOP_STRING req.finish_reason = FinishReason.STOP_STRING
return True return True
...@@ -732,10 +723,19 @@ class AsyncLLMEngine: ...@@ -732,10 +723,19 @@ class AsyncLLMEngine:
start = time.time() start = time.time()
while True: while True:
if request.is_finished() and request.output_queue.async_q.empty():
break
try: try:
if request_timeout and time.time() - start > float(request_timeout):
request.mark_timeout()
yield TokenOutput(
request_id=request.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=FinishReason.TIMEOUT,
generated_text=request.generated_text,
)
break
token_output = await asyncio.wait_for( token_output = await asyncio.wait_for(
request.output_queue.async_q.get(), timeout=timeout request.output_queue.async_q.get(), timeout=timeout
) )
...@@ -747,26 +747,28 @@ class AsyncLLMEngine: ...@@ -747,26 +747,28 @@ class AsyncLLMEngine:
if token_output.finished: if token_output.finished:
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Enforce request-level timeout even if no tokens are produced. logger.warning(
if request_timeout is not None: f"Timeout while waiting for token from request {request.request_id}"
now = time.time() )
if now - start > float(request_timeout): if request.is_aborted():
request.mark_timeout() while not request.output_queue.async_q.empty():
yield TokenOutput( try:
request_id=request.request_id, token_output = request.output_queue.async_q.get_nowait()
token_id=-1, request.output_queue.async_q.task_done()
token_text="", yield token_output
finished=True, except asyncio.QueueEmpty:
finish_reason=FinishReason.TIMEOUT, break
generated_text=request.generated_text,
) yield TokenOutput(
break request_id=request.request_id,
if request.is_finished(): token_id=-1,
token_text="",
finished=True,
finish_reason=request.finish_reason,
generated_text=request.generated_text,
)
break break
continue continue
except asyncio.CancelledError:
request.mark_canceled()
break
except Exception as e: except Exception as e:
logger.error(f"Error streaming request {request.request_id}: {e}") logger.error(f"Error while streaming request {request.request_id}: {e}")
await asyncio.sleep(0.01) break
...@@ -7,9 +7,13 @@ from dataclasses import dataclass, field ...@@ -7,9 +7,13 @@ from dataclasses import dataclass, field
from typing import List, Optional, Any from typing import List, Optional, Any
import time import time
import janus import janus
import asyncio
import logging
from infinilm.llm.sampling_params import SamplingParams from infinilm.llm.sampling_params import SamplingParams
logger = logging.getLogger(__name__)
class RequestStatus(Enum): class RequestStatus(Enum):
"""Status of an inference request.""" """Status of an inference request."""
...@@ -143,6 +147,7 @@ class InferenceRequest: ...@@ -143,6 +147,7 @@ class InferenceRequest:
# Output management (for async streaming) # Output management (for async streaming)
self._output_queue: Optional[janus.Queue] = None self._output_queue: Optional[janus.Queue] = None
self._aborted = False
# Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer) # Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer)
# Used by the engine to compute "delta" text chunks from a full decode. # Used by the engine to compute "delta" text chunks from a full decode.
...@@ -185,6 +190,14 @@ class InferenceRequest: ...@@ -185,6 +190,14 @@ class InferenceRequest:
RequestStatus.TIMEOUT, RequestStatus.TIMEOUT,
] ]
def abort(self):
"""Signal that the request has been aborted and should stop generation."""
self._aborted = True
def is_aborted(self) -> bool:
"""Check if the request has been aborted."""
return self._aborted
def mark_finished(self, reason: FinishReason): def mark_finished(self, reason: FinishReason):
"""Mark the request as finished with the given reason.""" """Mark the request as finished with the given reason."""
self.status = RequestStatus.FINISHED self.status = RequestStatus.FINISHED
...@@ -193,18 +206,21 @@ class InferenceRequest: ...@@ -193,18 +206,21 @@ class InferenceRequest:
def mark_failed(self, reason: FinishReason = FinishReason.ERROR): def mark_failed(self, reason: FinishReason = FinishReason.ERROR):
"""Mark the request as failed.""" """Mark the request as failed."""
self.abort()
self.status = RequestStatus.FAILED self.status = RequestStatus.FAILED
self.finish_reason = reason self.finish_reason = reason
self.finished_time = time.time() self.finished_time = time.time()
def mark_canceled(self): def mark_canceled(self):
"""Mark the request as canceled.""" """Mark the request as canceled."""
self.abort()
self.status = RequestStatus.CANCELED self.status = RequestStatus.CANCELED
self.finish_reason = FinishReason.CANCELED self.finish_reason = FinishReason.CANCELED
self.finished_time = time.time() self.finished_time = time.time()
def mark_timeout(self): def mark_timeout(self):
"""Mark the request as timed out.""" """Mark the request as timed out."""
self.abort()
self.status = RequestStatus.TIMEOUT self.status = RequestStatus.TIMEOUT
self.finish_reason = FinishReason.TIMEOUT self.finish_reason = FinishReason.TIMEOUT
self.finished_time = time.time() self.finished_time = time.time()
...@@ -212,9 +228,25 @@ class InferenceRequest: ...@@ -212,9 +228,25 @@ class InferenceRequest:
async def close(self): async def close(self):
"""Close the output queue and clean up resources.""" """Close the output queue and clean up resources."""
if self._output_queue is not None: if self._output_queue is not None:
await self._output_queue.async_q.join() self.abort()
try:
while not self._output_queue.async_q.empty():
try:
self._output_queue.async_q.get_nowait()
self._output_queue.async_q.task_done()
except asyncio.QueueEmpty:
break
except Exception as e:
logger.error(
f"Error while clearing output queue for request {self.request_id}: {e}"
)
pass
self._output_queue.close() self._output_queue.close()
await self._output_queue.wait_closed() try:
await asyncio.wait_for(self._output_queue.wait_closed(), timeout=0.5)
except asyncio.TimeoutError:
logger.warning("wait_closed timeout, force close")
def to_request_output(self) -> RequestOutput: def to_request_output(self) -> RequestOutput:
"""Convert to RequestOutput for external use.""" """Convert to RequestOutput for external use."""
......
...@@ -7,7 +7,12 @@ import queue ...@@ -7,7 +7,12 @@ import queue
import janus import janus
from typing import List, Optional from typing import List, Optional
from infinilm.llm.request import RequestStatus, InferenceRequest, FinishReason from infinilm.llm.request import (
RequestStatus,
InferenceRequest,
FinishReason,
TokenOutput,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -115,6 +120,21 @@ class StaticScheduler: ...@@ -115,6 +120,21 @@ class StaticScheduler:
) )
self.running_request = None self.running_request = None
req.mark_failed(FinishReason.LENGTH) req.mark_failed(FinishReason.LENGTH)
output = TokenOutput(
request_id=req.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=req.finish_reason,
generated_text=req.generated_text,
)
try:
req.output_queue.sync_q.put(output)
except Exception as e:
logger.warning(
f"Failed to put completion token for {req.request_id}: {e}. "
f"Likely due to client disconnecting or request cancelation."
)
continue continue
return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False) return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False)
...@@ -137,6 +157,21 @@ class StaticScheduler: ...@@ -137,6 +157,21 @@ class StaticScheduler:
) )
req.mark_failed(FinishReason.LENGTH) req.mark_failed(FinishReason.LENGTH)
output = TokenOutput(
request_id=req.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=req.finish_reason,
generated_text=req.generated_text,
)
try:
req.output_queue.sync_q.put(output)
except Exception as e:
logger.warning(
f"Failed to put completion token for {req.request_id}: {e}. "
f"Likely due to client disconnecting or request cancelation."
)
continue continue
req.status = RequestStatus.RUNNING req.status = RequestStatus.RUNNING
......
...@@ -11,6 +11,7 @@ import argparse ...@@ -11,6 +11,7 @@ import argparse
import uvicorn import uvicorn
import logging import logging
import os import os
import asyncio
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
...@@ -351,6 +352,12 @@ class InferenceServer: ...@@ -351,6 +352,12 @@ class InferenceServer:
timeout=DEFAULT_STREAM_TIMEOUT, timeout=DEFAULT_STREAM_TIMEOUT,
request_timeout=DEFAULT_REQUEST_TIMEOUT, request_timeout=DEFAULT_REQUEST_TIMEOUT,
): ):
# Check client disconnect
if await http_request.is_disconnected():
logger.info(f"Client disconnected for request {request_id}")
req.mark_canceled()
break
# If stream_request enforces timeout, we can just surface the state to the client. # If stream_request enforces timeout, we can just surface the state to the client.
if token_output.finish_reason == FinishReason.TIMEOUT: if token_output.finish_reason == FinishReason.TIMEOUT:
logger.warning( logger.warning(
...@@ -368,12 +375,6 @@ class InferenceServer: ...@@ -368,12 +375,6 @@ class InferenceServer:
yield f"data: {error_chunk}\n\n" yield f"data: {error_chunk}\n\n"
break break
# Check client disconnect
if await http_request.is_disconnected():
logger.info(f"Client disconnected for request {request_id}")
req.mark_canceled()
break
# Skip EOS token text for OpenAI API compatibility # Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids # Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids eos_token_ids = self.engine.engine.eos_token_ids
...@@ -404,6 +405,12 @@ class InferenceServer: ...@@ -404,6 +405,12 @@ class InferenceServer:
yield f"data: {chunk}\n\n" yield f"data: {chunk}\n\n"
break break
except asyncio.CancelledError:
logger.info(f"Request {request_id} was cancelled")
if req:
req.mark_canceled()
raise
except Exception as e: except Exception as e:
logger.error(f"Stream error for {request_id}: {e}", exc_info=True) logger.error(f"Stream error for {request_id}: {e}", exc_info=True)
if req: if req:
...@@ -451,23 +458,23 @@ class InferenceServer: ...@@ -451,23 +458,23 @@ class InferenceServer:
timeout=DEFAULT_STREAM_TIMEOUT, timeout=DEFAULT_STREAM_TIMEOUT,
request_timeout=DEFAULT_REQUEST_TIMEOUT, request_timeout=DEFAULT_REQUEST_TIMEOUT,
): ):
# Request-level timeout is handled inside stream_request.
if token_output.finish_reason == FinishReason.TIMEOUT:
logger.warning(f"Request {request_id} timed out")
break
# Check client disconnect # Check client disconnect
if await http_request.is_disconnected(): if await http_request.is_disconnected():
logger.info(f"Client disconnected for request {request_id}") logger.info(f"Client disconnected for request {request_id}")
req.mark_canceled() req.mark_canceled()
break break
# Request-level timeout is handled inside stream_request.
if token_output.finish_reason == FinishReason.TIMEOUT:
logger.warning(f"Request {request_id} timed out")
break
# Skip EOS token text for OpenAI API compatibility # Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids # Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids eos_token_ids = self.engine.engine.eos_token_ids
is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids
if not is_eos_token: if not is_eos_token and token_output.token_text:
output_text += token_output.token_text output_text += token_output.token_text
if token_output.finished: if token_output.finished:
...@@ -488,6 +495,12 @@ class InferenceServer: ...@@ -488,6 +495,12 @@ class InferenceServer:
) )
return response return response
except asyncio.CancelledError:
logger.info(f"Request {request_id} was cancelled")
if req:
req.mark_canceled()
raise
except Exception as e: except Exception as e:
logger.error(f"Chat error for {request_id}: {e}", exc_info=True) logger.error(f"Chat error for {request_id}: {e}", exc_info=True)
if req: if req:
......
...@@ -4,7 +4,6 @@ from openai import AsyncOpenAI ...@@ -4,7 +4,6 @@ from openai import AsyncOpenAI
import argparse import argparse
import random import random
PROMPTS = [ PROMPTS = [
"如果猫能写诗,它们会写些什么?", "如果猫能写诗,它们会写些什么?",
"描述一个没有重力的世界。", "描述一个没有重力的世界。",
...@@ -25,11 +24,11 @@ PROMPTS = [ ...@@ -25,11 +24,11 @@ PROMPTS = [
"如果你可以变成任何一种动物,你会选择什么?", "如果你可以变成任何一种动物,你会选择什么?",
"描述一个由机器人统治的未来世界。", "描述一个由机器人统治的未来世界。",
"如果你能与任何虚构角色成为朋友,你会选择谁?", "如果你能与任何虚构角色成为朋友,你会选择谁?",
"想象一下,如果每个人都能读懂他人的思想。" "想象一下,如果每个人都能读懂他人的思想。",
] ]
NUM_REQUESTS = 10 NUM_REQUESTS = 64
CONCURRENCY = 5 CONCURRENCY = 20
API_URL = "http://127.0.0.1:8000" API_URL = "http://127.0.0.1:8000"
MODEL = "FM9G-7B" MODEL = "FM9G-7B"
...@@ -43,14 +42,14 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose): ...@@ -43,14 +42,14 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose):
break break
question = random.choice(PROMPTS) question = random.choice(PROMPTS)
try: try:
print(f"🚀 User#{user_id} Sending request #{task_id}") print(f"🚀 User#{user_id} Sending request #{task_id}")
start_time = time.time() start_time = time.time()
stream = await client.chat.completions.create( stream = await client.chat.completions.create(
model=MODEL, model=MODEL,
messages=[{"role": "user", "content": question}], messages=[{"role": "user", "content": question}],
stream=True stream=True,
) )
first_token_time = None first_token_time = None
...@@ -71,19 +70,33 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose): ...@@ -71,19 +70,33 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose):
ttft = first_token_time - start_time if first_token_time else None ttft = first_token_time - start_time if first_token_time else None
elapsed_time = end_time - start_time if start_time else None elapsed_time = end_time - start_time if start_time else None
ms_per_token = (elapsed_time / total_tokens * 1000) if total_tokens > 0 and elapsed_time else None ms_per_token = (
tokens_per_second = total_tokens / elapsed_time if elapsed_time > 0 else 0 (elapsed_time / total_tokens * 1000)
if total_tokens > 0 and elapsed_time
else None
)
tokens_per_second = (
total_tokens / elapsed_time if elapsed_time > 0 else 0
)
answer = "".join(answer_chunks) answer = "".join(answer_chunks)
results.append((total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token)) results.append(
(total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token)
)
if verbose: if verbose:
print(f"\n📝 Request #{task_id} (User #{user_id})") print(f"\n📝 Request #{task_id} (User #{user_id})")
print(f" ⏱ 首字延迟 TTFT: {ttft:.3f}s") if ttft is not None:
print(f" ⏱ 总耗时: {elapsed_time:.3f}s") print(f" ⏱ 首字延迟 TTFT: {ttft:.3f}s")
if elapsed_time is not None:
print(f" ⏱ 总耗时: {elapsed_time:.3f}s")
print(f" 🔤 解码 token 总数: {total_tokens}") print(f" 🔤 解码 token 总数: {total_tokens}")
print(f" 📏 平均 token 解码时间: {ms_per_token:.2f} ms/token") if ms_per_token is not None:
print(f" 📏 平均 token 解码时间: {ms_per_token:.2f} ms/token")
else:
print(f" 📏 平均 token 解码时间: N/A (no token generated)")
print(f" ❓ 提问: {question}") print(f" ❓ 提问: {question}")
print(f" 💬 回答: {answer}\n") print(f" 💬 回答: {answer}\n")
...@@ -92,6 +105,8 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose): ...@@ -92,6 +105,8 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose):
if verbose: if verbose:
print(f"\n⚠️ Request #{task_id} (User #{user_id}) FAILED:") print(f"\n⚠️ Request #{task_id} (User #{user_id}) FAILED:")
print(f" ❌ Error: {e}\n") print(f" ❌ Error: {e}\n")
queue.task_done()
async def run_benchmark(verbose=False): async def run_benchmark(verbose=False):
client = AsyncOpenAI(base_url=API_URL, api_key="default") client = AsyncOpenAI(base_url=API_URL, api_key="default")
...@@ -104,7 +119,9 @@ async def run_benchmark(verbose=False): ...@@ -104,7 +119,9 @@ async def run_benchmark(verbose=False):
await queue.put(None) await queue.put(None)
users = [ users = [
asyncio.create_task(benchmark_user(client, semaphore, queue, results, user_id, verbose)) asyncio.create_task(
benchmark_user(client, semaphore, queue, results, user_id, verbose)
)
for user_id in range(CONCURRENCY) for user_id in range(CONCURRENCY)
] ]
...@@ -121,11 +138,19 @@ async def run_benchmark(verbose=False): ...@@ -121,11 +138,19 @@ async def run_benchmark(verbose=False):
ms_per_token_list = [r[4] for r in results if r and r[4] is not None] ms_per_token_list = [r[4] for r in results if r and r[4] is not None]
successful_requests = len(results) successful_requests = len(results)
requests_per_second = successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 requests_per_second = (
successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0
)
avg_latency = sum(latencies) / len(latencies) if latencies else 0 avg_latency = sum(latencies) / len(latencies) if latencies else 0
avg_tokens_per_second = sum(tokens_per_second_list) / len(tokens_per_second_list) if tokens_per_second_list else 0 avg_tokens_per_second = (
sum(tokens_per_second_list) / len(tokens_per_second_list)
if tokens_per_second_list
else 0
)
avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0 avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0
avg_ms_per_token = sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None avg_ms_per_token = (
sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None
)
width_label = 24 width_label = 24
sep = "-" * 60 sep = "-" * 60
...@@ -142,7 +167,9 @@ async def run_benchmark(verbose=False): ...@@ -142,7 +167,9 @@ async def run_benchmark(verbose=False):
print(f"{'Average latency':<{width_label}}: {avg_latency:.2f} s") print(f"{'Average latency':<{width_label}}: {avg_latency:.2f} s")
print(f"{'Average TTFT':<{width_label}}: {avg_ttft:.2f} s") print(f"{'Average TTFT':<{width_label}}: {avg_ttft:.2f} s")
print(f"{'Avg time per token':<{width_label}}: {avg_ms_per_token:.2f} ms/token") print(f"{'Avg time per token':<{width_label}}: {avg_ms_per_token:.2f} ms/token")
print(f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s") print(
f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s"
)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -150,6 +177,4 @@ if __name__ == "__main__": ...@@ -150,6 +177,4 @@ if __name__ == "__main__":
parser.add_argument("--verbose", action="store_true") parser.add_argument("--verbose", action="store_true")
args = parser.parse_args() args = parser.parse_args()
asyncio.run(run_benchmark( asyncio.run(run_benchmark(args.verbose))
args.verbose
))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment