Unverified Commit 58405177 authored by bin.pan's avatar bin.pan Committed by GitHub
Browse files

feat: Support a dynamic default max_tokens for VLLM backend (#4156)


Signed-off-by: default avatarbin <bin.pan@daocloud.io>
parent 4f2cbec0
...@@ -31,7 +31,9 @@ logger = logging.getLogger(__name__) ...@@ -31,7 +31,9 @@ logger = logging.getLogger(__name__)
def build_sampling_params( def build_sampling_params(
request: Dict[str, Any], default_sampling_params: Dict[str, Any] request: Dict[str, Any],
default_sampling_params: Dict[str, Any],
model_max_len: int | None = None,
) -> SamplingParams: ) -> SamplingParams:
""" """
Build SamplingParams from a PreprocessedRequest. Build SamplingParams from a PreprocessedRequest.
...@@ -59,6 +61,15 @@ def build_sampling_params( ...@@ -59,6 +61,15 @@ def build_sampling_params(
continue continue
setattr(sampling_params, key, value) setattr(sampling_params, key, value)
# If max_tokens wasn't provided (None or missing), compute a dynamic default
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
token_ids = request.get("token_ids", [])
input_length = len(token_ids)
if model_max_len is not None and (provided_max_tokens is None):
# Ensure at least 1 token generation by default when possible
dynamic_default = max(1, model_max_len - input_length)
sampling_params.max_tokens = dynamic_default
return sampling_params return sampling_params
...@@ -67,7 +78,14 @@ class BaseWorkerHandler(ABC): ...@@ -67,7 +78,14 @@ class BaseWorkerHandler(ABC):
Request handler for the generate and clear_kv_blocks endpoints. Request handler for the generate and clear_kv_blocks endpoints.
""" """
def __init__(self, runtime, component, engine, default_sampling_params): def __init__(
self,
runtime,
component,
engine,
default_sampling_params,
model_max_len: int | None = None,
):
self.runtime = runtime self.runtime = runtime
self.component = component self.component = component
self.engine_client = engine self.engine_client = engine
...@@ -76,6 +94,7 @@ class BaseWorkerHandler(ABC): ...@@ -76,6 +94,7 @@ class BaseWorkerHandler(ABC):
self.engine_monitor = VllmEngineMonitor(runtime, engine) self.engine_monitor = VllmEngineMonitor(runtime, engine)
self.image_loader = ImageLoader() self.image_loader = ImageLoader()
self.temp_dirs: list[tempfile.TemporaryDirectory] = [] self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len
@abstractmethod @abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]: async def generate(self, request, context) -> AsyncGenerator[dict, None]:
...@@ -251,8 +270,11 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -251,8 +270,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
component, component,
engine, engine,
default_sampling_params, default_sampling_params,
model_max_len: int | None = None,
): ):
super().__init__(runtime, component, engine, default_sampling_params) super().__init__(
runtime, component, engine, default_sampling_params, model_max_len
)
async def generate(self, request, context): async def generate(self, request, context):
# Use context ID for request tracking and correlation # Use context ID for request tracking and correlation
...@@ -267,7 +289,9 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -267,7 +289,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
) )
# Build sampling params from request # Build sampling params from request
sampling_params = build_sampling_params(request, self.default_sampling_params) sampling_params = build_sampling_params(
request, self.default_sampling_params, self.model_max_len
)
prefill_result = request.get("prefill_result") prefill_result = request.get("prefill_result")
if prefill_result and isinstance(prefill_result, dict): if prefill_result and isinstance(prefill_result, dict):
...@@ -308,8 +332,17 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -308,8 +332,17 @@ class DecodeWorkerHandler(BaseWorkerHandler):
class PrefillWorkerHandler(BaseWorkerHandler): class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(self, runtime, component, engine, default_sampling_params): def __init__(
super().__init__(runtime, component, engine, default_sampling_params) self,
runtime,
component,
engine,
default_sampling_params,
model_max_len: int | None = None,
):
super().__init__(
runtime, component, engine, default_sampling_params, model_max_len
)
async def generate(self, request, context): async def generate(self, request, context):
# Use context ID for request tracking and correlation with decode phase # Use context ID for request tracking and correlation with decode phase
...@@ -325,7 +358,9 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -325,7 +358,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
) )
# Build sampling params from request using shared utility # Build sampling params from request using shared utility
sampling_params = build_sampling_params(request, self.default_sampling_params) sampling_params = build_sampling_params(
request, self.default_sampling_params, self.model_max_len
)
# Configure for prefill-only mode with remote decode # Configure for prefill-only mode with remote decode
if sampling_params.extra_args is None: if sampling_params.extra_args is None:
......
...@@ -339,7 +339,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -339,7 +339,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
) = setup_vllm_engine(config) ) = setup_vllm_engine(config)
handler = PrefillWorkerHandler( handler = PrefillWorkerHandler(
runtime, component, engine_client, default_sampling_params runtime,
component,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
...@@ -450,6 +454,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -450,6 +454,7 @@ async def init(runtime: DistributedRuntime, config: Config):
component, component,
engine_client, engine_client,
default_sampling_params, default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment