Unverified Commit 428dd144 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[Core] Logprobs support in Multi-step (#7652)

parent 4abed65c
...@@ -3,6 +3,7 @@ from typing import List, Optional ...@@ -3,6 +3,7 @@ from typing import List, Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.sampler import SamplerOutput
try: try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.attention.backends.flash_attn import FlashAttentionMetadata
...@@ -16,8 +17,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -16,8 +17,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs from vllm.multimodal import MultiModalInputs
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import ExecuteModelRequest, IntermediateTensors
SamplerOutput)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner) ModelRunner)
......
...@@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple ...@@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple
import torch import torch
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.model_executor.layers.sampler import SamplerOutput
SequenceGroupMetadata) from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
......
...@@ -3,8 +3,8 @@ from typing import List, Optional, Set, Tuple ...@@ -3,8 +3,8 @@ from typing import List, Optional, Set, Tuple
import torch import torch
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.model_executor.layers.sampler import SamplerOutput
SequenceGroupMetadata) from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
......
...@@ -4,8 +4,9 @@ from typing import Dict, List, Set, Tuple ...@@ -4,8 +4,9 @@ from typing import Dict, List, Set, Tuple
import torch import torch
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput, from vllm.model_executor.layers.sampler import SamplerOutput
SequenceData, SequenceGroupMetadata) from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer) SpeculativeProposer)
......
...@@ -3,7 +3,8 @@ from typing import List, Optional, Set, Tuple ...@@ -3,7 +3,8 @@ from typing import List, Optional, Set, Tuple
import torch import torch
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposer from vllm.spec_decode.interfaces import SpeculativeProposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase
......
...@@ -6,7 +6,8 @@ from vllm.distributed.parallel_state import (get_tp_group, ...@@ -6,7 +6,8 @@ from vllm.distributed.parallel_state import (get_tp_group,
init_model_parallel_group, init_model_parallel_group,
patch_tensor_parallel_group) patch_tensor_parallel_group)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
......
...@@ -8,12 +8,13 @@ from vllm.config import ParallelConfig, SpeculativeConfig ...@@ -8,12 +8,13 @@ from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.spec_decode_base_sampler import ( from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import ( from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler) TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SamplerOutput, SequenceGroupMetadata, HiddenStates, SequenceGroupMetadata,
get_all_seq_ids, get_all_seq_ids_and_request_ids) get_all_seq_ids, get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
......
...@@ -2,8 +2,8 @@ from typing import List, Optional, Set, Tuple ...@@ -2,8 +2,8 @@ from typing import List, Optional, Set, Tuple
import torch import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.model_executor.layers.sampler import SamplerOutput
SequenceGroupMetadata) from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer) SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
......
...@@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Sequence, Tuple ...@@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Sequence, Tuple
import torch import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata, SequenceGroupMetadata, SequenceOutput)
SequenceOutput)
SeqId = int SeqId = int
......
...@@ -10,11 +10,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -10,11 +10,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SchedulerConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
SequenceGroupMetadata)
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerBase, ModelRunnerInputBase,
......
...@@ -16,9 +16,10 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -16,9 +16,10 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, from vllm.sequence import (IntermediateTensors, PoolerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (GPUModelRunnerBase, from vllm.worker.model_runner import (GPUModelRunnerBase,
......
...@@ -29,6 +29,7 @@ from vllm.lora.layers import LoRAMapping ...@@ -29,6 +29,7 @@ from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora, from vllm.model_executor.models.interfaces import (supports_lora,
...@@ -41,8 +42,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -41,8 +42,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import ( from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager) LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available, flatten_2d_lists, is_hip, is_pin_memory_available,
supports_dynamo) supports_dynamo)
......
...@@ -5,9 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, ...@@ -5,9 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
SequenceGroupMetadata)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
......
import dataclasses import dataclasses
import functools import functools
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
try: try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.attention.backends.flash_attn import FlashAttentionMetadata
...@@ -15,9 +16,12 @@ import torch ...@@ -15,9 +16,12 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
SamplerOutput,
SamplingMetadata, get_logprobs,
get_pythonized_sample_results)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata, Logprob, SequenceGroupMetadata, SequenceOutput)
SequenceOutput)
from vllm.worker.model_runner import (GPUModelRunnerBase, from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
...@@ -53,6 +57,8 @@ class ModelOutput: ...@@ -53,6 +57,8 @@ class ModelOutput:
sampler_output_ready_event: torch.cuda.Event sampler_output_ready_event: torch.cuda.Event
sampled_token_ids: Optional[torch.Tensor] = None sampled_token_ids: Optional[torch.Tensor] = None
pythonized: bool = False pythonized: bool = False
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
def pythonize(self, input_metadata: "StatefulModelInput", def pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream, copy_stream: torch.cuda.Stream,
...@@ -78,7 +84,9 @@ class ModelOutput: ...@@ -78,7 +84,9 @@ class ModelOutput:
blocking: bool) -> bool: blocking: bool) -> bool:
""" """
If blocking is set, will block until the forward pass for the output is If blocking is set, will block until the forward pass for the output is
ready and pythonize the output. ready and pythonize the output. Upon completing Pythonization, erases
self.logprobs (note that a non-blocking call that is performed when
the sampler output is not yet ready, will not erase self.logprobs.)
""" """
assert self.sampled_token_ids is not None assert self.sampled_token_ids is not None
if not blocking and not self.sampler_output_ready_event.query(): if not blocking and not self.sampler_output_ready_event.query():
...@@ -89,7 +97,15 @@ class ModelOutput: ...@@ -89,7 +97,15 @@ class ModelOutput:
with torch.cuda.stream(copy_stream): with torch.cuda.stream(copy_stream):
_pythonize_sampler_output(input_metadata, self.sampler_output, _pythonize_sampler_output(input_metadata, self.sampler_output,
pinned_sampled_token_buffer, pinned_sampled_token_buffer,
self.sampled_token_ids) self.sampled_token_ids, self.logprobs)
# Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its
# own CUDA stream, nonetheless _pythonize_sampler_output()
# cannot return until Pythonization is complete; therefore
# we know that by the time the CPU reaches this point,
# `self.logprobs` is no longer needed.
self.logprobs = None
return True return True
...@@ -350,11 +366,16 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -350,11 +366,16 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
0].sampled_token_ids.cpu() 0].sampled_token_ids.cpu()
model_input.cached_outputs.append( model_input.cached_outputs.append(
ModelOutput(output[0], output_ready_event, ModelOutput(output[0], output_ready_event,
output[0].sampled_token_ids, False)) output[0].sampled_token_ids, False,
# make sure we dont try to serialize any GPU tensors output[0].logprobs))
# These GPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or
# transferred to CPU
output[0].sampled_token_ids = None output[0].sampled_token_ids = None
output[0].sampled_token_probs = None output[0].sampled_token_probs = None
output[0].logprobs = None output[0].logprobs = None
# Pythonize the output if CPU is ahead and the previous step is # Pythonize the output if CPU is ahead and the previous step is
# ready. # ready.
if not frozen_model_input.use_async_and_multi_step: if not frozen_model_input.use_async_and_multi_step:
...@@ -464,12 +485,75 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -464,12 +485,75 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
return self._base_model_runner.vocab_size return self._base_model_runner.vocab_size
def _pythonize_sampler_output(model_input: StatefulModelInput, DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
Optional[List[SampleLogprobs]]]
def deferred_pythonize_logprobs(
output: SamplerOutput,
sampling_metadata: SamplingMetadata,
logprobs_tensor: Optional[torch.Tensor],
) -> DeferredLogprobsReturnType:
"""Perform deferred logprob Pythonization.
1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
utilizing the Pythonized sampler result computed in step 1.
These deferred computations are not required for single-step scheduling
or the `profile_run()` phase of multi-step scheduling.
Args:
output: sampler output (under deferred Pythonization)
sampling_metadata
Returns:
prompt_logprobs (CPU), sample_logprobs (CPU)
"""
# - Deferred pythonization of sample result
sampler_result = get_pythonized_sample_results(
output.deferred_sample_results_args)
# - Erase the GPU-side deferred sample_result
# computation args to ensure it is never
# pythonized or transferred to CPU
output.deferred_sample_results_args = None
# - Deferred pythonization of logprobs
(
prompt_logprobs,
sample_logprobs,
) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
return prompt_logprobs, sample_logprobs
def _pythonize_sampler_output(
model_input: StatefulModelInput,
output: SamplerOutput, output: SamplerOutput,
pinned_sampled_token_buffer: torch.Tensor, pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor) -> None: sampled_token_ids: torch.Tensor,
logprobs_tensor: Optional[torch.Tensor],
) -> None:
""" This function is only called when the output tensors are ready. """ This function is only called when the output tensors are ready.
See ModelOutput See :class:`ModelOutput`.
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
adding a Pythonized output data structure
(:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
Args:
model_input
output: sampler output
pinned_sampled_token_token_buffer: CPU-side pinned memory
(receives copy of
GPU-side token buffer.)
sampled_token_ids: GPU-side token buffer
logprobs_tensor: GPU-side tensor containing
logprobs computed during sampling
""" """
assert model_input.frozen_model_input is not None assert model_input.frozen_model_input is not None
...@@ -489,8 +573,51 @@ def _pythonize_sampler_output(model_input: StatefulModelInput, ...@@ -489,8 +573,51 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
sampling_metadata = frozen_model_input.sampling_metadata sampling_metadata = frozen_model_input.sampling_metadata
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, skip_sampler_cpu_output = (
samples_list): frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
# We are guaranteed output tensors are ready, so it is safe to
# pythonize the sampler output & obtain CPU-side logprobs.
#
# However this computation may be skipped entirely
# if no pythonization was deferred.
seq_groups = sampling_metadata.seq_groups
logprobs_are_requested = any([
sg.sampling_params.logprobs is not None
or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups
])
do_pythonize_logprobs = (skip_sampler_cpu_output
and logprobs_are_requested)
(
prompt_logprobs,
sample_logprobs,
) = (deferred_pythonize_logprobs(output, sampling_metadata,
logprobs_tensor)
if do_pythonize_logprobs else (None, None))
for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)):
if do_pythonize_logprobs:
assert prompt_logprobs is not None
assert sample_logprobs is not None
(
group_prompt_logprobs,
group_sample_logprobs,
) = ( # Utilize deferred pythonization results
prompt_logprobs[sgdx],
sample_logprobs[sgdx],
)
elif logprobs_are_requested:
(
group_prompt_logprobs,
group_sample_logprobs,
) = (
# profile_run: use already-computed logprobs
output.outputs[sgdx].prompt_logprobs,
[sample.logprobs for sample in output.outputs[sgdx].samples])
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
next_token_ids = sample_result next_token_ids = sample_result
parent_ids = [0] parent_ids = [0]
...@@ -498,11 +625,19 @@ def _pythonize_sampler_output(model_input: StatefulModelInput, ...@@ -498,11 +625,19 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
if seq_group.sampling_params.logits_processors: if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, ( assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding") "Logits Processors are not supported in multi-step decoding")
for parent_id, next_token_id in zip(parent_ids, next_token_ids): for tdx, (parent_id,
# TODO(will): support logprobs next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
# Hard coded logprob
seq_outputs.append( seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, SequenceOutput(seq_ids[parent_id], next_token_id,
{next_token_id: Logprob(logprob=-1)})) (group_sample_logprobs[tdx]
output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None)) if logprobs_are_requested else {
next_token_id:
Logprob(logprob=float('inf'),
rank=None,
decoded_token=None)
})))
output.outputs.append(
CompletionSequenceGroupOutput(
seq_outputs,
(group_prompt_logprobs if logprobs_are_requested else None)))
assert len(output.outputs) > 0 assert len(output.outputs) > 0
...@@ -5,7 +5,8 @@ from typing import Dict, List, Optional, Tuple ...@@ -5,7 +5,8 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.model_runner_base import BroadcastableModelInput from vllm.worker.model_runner_base import BroadcastableModelInput
from vllm.worker.multi_step_model_runner import (MultiStepModelRunner, from vllm.worker.multi_step_model_runner import (MultiStepModelRunner,
StatefulModelInput) StatefulModelInput)
......
...@@ -8,11 +8,11 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, ...@@ -8,11 +8,11 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
......
...@@ -11,10 +11,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -11,10 +11,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SchedulerConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.openvino import get_model from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict, ...@@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.openvino_model_runner import OpenVINOModelRunner from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase
......
...@@ -14,11 +14,11 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther ...@@ -14,11 +14,11 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata, Logprob, SequenceGroupMetadata, SequenceOutput)
SequenceOutput)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment