Unverified Commit 428dd144 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[Core] Logprobs support in Multi-step (#7652)

parent 4abed65c
......@@ -3,6 +3,7 @@ from typing import List, Optional
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.sampler import SamplerOutput
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
......@@ -16,8 +17,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
......
......@@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple
import torch
from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
......
......@@ -3,8 +3,8 @@ from typing import List, Optional, Set, Tuple
import torch
from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
......
......@@ -4,8 +4,9 @@ from typing import Dict, List, Set, Tuple
import torch
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput,
SequenceData, SequenceGroupMetadata)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
......
......@@ -3,7 +3,8 @@ from typing import List, Optional, Set, Tuple
import torch
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
......
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
......
......@@ -6,7 +6,8 @@ from vllm.distributed.parallel_state import (get_tp_group,
init_model_parallel_group,
patch_tensor_parallel_group)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
......
......@@ -8,12 +8,13 @@ from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SamplerOutput, SequenceGroupMetadata,
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids, get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
......
......@@ -2,8 +2,8 @@ from typing import List, Optional, Set, Tuple
import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
......
......@@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Sequence, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
SequenceGroupMetadata, SequenceOutput)
SeqId = int
......
......@@ -10,11 +10,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
......
......@@ -16,9 +16,10 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
from vllm.sequence import (IntermediateTensors, PoolerOutput,
SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (GPUModelRunnerBase,
......
......@@ -29,6 +29,7 @@ from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora,
......@@ -41,8 +42,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available,
supports_dynamo)
......
......@@ -5,9 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
......
import dataclasses
import functools
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
......@@ -15,9 +16,12 @@ import torch
from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
SamplerOutput,
SamplingMetadata, get_logprobs,
get_pythonized_sample_results)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
......@@ -53,6 +57,8 @@ class ModelOutput:
sampler_output_ready_event: torch.cuda.Event
sampled_token_ids: Optional[torch.Tensor] = None
pythonized: bool = False
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
def pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
......@@ -78,7 +84,9 @@ class ModelOutput:
blocking: bool) -> bool:
"""
If blocking is set, will block until the forward pass for the output is
ready and pythonize the output.
ready and pythonize the output. Upon completing Pythonization, erases
self.logprobs (note that a non-blocking call that is performed when
the sampler output is not yet ready, will not erase self.logprobs.)
"""
assert self.sampled_token_ids is not None
if not blocking and not self.sampler_output_ready_event.query():
......@@ -89,7 +97,15 @@ class ModelOutput:
with torch.cuda.stream(copy_stream):
_pythonize_sampler_output(input_metadata, self.sampler_output,
pinned_sampled_token_buffer,
self.sampled_token_ids)
self.sampled_token_ids, self.logprobs)
# Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its
# own CUDA stream, nonetheless _pythonize_sampler_output()
# cannot return until Pythonization is complete; therefore
# we know that by the time the CPU reaches this point,
# `self.logprobs` is no longer needed.
self.logprobs = None
return True
......@@ -350,11 +366,16 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
0].sampled_token_ids.cpu()
model_input.cached_outputs.append(
ModelOutput(output[0], output_ready_event,
output[0].sampled_token_ids, False))
# make sure we dont try to serialize any GPU tensors
output[0].sampled_token_ids, False,
output[0].logprobs))
# These GPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or
# transferred to CPU
output[0].sampled_token_ids = None
output[0].sampled_token_probs = None
output[0].logprobs = None
# Pythonize the output if CPU is ahead and the previous step is
# ready.
if not frozen_model_input.use_async_and_multi_step:
......@@ -464,12 +485,75 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
return self._base_model_runner.vocab_size
def _pythonize_sampler_output(model_input: StatefulModelInput,
output: SamplerOutput,
pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor) -> None:
DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
Optional[List[SampleLogprobs]]]
def deferred_pythonize_logprobs(
output: SamplerOutput,
sampling_metadata: SamplingMetadata,
logprobs_tensor: Optional[torch.Tensor],
) -> DeferredLogprobsReturnType:
"""Perform deferred logprob Pythonization.
1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
utilizing the Pythonized sampler result computed in step 1.
These deferred computations are not required for single-step scheduling
or the `profile_run()` phase of multi-step scheduling.
Args:
output: sampler output (under deferred Pythonization)
sampling_metadata
Returns:
prompt_logprobs (CPU), sample_logprobs (CPU)
"""
# - Deferred pythonization of sample result
sampler_result = get_pythonized_sample_results(
output.deferred_sample_results_args)
# - Erase the GPU-side deferred sample_result
# computation args to ensure it is never
# pythonized or transferred to CPU
output.deferred_sample_results_args = None
# - Deferred pythonization of logprobs
(
prompt_logprobs,
sample_logprobs,
) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
return prompt_logprobs, sample_logprobs
def _pythonize_sampler_output(
model_input: StatefulModelInput,
output: SamplerOutput,
pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor,
logprobs_tensor: Optional[torch.Tensor],
) -> None:
""" This function is only called when the output tensors are ready.
See ModelOutput
See :class:`ModelOutput`.
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
adding a Pythonized output data structure
(:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
Args:
model_input
output: sampler output
pinned_sampled_token_token_buffer: CPU-side pinned memory
(receives copy of
GPU-side token buffer.)
sampled_token_ids: GPU-side token buffer
logprobs_tensor: GPU-side tensor containing
logprobs computed during sampling
"""
assert model_input.frozen_model_input is not None
......@@ -489,8 +573,51 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
sampling_metadata = frozen_model_input.sampling_metadata
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
samples_list):
skip_sampler_cpu_output = (
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
# We are guaranteed output tensors are ready, so it is safe to
# pythonize the sampler output & obtain CPU-side logprobs.
#
# However this computation may be skipped entirely
# if no pythonization was deferred.
seq_groups = sampling_metadata.seq_groups
logprobs_are_requested = any([
sg.sampling_params.logprobs is not None
or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups
])
do_pythonize_logprobs = (skip_sampler_cpu_output
and logprobs_are_requested)
(
prompt_logprobs,
sample_logprobs,
) = (deferred_pythonize_logprobs(output, sampling_metadata,
logprobs_tensor)
if do_pythonize_logprobs else (None, None))
for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)):
if do_pythonize_logprobs:
assert prompt_logprobs is not None
assert sample_logprobs is not None
(
group_prompt_logprobs,
group_sample_logprobs,
) = ( # Utilize deferred pythonization results
prompt_logprobs[sgdx],
sample_logprobs[sgdx],
)
elif logprobs_are_requested:
(
group_prompt_logprobs,
group_sample_logprobs,
) = (
# profile_run: use already-computed logprobs
output.outputs[sgdx].prompt_logprobs,
[sample.logprobs for sample in output.outputs[sgdx].samples])
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
......@@ -498,11 +625,19 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding")
for parent_id, next_token_id in zip(parent_ids, next_token_ids):
# TODO(will): support logprobs
# Hard coded logprob
for tdx, (parent_id,
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
{next_token_id: Logprob(logprob=-1)}))
output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None))
(group_sample_logprobs[tdx]
if logprobs_are_requested else {
next_token_id:
Logprob(logprob=float('inf'),
rank=None,
decoded_token=None)
})))
output.outputs.append(
CompletionSequenceGroupOutput(
seq_outputs,
(group_prompt_logprobs if logprobs_are_requested else None)))
assert len(output.outputs) > 0
......@@ -5,7 +5,8 @@ from typing import Dict, List, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.model_runner_base import BroadcastableModelInput
from vllm.worker.multi_step_model_runner import (MultiStepModelRunner,
StatefulModelInput)
......
......@@ -8,11 +8,11 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
......
......@@ -11,10 +11,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import SequenceGroupMetadata
logger = init_logger(__name__)
......
......@@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
......
......@@ -14,11 +14,11 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment