Unverified Commit 32c40b95 authored by Avishek Goswami's avatar Avishek Goswami Committed by GitHub
Browse files

[BugFix] bad_words filtering ineffective when n > 1 (#29313)


Signed-off-by: default avatarGOavi101 <1704178@kiit.ac.in>
parent db290610
...@@ -72,6 +72,14 @@ class EngineCoreRequest( ...@@ -72,6 +72,14 @@ class EngineCoreRequest(
trace_headers: Mapping[str, str] | None = None trace_headers: Mapping[str, str] | None = None
@property
def params(self) -> SamplingParams | PoolingParams:
"""Return the processed params (sampling or pooling)."""
if self.sampling_params is not None:
return self.sampling_params
assert self.pooling_params is not None
return self.pooling_params
class EngineCoreEventType(enum.IntEnum): class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event.""" """The type of engine core request event."""
......
...@@ -321,14 +321,15 @@ class AsyncLLM(EngineClient): ...@@ -321,14 +321,15 @@ class AsyncLLM(EngineClient):
elif isinstance(prompt, Mapping): elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt")) prompt_text = cast(str | None, prompt.get("prompt"))
# Use cloned params that may have been updated in process_inputs()
params = request.params
if is_pooling or params.n == 1: if is_pooling or params.n == 1:
await self._add_request(request, prompt_text, None, 0, queue) await self._add_request(request, prompt_text, None, 0, queue)
return queue return queue
# Get the updated SamplingParams from the request, which parent_params = params
# were cloned/updated in processor.process_inputs above. assert isinstance(parent_params, SamplingParams)
parent_params = request.sampling_params
assert parent_params is not None
# Fan out child requests (for n>1). # Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, parent_params) parent_request = ParentRequest(request_id, parent_params)
......
...@@ -250,6 +250,9 @@ class LLMEngine: ...@@ -250,6 +250,9 @@ class LLMEngine:
elif isinstance(prompt, Mapping): elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt")) prompt_text = cast(str | None, prompt.get("prompt"))
# Use cloned params that may have been updated in process_inputs()
params = request.params
n = params.n if isinstance(params, SamplingParams) else 1 n = params.n if isinstance(params, SamplingParams) else 1
if n == 1: if n == 1:
...@@ -262,10 +265,10 @@ class LLMEngine: ...@@ -262,10 +265,10 @@ class LLMEngine:
# Fan out child requests (for n>1). # Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params) parent_req = ParentRequest(request_id, params)
for idx in range(n): for idx in range(n):
request_id, params = parent_req.get_child_info(idx) request_id, child_params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request) child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id child_request.request_id = request_id
child_request.sampling_params = params child_request.sampling_params = child_params
# Make a new RequestState and queue. # Make a new RequestState and queue.
self.output_processor.add_request( self.output_processor.add_request(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment