Unverified Commit 169313b9 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Misc] Make handling of SamplingParams clearer in n>1 case (#26032)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent 0b018d8b
...@@ -289,13 +289,19 @@ class AsyncLLM(EngineClient): ...@@ -289,13 +289,19 @@ class AsyncLLM(EngineClient):
await self._add_request(request, prompt_str, None, 0, queue) await self._add_request(request, prompt_str, None, 0, queue)
return queue return queue
# Get the updated SamplingParams from the request, which
# were cloned/updated in processor.process_inputs above.
parent_params = request.sampling_params
assert parent_params is not None
# Fan out child requests (for n>1). # Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, request.sampling_params) parent_request = ParentRequest(request_id, parent_params)
for idx in range(params.n): for idx in range(parent_params.n):
request_id, params = parent_request.get_child_info(idx) request_id, child_params = parent_request.get_child_info(idx)
child_request = request if idx == params.n - 1 else copy(request) child_request = request if idx == parent_params.n - 1 else copy(
request)
child_request.request_id = request_id child_request.request_id = request_id
child_request.sampling_params = params child_request.sampling_params = child_params
await self._add_request(child_request, prompt_str, parent_request, await self._add_request(child_request, prompt_str, parent_request,
idx, queue) idx, queue)
return queue return queue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment