Unverified Commit 6b0f2e90 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Add `--max-total-tokens` (#840)

parent 1edd4e07
......@@ -19,6 +19,7 @@ import importlib
import importlib.resources
import logging
import pkgutil
import warnings
from functools import lru_cache
from typing import Optional, Type
......@@ -121,7 +122,11 @@ class ModelRunner:
# Load the model and create memory pool
self.load_model()
self.init_memory_pool(total_gpu_memory, server_args.max_num_reqs)
self.init_memory_pool(
total_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flash_infer()
......@@ -203,8 +208,18 @@ class ModelRunner:
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
def init_memory_pool(self, total_gpu_memory, max_num_reqs=None):
def init_memory_pool(
self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
):
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_total_tokens is not None:
if max_total_tokens > self.max_total_num_tokens:
warnings.warn(
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
f"{self.max_total_num_tokens}. "
f"Use the profiled value instead."
)
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
if self.max_total_num_tokens <= 0:
raise RuntimeError(
......
......@@ -44,6 +44,7 @@ class ServerArgs:
max_prefill_tokens: Optional[int] = None
max_running_requests: Optional[int] = None
max_num_reqs: Optional[int] = None
max_total_tokens: Optional[int] = None
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
......@@ -231,6 +232,12 @@ class ServerArgs:
default=ServerArgs.max_num_reqs,
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
)
parser.add_argument(
"--max-total-tokens",
type=int,
default=ServerArgs.max_total_tokens,
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
)
parser.add_argument(
"--schedule-policy",
type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment