Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
428dd144
Unverified
Commit
428dd144
authored
Aug 29, 2024
by
afeldman-nm
Committed by
GitHub
Aug 29, 2024
Browse files
[Core] Logprobs support in Multi-step (#7652)
parent
4abed65c
Changes
103
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
193 additions
and
49 deletions
+193
-49
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+2
-2
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+2
-2
vllm/spec_decode/mlp_speculator_worker.py
vllm/spec_decode/mlp_speculator_worker.py
+2
-2
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+3
-2
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+2
-1
vllm/spec_decode/proposer_worker_base.py
vllm/spec_decode/proposer_worker_base.py
+2
-1
vllm/spec_decode/smaller_tp_proposer_worker.py
vllm/spec_decode/smaller_tp_proposer_worker.py
+2
-1
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+2
-1
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+2
-2
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+2
-2
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+2
-2
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+2
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+2
-2
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+2
-2
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+154
-19
vllm/worker/multi_step_worker.py
vllm/worker/multi_step_worker.py
+2
-1
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+2
-2
vllm/worker/openvino_model_runner.py
vllm/worker/openvino_model_runner.py
+2
-1
vllm/worker/openvino_worker.py
vllm/worker/openvino_worker.py
+2
-1
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+2
-2
No files found.
vllm/spec_decode/draft_model_runner.py
View file @
428dd144
...
@@ -3,6 +3,7 @@ from typing import List, Optional
...
@@ -3,6 +3,7 @@ from typing import List, Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.sampler
import
SamplerOutput
try
:
try
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
...
@@ -16,8 +17,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -16,8 +17,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalInputs
from
vllm.multimodal
import
MultiModalInputs
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
SamplerOutput
)
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
ModelRunner
)
ModelRunner
)
...
...
vllm/spec_decode/medusa_worker.py
View file @
428dd144
...
@@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple
...
@@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple
import
torch
import
torch
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.
sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
from
vllm.
model_executor.layers.sampler
import
SamplerOutput
SequenceGroupMetadata
)
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
...
...
vllm/spec_decode/mlp_speculator_worker.py
View file @
428dd144
...
@@ -3,8 +3,8 @@ from typing import List, Optional, Set, Tuple
...
@@ -3,8 +3,8 @@ from typing import List, Optional, Set, Tuple
import
torch
import
torch
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.
sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
from
vllm.
model_executor.layers.sampler
import
SamplerOutput
SequenceGroupMetadata
)
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
...
...
vllm/spec_decode/multi_step_worker.py
View file @
428dd144
...
@@ -4,8 +4,9 @@ from typing import Dict, List, Set, Tuple
...
@@ -4,8 +4,9 @@ from typing import Dict, List, Set, Tuple
import
torch
import
torch
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
SamplerOutput
,
from
vllm.model_executor.layers.sampler
import
SamplerOutput
SequenceData
,
SequenceGroupMetadata
)
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
SpeculativeProposer
)
...
...
vllm/spec_decode/ngram_worker.py
View file @
428dd144
...
@@ -3,7 +3,8 @@ from typing import List, Optional, Set, Tuple
...
@@ -3,7 +3,8 @@ from typing import List, Optional, Set, Tuple
import
torch
import
torch
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
...
...
vllm/spec_decode/proposer_worker_base.py
View file @
428dd144
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.interfaces
import
SpeculativeProposer
from
vllm.spec_decode.interfaces
import
SpeculativeProposer
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
...
...
vllm/spec_decode/smaller_tp_proposer_worker.py
View file @
428dd144
...
@@ -6,7 +6,8 @@ from vllm.distributed.parallel_state import (get_tp_group,
...
@@ -6,7 +6,8 @@ from vllm.distributed.parallel_state import (get_tp_group,
init_model_parallel_group
,
init_model_parallel_group
,
patch_tensor_parallel_group
)
patch_tensor_parallel_group
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
428dd144
...
@@ -8,12 +8,13 @@ from vllm.config import ParallelConfig, SpeculativeConfig
...
@@ -8,12 +8,13 @@ from vllm.config import ParallelConfig, SpeculativeConfig
from
vllm.distributed.communication_op
import
broadcast_tensor_dict
from
vllm.distributed.communication_op
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
,
SpecDecodeStochasticBaseSampler
)
SpecDecodeBaseSampler
,
SpecDecodeStochasticBaseSampler
)
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
TypicalAcceptanceSampler
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
HiddenStates
,
SamplerOutput
,
SequenceGroupMetadata
,
HiddenStates
,
SequenceGroupMetadata
,
get_all_seq_ids
,
get_all_seq_ids_and_request_ids
)
get_all_seq_ids
,
get_all_seq_ids_and_request_ids
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
...
...
vllm/spec_decode/top1_proposer.py
View file @
428dd144
...
@@ -2,8 +2,8 @@ from typing import List, Optional, Set, Tuple
...
@@ -2,8 +2,8 @@ from typing import List, Optional, Set, Tuple
import
torch
import
torch
from
vllm.
sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
from
vllm.
model_executor.layers.sampler
import
SamplerOutput
SequenceGroupMetadata
)
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
...
...
vllm/spec_decode/util.py
View file @
428dd144
...
@@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Sequence, Tuple
...
@@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Sequence, Tuple
import
torch
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceOutput
)
SeqId
=
int
SeqId
=
int
...
...
vllm/worker/cpu_model_runner.py
View file @
428dd144
...
@@ -10,11 +10,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -10,11 +10,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
SequenceGroupMetadata
)
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
...
...
vllm/worker/enc_dec_model_runner.py
View file @
428dd144
...
@@ -16,9 +16,10 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -16,9 +16,10 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
SamplerOutput
,
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_BACKEND
,
make_tensor_with_pad
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_BACKEND
,
make_tensor_with_pad
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
...
...
vllm/worker/model_runner.py
View file @
428dd144
...
@@ -29,6 +29,7 @@ from vllm.lora.layers import LoRAMapping
...
@@ -29,6 +29,7 @@ from vllm.lora.layers import LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
,
SamplingMetadataCache
from
vllm.model_executor
import
SamplingMetadata
,
SamplingMetadataCache
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
...
@@ -41,8 +42,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
...
@@ -41,8 +42,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.prompt_adapter.worker_manager
import
(
from
vllm.prompt_adapter.worker_manager
import
(
LRUCacheWorkerPromptAdapterManager
)
LRUCacheWorkerPromptAdapterManager
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
PyObjectCache
,
async_tensor_h2d
,
from
vllm.utils
import
(
CudaMemoryProfiler
,
PyObjectCache
,
async_tensor_h2d
,
flatten_2d_lists
,
is_hip
,
is_pin_memory_available
,
flatten_2d_lists
,
is_hip
,
is_pin_memory_available
,
supports_dynamo
)
supports_dynamo
)
...
...
vllm/worker/model_runner_base.py
View file @
428dd144
...
@@ -5,9 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
...
@@ -5,9 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import
torch
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
SequenceGroupMetadata
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
...
...
vllm/worker/multi_step_model_runner.py
View file @
428dd144
import
dataclasses
import
dataclasses
import
functools
import
functools
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
)
try
:
try
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
...
@@ -15,9 +16,12 @@ import torch
...
@@ -15,9 +16,12 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SamplingMetadata
,
get_logprobs
,
get_pythonized_sample_results
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
Logprob
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceOutput
)
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
...
@@ -53,6 +57,8 @@ class ModelOutput:
...
@@ -53,6 +57,8 @@ class ModelOutput:
sampler_output_ready_event
:
torch
.
cuda
.
Event
sampler_output_ready_event
:
torch
.
cuda
.
Event
sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
pythonized
:
bool
=
False
pythonized
:
bool
=
False
# On-device tensor containing the logprobs of each token.
logprobs
:
Optional
[
"torch.Tensor"
]
=
None
def
pythonize
(
self
,
input_metadata
:
"StatefulModelInput"
,
def
pythonize
(
self
,
input_metadata
:
"StatefulModelInput"
,
copy_stream
:
torch
.
cuda
.
Stream
,
copy_stream
:
torch
.
cuda
.
Stream
,
...
@@ -78,7 +84,9 @@ class ModelOutput:
...
@@ -78,7 +84,9 @@ class ModelOutput:
blocking
:
bool
)
->
bool
:
blocking
:
bool
)
->
bool
:
"""
"""
If blocking is set, will block until the forward pass for the output is
If blocking is set, will block until the forward pass for the output is
ready and pythonize the output.
ready and pythonize the output. Upon completing Pythonization, erases
self.logprobs (note that a non-blocking call that is performed when
the sampler output is not yet ready, will not erase self.logprobs.)
"""
"""
assert
self
.
sampled_token_ids
is
not
None
assert
self
.
sampled_token_ids
is
not
None
if
not
blocking
and
not
self
.
sampler_output_ready_event
.
query
():
if
not
blocking
and
not
self
.
sampler_output_ready_event
.
query
():
...
@@ -89,7 +97,15 @@ class ModelOutput:
...
@@ -89,7 +97,15 @@ class ModelOutput:
with
torch
.
cuda
.
stream
(
copy_stream
):
with
torch
.
cuda
.
stream
(
copy_stream
):
_pythonize_sampler_output
(
input_metadata
,
self
.
sampler_output
,
_pythonize_sampler_output
(
input_metadata
,
self
.
sampler_output
,
pinned_sampled_token_buffer
,
pinned_sampled_token_buffer
,
self
.
sampled_token_ids
)
self
.
sampled_token_ids
,
self
.
logprobs
)
# Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its
# own CUDA stream, nonetheless _pythonize_sampler_output()
# cannot return until Pythonization is complete; therefore
# we know that by the time the CPU reaches this point,
# `self.logprobs` is no longer needed.
self
.
logprobs
=
None
return
True
return
True
...
@@ -350,11 +366,16 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -350,11 +366,16 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
0
].
sampled_token_ids
.
cpu
()
0
].
sampled_token_ids
.
cpu
()
model_input
.
cached_outputs
.
append
(
model_input
.
cached_outputs
.
append
(
ModelOutput
(
output
[
0
],
output_ready_event
,
ModelOutput
(
output
[
0
],
output_ready_event
,
output
[
0
].
sampled_token_ids
,
False
))
output
[
0
].
sampled_token_ids
,
False
,
# make sure we dont try to serialize any GPU tensors
output
[
0
].
logprobs
))
# These GPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or
# transferred to CPU
output
[
0
].
sampled_token_ids
=
None
output
[
0
].
sampled_token_ids
=
None
output
[
0
].
sampled_token_probs
=
None
output
[
0
].
sampled_token_probs
=
None
output
[
0
].
logprobs
=
None
output
[
0
].
logprobs
=
None
# Pythonize the output if CPU is ahead and the previous step is
# Pythonize the output if CPU is ahead and the previous step is
# ready.
# ready.
if
not
frozen_model_input
.
use_async_and_multi_step
:
if
not
frozen_model_input
.
use_async_and_multi_step
:
...
@@ -464,12 +485,75 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -464,12 +485,75 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
return
self
.
_base_model_runner
.
vocab_size
return
self
.
_base_model_runner
.
vocab_size
def
_pythonize_sampler_output
(
model_input
:
StatefulModelInput
,
DeferredLogprobsReturnType
=
Tuple
[
Optional
[
List
[
Optional
[
PromptLogprobs
]]],
Optional
[
List
[
SampleLogprobs
]]]
def
deferred_pythonize_logprobs
(
output
:
SamplerOutput
,
sampling_metadata
:
SamplingMetadata
,
logprobs_tensor
:
Optional
[
torch
.
Tensor
],
)
->
DeferredLogprobsReturnType
:
"""Perform deferred logprob Pythonization.
1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
utilizing the Pythonized sampler result computed in step 1.
These deferred computations are not required for single-step scheduling
or the `profile_run()` phase of multi-step scheduling.
Args:
output: sampler output (under deferred Pythonization)
sampling_metadata
Returns:
prompt_logprobs (CPU), sample_logprobs (CPU)
"""
# - Deferred pythonization of sample result
sampler_result
=
get_pythonized_sample_results
(
output
.
deferred_sample_results_args
)
# - Erase the GPU-side deferred sample_result
# computation args to ensure it is never
# pythonized or transferred to CPU
output
.
deferred_sample_results_args
=
None
# - Deferred pythonization of logprobs
(
prompt_logprobs
,
sample_logprobs
,
)
=
get_logprobs
(
logprobs_tensor
,
sampling_metadata
,
sampler_result
)
assert
len
(
prompt_logprobs
)
==
len
(
sampling_metadata
.
seq_groups
)
assert
len
(
sample_logprobs
)
==
len
(
sampling_metadata
.
seq_groups
)
return
prompt_logprobs
,
sample_logprobs
def
_pythonize_sampler_output
(
model_input
:
StatefulModelInput
,
output
:
SamplerOutput
,
output
:
SamplerOutput
,
pinned_sampled_token_buffer
:
torch
.
Tensor
,
pinned_sampled_token_buffer
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
)
->
None
:
sampled_token_ids
:
torch
.
Tensor
,
logprobs_tensor
:
Optional
[
torch
.
Tensor
],
)
->
None
:
""" This function is only called when the output tensors are ready.
""" This function is only called when the output tensors are ready.
See ModelOutput
See :class:`ModelOutput`.
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
adding a Pythonized output data structure
(:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
Args:
model_input
output: sampler output
pinned_sampled_token_token_buffer: CPU-side pinned memory
(receives copy of
GPU-side token buffer.)
sampled_token_ids: GPU-side token buffer
logprobs_tensor: GPU-side tensor containing
logprobs computed during sampling
"""
"""
assert
model_input
.
frozen_model_input
is
not
None
assert
model_input
.
frozen_model_input
is
not
None
...
@@ -489,8 +573,51 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
...
@@ -489,8 +573,51 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
sampling_metadata
=
frozen_model_input
.
sampling_metadata
sampling_metadata
=
frozen_model_input
.
sampling_metadata
for
(
seq_group
,
sample_result
)
in
zip
(
sampling_metadata
.
seq_groups
,
skip_sampler_cpu_output
=
(
samples_list
):
frozen_model_input
.
sampling_metadata
.
skip_sampler_cpu_output
)
# We are guaranteed output tensors are ready, so it is safe to
# pythonize the sampler output & obtain CPU-side logprobs.
#
# However this computation may be skipped entirely
# if no pythonization was deferred.
seq_groups
=
sampling_metadata
.
seq_groups
logprobs_are_requested
=
any
([
sg
.
sampling_params
.
logprobs
is
not
None
or
sg
.
sampling_params
.
prompt_logprobs
is
not
None
for
sg
in
seq_groups
])
do_pythonize_logprobs
=
(
skip_sampler_cpu_output
and
logprobs_are_requested
)
(
prompt_logprobs
,
sample_logprobs
,
)
=
(
deferred_pythonize_logprobs
(
output
,
sampling_metadata
,
logprobs_tensor
)
if
do_pythonize_logprobs
else
(
None
,
None
))
for
sgdx
,
(
seq_group
,
sample_result
)
in
enumerate
(
zip
(
seq_groups
,
samples_list
)):
if
do_pythonize_logprobs
:
assert
prompt_logprobs
is
not
None
assert
sample_logprobs
is
not
None
(
group_prompt_logprobs
,
group_sample_logprobs
,
)
=
(
# Utilize deferred pythonization results
prompt_logprobs
[
sgdx
],
sample_logprobs
[
sgdx
],
)
elif
logprobs_are_requested
:
(
group_prompt_logprobs
,
group_sample_logprobs
,
)
=
(
# profile_run: use already-computed logprobs
output
.
outputs
[
sgdx
].
prompt_logprobs
,
[
sample
.
logprobs
for
sample
in
output
.
outputs
[
sgdx
].
samples
])
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
next_token_ids
=
sample_result
next_token_ids
=
sample_result
parent_ids
=
[
0
]
parent_ids
=
[
0
]
...
@@ -498,11 +625,19 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
...
@@ -498,11 +625,19 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
if
seq_group
.
sampling_params
.
logits_processors
:
if
seq_group
.
sampling_params
.
logits_processors
:
assert
len
(
seq_group
.
sampling_params
.
logits_processors
)
==
0
,
(
assert
len
(
seq_group
.
sampling_params
.
logits_processors
)
==
0
,
(
"Logits Processors are not supported in multi-step decoding"
)
"Logits Processors are not supported in multi-step decoding"
)
for
parent_id
,
next_token_id
in
zip
(
parent_ids
,
next_token_ids
):
for
tdx
,
(
parent_id
,
# TODO(will): support logprobs
next_token_id
)
in
enumerate
(
zip
(
parent_ids
,
next_token_ids
)):
# Hard coded logprob
seq_outputs
.
append
(
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
{
next_token_id
:
Logprob
(
logprob
=-
1
)}))
(
group_sample_logprobs
[
tdx
]
output
.
outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
if
logprobs_are_requested
else
{
next_token_id
:
Logprob
(
logprob
=
float
(
'inf'
),
rank
=
None
,
decoded_token
=
None
)
})))
output
.
outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
(
group_prompt_logprobs
if
logprobs_are_requested
else
None
)))
assert
len
(
output
.
outputs
)
>
0
assert
len
(
output
.
outputs
)
>
0
vllm/worker/multi_step_worker.py
View file @
428dd144
...
@@ -5,7 +5,8 @@ from typing import Dict, List, Optional, Tuple
...
@@ -5,7 +5,8 @@ from typing import Dict, List, Optional, Tuple
import
torch
import
torch
from
vllm.distributed
import
broadcast_tensor_dict
,
get_pp_group
from
vllm.distributed
import
broadcast_tensor_dict
,
get_pp_group
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.model_runner_base
import
BroadcastableModelInput
from
vllm.worker.model_runner_base
import
BroadcastableModelInput
from
vllm.worker.multi_step_model_runner
import
(
MultiStepModelRunner
,
from
vllm.worker.multi_step_model_runner
import
(
MultiStepModelRunner
,
StatefulModelInput
)
StatefulModelInput
)
...
...
vllm/worker/neuron_model_runner.py
View file @
428dd144
...
@@ -8,11 +8,11 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
...
@@ -8,11 +8,11 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.neuron
import
get_neuron_model
from
vllm.model_executor.model_loader.neuron
import
get_neuron_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
SequenceGroupMetadata
)
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
ModelRunnerBase
,
ModelRunnerInputBase
from
vllm.worker.model_runner_base
import
ModelRunnerBase
,
ModelRunnerInputBase
...
...
vllm/worker/openvino_model_runner.py
View file @
428dd144
...
@@ -11,10 +11,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -11,10 +11,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.openvino
import
get_model
from
vllm.model_executor.model_loader.openvino
import
get_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SequenceGroupMetadata
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/worker/openvino_worker.py
View file @
428dd144
...
@@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict,
...
@@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.openvino_model_runner
import
OpenVINOModelRunner
from
vllm.worker.openvino_model_runner
import
OpenVINOModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
...
...
vllm/worker/tpu_model_runner.py
View file @
428dd144
...
@@ -14,11 +14,11 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
...
@@ -14,11 +14,11 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
Logprob
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceOutput
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
_add_attn_metadata_broadcastable_dict
,
_add_attn_metadata_broadcastable_dict
,
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment