Unverified Commit 428dd144 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[Core] Logprobs support in Multi-step (#7652)

parent 4abed65c
import warnings import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import SampleLogprobs from vllm.sequence import Logprob, SampleLogprobs
TokensText = Tuple[List[int], str] TokensText = Tuple[List[int], str]
...@@ -38,34 +38,39 @@ TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int, ...@@ -38,34 +38,39 @@ TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float]], float]],
SampleLogprobs]]] SampleLogprobs]]]
# Allow for tokens to be represented as str's rather than IDs
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
List[Dict[str,
Logprob]]]]]
def check_logprobs_close( def check_logprobs_close(
*, *,
outputs_0_lst: Sequence[TokensTextLogprobs], outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
outputs_1_lst: Sequence[TokensTextLogprobs], outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
name_0: str, name_0: str,
name_1: str, name_1: str,
num_outputs_0_skip_tokens: int = 0, num_outputs_0_skip_tokens: int = 0,
warn_on_mismatch: bool = True, warn_on_mismatch: bool = True,
): always_check_logprobs: bool = False,
""" ) -> None:
Compare the logprobs of two sequences generated by different models, """Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal. which should be similar but not necessarily equal.
Arguments: Args:
outputs_0_lst: First sequence to compare
* outputs_0_lst: First sequence to compare outputs_0_lst: Second sequence to compare
* outputs_0_lst: Second sequence to compare name_0: sequence #0 name
* name_0: sequence #0 name name_1: sequence #1 name
* name_1: sequence #1 name num_outputs_0_skip_tokens: If > 0, specifies the number of initial
* num_outputs_0_skip_tokens: If > 0, specifies the number of initial
sequence #0 tokens & logprobs to discard sequence #0 tokens & logprobs to discard
before comparison, i.e. all before comparison, i.e. all
of sequence #1 will be compared to of sequence #1 will be compared to
sequence #0 beginning at index sequence #0 beginning at index
num_outputs_0_skip_tokens num_outputs_0_skip_tokens
* warn_on_mismatch: Issue a warning if there is token-wise or text-wise warn_on_mismatch: Issue a warning if there is token-wise or text-wise
mismatch between the two sequences mismatch between the two sequences
always_check_logprobs: If true, check logprobs even when tokens match
""" """
assert len(outputs_0_lst) == len(outputs_1_lst) assert len(outputs_0_lst) == len(outputs_1_lst)
...@@ -94,8 +99,12 @@ def check_logprobs_close( ...@@ -94,8 +99,12 @@ def check_logprobs_close(
for idx, (output_id_0, for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
# If generated tokens don't match, then is_tok_mismatch = output_id_0 != output_id_1
if output_id_0 != output_id_1:
# If generated tokens don't match
# or it is desired to always check logprobs,
# then
if is_tok_mismatch or always_check_logprobs:
logprobs_elem_0 = logprobs_0[idx] logprobs_elem_0 = logprobs_0[idx]
logprobs_elem_1 = logprobs_1[idx] logprobs_elem_1 = logprobs_1[idx]
...@@ -111,7 +120,7 @@ def check_logprobs_close( ...@@ -111,7 +120,7 @@ def check_logprobs_close(
assert output_id_0 in logprobs_elem_1, fail_msg assert output_id_0 in logprobs_elem_1, fail_msg
assert output_id_1 in logprobs_elem_0, fail_msg assert output_id_1 in logprobs_elem_0, fail_msg
if warn_on_mismatch: if warn_on_mismatch and is_tok_mismatch:
with warnings.catch_warnings(): with warnings.catch_warnings():
# This ensures that repeated warnings are shown # This ensures that repeated warnings are shown
# in the output, not just the first occurrence # in the output, not just the first occurrence
......
# Test the AsyncLLMEngine with multi-step-decoding # Test the AsyncLLMEngine with multi-step-decoding
from typing import List from typing import List, Optional
import pytest import pytest
from ..utils import RemoteOpenAIServer from ..models.utils import check_logprobs_close
from ..utils import (completions_with_server_args, get_client_text_generations,
get_client_text_logprob_generations)
MODELS = [ MODELS = [
"JackFram/llama-160m", "JackFram/llama-160m",
...@@ -23,22 +25,6 @@ DEFAULT_SERVER_ARGS: List[str] = [ ...@@ -23,22 +25,6 @@ DEFAULT_SERVER_ARGS: List[str] = [
] ]
async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):
outputs = None
with RemoteOpenAIServer(model_name, server_cli_args) as server:
async with server.get_async_client() as client:
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5)
assert outputs is not None
return outputs
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize(("tp_size, pp_size"), [ @pytest.mark.parametrize(("tp_size, pp_size"), [
(1, 1), (1, 1),
...@@ -47,12 +33,43 @@ async def completions_with_server_args(prompts: List[str], model_name: str, ...@@ -47,12 +33,43 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@pytest.mark.parametrize("eager_mode", [False, True]) @pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int, async def test_multi_step(
pp_size: int, eager_mode: int, example_prompts,
num_scheduler_steps: int, num_prompts: int, model: str,
is_async: bool): tp_size: int,
pp_size: int,
eager_mode: int,
num_scheduler_steps: int,
num_prompts: int,
is_async: bool,
num_logprobs: Optional[int],
) -> None:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
client/server environment.
Set up an engine with single-step scheduling as a ground-truth reference.
Send a completions API request to both engines with the same prompts.
Validate:
* Generated tokens match
* Generated logprobs are all very close
Args:
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
tp_size: degree of tensor-parallelism
pp_size: degree of pipeline-parallelism
eager_mode
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
"""
prompts = example_prompts prompts = example_prompts
if len(prompts) < num_prompts: if len(prompts) < num_prompts:
...@@ -77,14 +94,36 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, ...@@ -77,14 +94,36 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
str(pp_size), str(pp_size),
] ]
# Spin up client/server & issue completion API requests.
# Default `max_wait_seconds` is 240 but was empirically
# was raised 3x to 720 *just for this test* due to
# observed timeouts in GHA CI
ref_completions = await completions_with_server_args( ref_completions = await completions_with_server_args(
prompts, model, server_args + distributed_args) prompts,
model,
server_args + distributed_args,
num_logprobs,
max_wait_seconds=3 * 240)
test_completions = await completions_with_server_args( test_completions = await completions_with_server_args(
prompts, model, ms_server_args + distributed_args) prompts,
model,
def get_text_generations(completions): ms_server_args + distributed_args,
return [x.text for x in completions.choices] num_logprobs,
max_wait_seconds=3 * 240)
ref_generations = get_text_generations(ref_completions)
test_generations = get_text_generations(test_completions) # Assert multi-step scheduling produces identical tokens
# to single-step scheduling.
ref_generations = get_client_text_generations(ref_completions)
test_generations = get_client_text_generations(test_completions)
assert ref_generations == test_generations assert ref_generations == test_generations
# Assert multi-step scheduling produces nearly-identical logprobs
# to single-step scheduling.
ref_text_logprobs = get_client_text_logprob_generations(ref_completions)
test_text_logprobs = get_client_text_logprob_generations(test_completions)
check_logprobs_close(
outputs_0_lst=ref_text_logprobs,
outputs_1_lst=test_text_logprobs,
name_0="hf",
name_1="vllm",
)
# Test the LLMEngine with multi-step-decoding # Test the LLMEngine with multi-step-decoding
from typing import Optional
import pytest import pytest
from ..models.utils import check_outputs_equal from ..models.utils import check_logprobs_close, check_outputs_equal
MODELS = [ MODELS = [
"JackFram/llama-160m", "JackFram/llama-160m",
...@@ -18,10 +20,45 @@ NUM_PROMPTS = [10] ...@@ -18,10 +20,45 @@ NUM_PROMPTS = [10]
@pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str, @pytest.mark.parametrize("num_logprobs", [None, 5])
dtype: str, tp_size: int, max_tokens: int, def test_multi_step_llm(
enforce_eager: int, num_scheduler_steps: int, hf_runner,
num_prompts: int) -> None: vllm_runner,
example_prompts,
model: str,
dtype: str,
tp_size: int,
max_tokens: int,
enforce_eager: int,
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
) -> None:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
Set up a HuggingFace (HF) transformers model as a ground-truth reference.
Prompt them with the same example prompts.
Validate:
* Generated tokens match
* Generated logprobs are all very close
Args:
hf_runner: HF transformers model runner fixture
vllm_runner: vLLM model runner fixture
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
dtype: tensor datatype for engine to utilize
tp_size: degree of tensor-parallelism
max_tokens: the maximum number of tokens to generate
enforce_eager
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
"""
prompts = example_prompts prompts = example_prompts
if len(prompts) < num_prompts: if len(prompts) < num_prompts:
...@@ -29,21 +66,37 @@ def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str, ...@@ -29,21 +66,37 @@ def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str,
prompts = prompts[:num_prompts] prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts assert len(prompts) == num_prompts
with vllm_runner(model, with vllm_runner(
dtype=dtype, model,
enforce_eager=enforce_eager, dtype=dtype,
gpu_memory_utilization=0.7, enforce_eager=enforce_eager,
tensor_parallel_size=tp_size, gpu_memory_utilization=0.7,
use_v2_block_manager=True, tensor_parallel_size=tp_size,
num_scheduler_steps=num_scheduler_steps) as vllm_model: use_v2_block_manager=True,
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) num_scheduler_steps=num_scheduler_steps,
) as vllm_model:
vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens)
if num_logprobs is None else
vllm_model.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs))
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(prompts, max_tokens) hf_outputs = (hf_model.generate_greedy(prompts, max_tokens)
if num_logprobs is None else
check_outputs_equal( hf_model.generate_greedy_logprobs_limit(
outputs_0_lst=hf_outputs, prompts, max_tokens, num_logprobs))
outputs_1_lst=vllm_outputs,
name_0="hf", if num_logprobs is None:
name_1="vllm", check_outputs_equal(
) outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
else:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
...@@ -5,9 +5,10 @@ from unittest.mock import MagicMock ...@@ -5,9 +5,10 @@ from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob, from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
SamplerOutput, get_all_seq_ids) get_all_seq_ids)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
......
...@@ -7,8 +7,9 @@ from unittest.mock import MagicMock ...@@ -7,8 +7,9 @@ from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector, from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics) SpecDecodeWorkerMetrics)
......
...@@ -8,12 +8,12 @@ from unittest.mock import MagicMock ...@@ -8,12 +8,12 @@ from unittest.mock import MagicMock
import torch import torch
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob, CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceData, SequenceGroupMetadata, SequenceOutput)
SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
......
...@@ -2,9 +2,10 @@ from array import array ...@@ -2,9 +2,10 @@ from array import array
import pytest import pytest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SamplerOutput, CompletionSequenceGroupOutput, SequenceData,
SequenceData, SequenceOutput) SequenceOutput)
from .core.utils import create_dummy_prompt from .core.utils import create_dummy_prompt
......
...@@ -11,9 +11,11 @@ from typing import Any, Callable, Dict, List, Optional ...@@ -11,9 +11,11 @@ from typing import Any, Callable, Dict, List, Optional
import openai import openai
import requests import requests
from openai.types.completion import Completion
from transformers import AutoTokenizer from transformers import AutoTokenizer
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from tests.models.utils import TextTextLogprobs
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
...@@ -432,3 +434,61 @@ def fork_new_process_for_each_test( ...@@ -432,3 +434,61 @@ def fork_new_process_for_each_test(
f" args {args} and kwargs {kwargs}") f" args {args} and kwargs {kwargs}")
return wrapper return wrapper
async def completions_with_server_args(
prompts: List[str],
model_name: str,
server_cli_args: List[str],
num_logprobs: Optional[int],
max_wait_seconds: int = 240,
) -> Completion:
'''Construct a remote OpenAI server, obtain an async client to the
server & invoke the completions API to obtain completions.
Args:
prompts: test prompts
model_name: model to spin up on the vLLM server
server_cli_args: CLI args for starting the server
num_logprobs: Number of logprobs to report (or `None`)
max_wait_seconds: timeout interval for bringing up server.
Default: 240sec
Returns:
OpenAI Completion instance
'''
outputs = None
with RemoteOpenAIServer(model_name,
server_cli_args,
max_wait_seconds=max_wait_seconds) as server:
client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5,
logprobs=num_logprobs)
assert outputs is not None
return outputs
def get_client_text_generations(completions: Completion) -> List[str]:
'''Extract generated tokens from the output of a
request made to an Open-AI-protocol completions endpoint.
'''
return [x.text for x in completions.choices]
def get_client_text_logprob_generations(
completions: Completion) -> List[TextTextLogprobs]:
'''Operates on the output of a request made to an Open-AI-protocol
completions endpoint; obtains top-rank logprobs for each token in
each :class:`SequenceGroup`
'''
text_generations = get_client_text_generations(completions)
text = ''.join(text_generations)
return [(text_generations, text,
(None if x.logprobs is None else x.logprobs.top_logprobs))
for x in completions.choices]
...@@ -22,11 +22,12 @@ from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, ...@@ -22,11 +22,12 @@ from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
......
...@@ -33,6 +33,7 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, ...@@ -33,6 +33,7 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
...@@ -40,8 +41,8 @@ from vllm.pooling_params import PoolingParams ...@@ -40,8 +41,8 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
SamplerOutput, Sequence, SequenceGroup, Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceGroupMetadata, SequenceStatus) SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
......
...@@ -4,6 +4,8 @@ from typing import Callable, List ...@@ -4,6 +4,8 @@ from typing import Callable, List
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.single_step import (
single_step_process_prompt_logprob)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -46,9 +48,16 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -46,9 +48,16 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def process_prompt_logprob(self, seq_group: SequenceGroup, def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None: outputs: List[SequenceGroupOutput]) -> None:
# TODO(sang): Prompt logprob currently not implemented in multi step """Process prompt logprobs associated with each step of a multi-step-
# workers. scheduled computation.
self._log_prompt_logprob_unsupported_warning_once()
Args:
seq_group: the outputs are associated with this :class:`SequenceGroup`
outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
"""
for output in outputs:
# Concatenate single-step prompt logprob processing results.
single_step_process_prompt_logprob(self, seq_group, output)
@staticmethod @staticmethod
@functools.lru_cache() @functools.lru_cache()
......
...@@ -15,6 +15,44 @@ from vllm.utils import Counter ...@@ -15,6 +15,44 @@ from vllm.utils import Counter
logger = init_logger(__name__) logger = init_logger(__name__)
def single_step_process_prompt_logprob(
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
output: SequenceGroupOutput) -> None:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step.
Do nothing if the output has no prompt logprobs.
Account for the fact that transformers do not compute first-token logprobs.
Args:
sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
"""
prompt_logprobs = output.prompt_logprobs
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []
assert hasattr(sg_output_proc, 'detokenizer')
if (seq_group.sampling_params.detokenize
and sg_output_proc.detokenizer):
sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))
seq_group.prompt_logprobs.extend(prompt_logprobs)
class SingleStepOutputProcessor(SequenceGroupOutputProcessor): class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles "output processing" logic, """SequenceGroupOutputProcessor which handles "output processing" logic,
which happens after the model returns generated token ids and before which happens after the model returns generated token ids and before
...@@ -60,27 +98,16 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -60,27 +98,16 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
def process_prompt_logprob(self, seq_group: SequenceGroup, def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None: outputs: List[SequenceGroupOutput]) -> None:
"""Process prompt logprobs associated with one step of a single-step-
scheduled computation.
Args:
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
"""
assert len(outputs) == 1, ("Single step should only has 1 output.") assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0] output = outputs[0]
prompt_logprobs = output.prompt_logprobs single_step_process_prompt_logprob(self, seq_group, output)
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []
if seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))
seq_group.prompt_logprobs.extend(prompt_logprobs)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput, outputs: SequenceGroupOutput,
......
...@@ -2,7 +2,8 @@ from typing import List ...@@ -2,7 +2,8 @@ from typing import List
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Union from typing import Union
from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import PoolerOutput, SequenceGroupOutput
def create_output_by_sequence_group( def create_output_by_sequence_group(
......
...@@ -5,11 +5,11 @@ from vllm.config import DecodingConfig, ModelConfig ...@@ -5,11 +5,11 @@ from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs from vllm.inputs.data import PromptInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
......
...@@ -11,8 +11,9 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ...@@ -11,8 +11,9 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor) ResultHandler, WorkerMonitor)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port, from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
......
...@@ -6,7 +6,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase ...@@ -6,7 +6,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -6,8 +6,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -6,8 +6,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest
class ExecutorBase(ABC): class ExecutorBase(ABC):
......
...@@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union ...@@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
......
...@@ -14,7 +14,8 @@ from vllm.executor.gpu_executor import create_worker ...@@ -14,7 +14,8 @@ from vllm.executor.gpu_executor import create_worker
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor) ResultHandler, WorkerMonitor)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils import maybe_set_triton_cache_manager from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
get_distributed_init_method, get_open_port, get_distributed_init_method, get_open_port,
......
...@@ -3,7 +3,8 @@ from typing import List, Set, Tuple ...@@ -3,7 +3,8 @@ from typing import List, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment