Unverified Commit 58405177 authored by bin.pan's avatar bin.pan Committed by GitHub
Browse files

feat: Support a dynamic default max_tokens for VLLM backend (#4156)


Signed-off-by: default avatarbin <bin.pan@daocloud.io>
parent 4f2cbec0
......@@ -31,7 +31,9 @@ logger = logging.getLogger(__name__)
def build_sampling_params(
request: Dict[str, Any], default_sampling_params: Dict[str, Any]
request: Dict[str, Any],
default_sampling_params: Dict[str, Any],
model_max_len: int | None = None,
) -> SamplingParams:
"""
Build SamplingParams from a PreprocessedRequest.
......@@ -59,6 +61,15 @@ def build_sampling_params(
continue
setattr(sampling_params, key, value)
# If max_tokens wasn't provided (None or missing), compute a dynamic default
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
token_ids = request.get("token_ids", [])
input_length = len(token_ids)
if model_max_len is not None and (provided_max_tokens is None):
# Ensure at least 1 token generation by default when possible
dynamic_default = max(1, model_max_len - input_length)
sampling_params.max_tokens = dynamic_default
return sampling_params
......@@ -67,7 +78,14 @@ class BaseWorkerHandler(ABC):
Request handler for the generate and clear_kv_blocks endpoints.
"""
def __init__(self, runtime, component, engine, default_sampling_params):
def __init__(
self,
runtime,
component,
engine,
default_sampling_params,
model_max_len: int | None = None,
):
self.runtime = runtime
self.component = component
self.engine_client = engine
......@@ -76,6 +94,7 @@ class BaseWorkerHandler(ABC):
self.engine_monitor = VllmEngineMonitor(runtime, engine)
self.image_loader = ImageLoader()
self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len
@abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
......@@ -251,8 +270,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
component,
engine,
default_sampling_params,
model_max_len: int | None = None,
):
super().__init__(runtime, component, engine, default_sampling_params)
super().__init__(
runtime, component, engine, default_sampling_params, model_max_len
)
async def generate(self, request, context):
# Use context ID for request tracking and correlation
......@@ -267,7 +289,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
)
# Build sampling params from request
sampling_params = build_sampling_params(request, self.default_sampling_params)
sampling_params = build_sampling_params(
request, self.default_sampling_params, self.model_max_len
)
prefill_result = request.get("prefill_result")
if prefill_result and isinstance(prefill_result, dict):
......@@ -308,8 +332,17 @@ class DecodeWorkerHandler(BaseWorkerHandler):
class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(self, runtime, component, engine, default_sampling_params):
super().__init__(runtime, component, engine, default_sampling_params)
def __init__(
self,
runtime,
component,
engine,
default_sampling_params,
model_max_len: int | None = None,
):
super().__init__(
runtime, component, engine, default_sampling_params, model_max_len
)
async def generate(self, request, context):
# Use context ID for request tracking and correlation with decode phase
......@@ -325,7 +358,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
)
# Build sampling params from request using shared utility
sampling_params = build_sampling_params(request, self.default_sampling_params)
sampling_params = build_sampling_params(
request, self.default_sampling_params, self.model_max_len
)
# Configure for prefill-only mode with remote decode
if sampling_params.extra_args is None:
......
......@@ -339,7 +339,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
) = setup_vllm_engine(config)
handler = PrefillWorkerHandler(
runtime, component, engine_client, default_sampling_params
runtime,
component,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
)
handler.add_temp_dir(prometheus_temp_dir)
......@@ -450,6 +454,7 @@ async def init(runtime: DistributedRuntime, config: Config):
component,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
)
handler.add_temp_dir(prometheus_temp_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment