Unverified Commit 4fe58953 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[AsyncScheduling] Make async overlap work with logprobs (#27615)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent 111faf11
...@@ -831,8 +831,9 @@ class VllmRunner: ...@@ -831,8 +831,9 @@ class VllmRunner:
images: PromptImageInput | None = None, images: PromptImageInput | None = None,
videos: PromptVideoInput | None = None, videos: PromptVideoInput | None = None,
audios: PromptAudioInput | None = None, audios: PromptAudioInput | None = None,
return_logprobs: bool = False,
**kwargs: Any, **kwargs: Any,
) -> list[tuple[list[list[int]], list[str]]]: ) -> list[tuple[list[list[int]], list[str]]] | tuple[list, list]:
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
req_outputs = self.llm.generate( req_outputs = self.llm.generate(
...@@ -840,18 +841,23 @@ class VllmRunner: ...@@ -840,18 +841,23 @@ class VllmRunner:
) )
outputs: list[tuple[list[list[int]], list[str]]] = [] outputs: list[tuple[list[list[int]], list[str]]] = []
logprobs = []
for req_output in req_outputs: for req_output in req_outputs:
prompt_str = req_output.prompt prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids prompt_ids = req_output.prompt_token_ids
req_sample_output_ids: list[list[int]] = [] req_sample_output_ids: list[list[int]] = []
req_sample_output_strs: list[str] = [] req_sample_output_strs: list[str] = []
req_logprobs = []
for sample in req_output.outputs: for sample in req_output.outputs:
output_str = sample.text output_str = sample.text
output_ids = list(sample.token_ids) output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids) req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append((prompt_str or "") + output_str) req_sample_output_strs.append((prompt_str or "") + output_str)
if sample.logprobs:
req_logprobs.extend(sample.logprobs)
outputs.append((req_sample_output_ids, req_sample_output_strs)) outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs logprobs.append(req_logprobs)
return outputs if not return_logprobs else (outputs, logprobs)
@staticmethod @staticmethod
def _final_steps_generate_w_logprobs( def _final_steps_generate_w_logprobs(
......
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import torch._dynamo.config as dynamo_config import torch._dynamo.config as dynamo_config
from vllm import SamplingParams from vllm import SamplingParams
from vllm.logprobs import Logprob
from ...conftest import VllmRunner from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal from ...models.utils import check_outputs_equal
...@@ -32,6 +33,8 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): ...@@ -32,6 +33,8 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
# dict(min_tokens=20), # dict(min_tokens=20),
dict(presence_penalty=-1.0), dict(presence_penalty=-1.0),
dict(bad_words=["the", " the"]), dict(bad_words=["the", " the"]),
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
] ]
default_params = dict( default_params = dict(
...@@ -77,29 +80,33 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): ...@@ -77,29 +80,33 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
sampling_params=SamplingParams( sampling_params=SamplingParams(
**default_params, **override_params **default_params, **override_params
), ),
return_logprobs=True,
) )
) )
if not outputs: if not outputs:
# First check that the different parameter configs # First check that the different parameter configs
# actually result in different output. # actually result in different output.
for other_test, params in zip( for (other_test_outs, other_test_logprobs), params in zip(
results[1:], sampling_param_tests[1:] results[1:], sampling_param_tests[1:]
): ):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
check_outputs_equal( check_outputs_equal(
outputs_0_lst=results[0], outputs_0_lst=results[0][0],
outputs_1_lst=other_test, outputs_1_lst=other_test_outs,
name_0=f"baseline params={params}", name_0=f"baseline params={params}",
name_1=f"other params={params}", name_1=f"other params={params}",
) )
assert _all_logprobs_match(
results[0][1], other_test_logprobs
)
outputs.append((test_config, results)) outputs.append((test_config, results))
baseline_config, baseline_tests = outputs[0] baseline_config, baseline_tests = outputs[0]
for test_config, test_outputs in outputs[1:]: for test_config, test_outputs in outputs[1:]:
for base_outs, test_outs, params in zip( for (base_outs, base_logprobs), (test_outs, test_logprobs), params in zip(
baseline_tests, test_outputs, sampling_param_tests baseline_tests, test_outputs, sampling_param_tests
): ):
check_outputs_equal( check_outputs_equal(
...@@ -108,5 +115,27 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): ...@@ -108,5 +115,27 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
name_0=f"baseline=[{baseline_config}], params={params}", name_0=f"baseline=[{baseline_config}], params={params}",
name_1=f"config=[{test_config}], params={params}", name_1=f"config=[{test_config}], params={params}",
) )
assert _all_logprobs_match(base_logprobs, test_logprobs)
print(f"PASSED: config=[{test_config}], params={params}") print(f"PASSED: config=[{test_config}], params={params}")
def _all_logprobs_match(req_a, req_b) -> bool:
return (
req_a == req_b
or len(req_a) == len(req_b)
and all(
len(seq_a) == len(seq_b)
and all(_logprobs_match(a, b) for a, b in zip(seq_a, seq_b))
for seq_a, seq_b in zip(req_a, req_b)
)
)
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
return len(lps_a) == len(lps_b) and all(
a.decoded_token == b.decoded_token
and a.rank == b.rank
and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6)
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
)
...@@ -59,6 +59,15 @@ class LogprobsTensors(NamedTuple): ...@@ -59,6 +59,15 @@ class LogprobsTensors(NamedTuple):
cu_num_generated_tokens, cu_num_generated_tokens,
) )
def to_cpu_nonblocking(self) -> "LogprobsTensors":
if self.logprob_token_ids.device.type == "cpu":
return self
return LogprobsTensors(
self.logprob_token_ids.to("cpu", non_blocking=True),
self.logprobs.to("cpu", non_blocking=True),
self.selected_token_ranks.to("cpu", non_blocking=True),
)
@staticmethod @staticmethod
def empty_cpu( def empty_cpu(
num_positions: int, num_tokens_per_position: int num_positions: int, num_tokens_per_position: int
......
...@@ -164,6 +164,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): ...@@ -164,6 +164,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self, self,
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
sampled_token_ids: torch.Tensor, sampled_token_ids: torch.Tensor,
logprobs_tensors: torch.Tensor | None,
invalid_req_indices: list[int], invalid_req_indices: list[int],
async_output_copy_stream: torch.cuda.Stream, async_output_copy_stream: torch.cuda.Stream,
): ):
...@@ -176,6 +177,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): ...@@ -176,6 +177,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
# Keep a reference to the device tensor to avoid it being # Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host. # deallocated until we finish copying it to the host.
self._sampled_token_ids = sampled_token_ids self._sampled_token_ids = sampled_token_ids
self._logprobs_tensors = logprobs_tensors
# Initiate the copy on a separate stream, but do not synchronize it. # Initiate the copy on a separate stream, but do not synchronize it.
default_stream = torch.cuda.current_stream() default_stream = torch.cuda.current_stream()
...@@ -184,6 +186,11 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): ...@@ -184,6 +186,11 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self.sampled_token_ids_cpu = self._sampled_token_ids.to( self.sampled_token_ids_cpu = self._sampled_token_ids.to(
"cpu", non_blocking=True "cpu", non_blocking=True
) )
self._logprobs_tensors_cpu = (
self._logprobs_tensors.to_cpu_nonblocking()
if self._logprobs_tensors
else None
)
self.async_copy_ready_event.record() self.async_copy_ready_event.record()
def get_output(self) -> ModelRunnerOutput: def get_output(self) -> ModelRunnerOutput:
...@@ -193,7 +200,8 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): ...@@ -193,7 +200,8 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
""" """
self.async_copy_ready_event.synchronize() self.async_copy_ready_event.synchronize()
# Release the device tensor once the copy has completed # Release the device tensors once the copy has completed.
del self._logprobs_tensors
del self._sampled_token_ids del self._sampled_token_ids
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
...@@ -202,6 +210,10 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): ...@@ -202,6 +210,10 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
output = self._model_runner_output output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids output.sampled_token_ids = valid_sampled_token_ids
if self._logprobs_tensors_cpu:
# NOTE(nick): this will need to be updated to use cu_num_accepted_tokens
# for async sched + spec decode + logprobs compatibility.
output.logprobs = self._logprobs_tensors_cpu.tolists()
return output return output
...@@ -2334,11 +2346,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2334,11 +2346,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_accepted_tokens[-1] + len(sampled_ids) cu_num_accepted_tokens[-1] + len(sampled_ids)
) )
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_lists = ( logprobs_lists = (
logprobs_tensors.tolists(cu_num_accepted_tokens) logprobs_tensors.tolists(cu_num_accepted_tokens)
if logprobs_tensors is not None if not self.use_async_scheduling and logprobs_tensors is not None
else None else None
) )
...@@ -2664,6 +2674,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2664,6 +2674,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
async_output = AsyncGPUModelRunnerOutput( async_output = AsyncGPUModelRunnerOutput(
model_runner_output=output, model_runner_output=output,
sampled_token_ids=sampler_output.sampled_token_ids, sampled_token_ids=sampler_output.sampled_token_ids,
logprobs_tensors=sampler_output.logprobs_tensors,
invalid_req_indices=invalid_req_indices, invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream, async_output_copy_stream=self.async_output_copy_stream,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment