Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ff58b66
Commit
4ff58b66
authored
Apr 14, 2025
by
lizhigong
Browse files
debug v0 zero overhead schedule
parent
54294854
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
43 additions
and
34 deletions
+43
-34
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+5
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+1
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-1
vllm/zero_overhead/v0/llm_engine.py
vllm/zero_overhead/v0/llm_engine.py
+18
-17
vllm/zero_overhead/v0/model_runner.py
vllm/zero_overhead/v0/model_runner.py
+1
-1
vllm/zero_overhead/v0/sampler.py
vllm/zero_overhead/v0/sampler.py
+7
-10
vllm/zero_overhead/v0/stop_check.py
vllm/zero_overhead/v0/stop_check.py
+1
-1
vllm/zero_overhead/v0/update_input.py
vllm/zero_overhead/v0/update_input.py
+7
-1
No files found.
vllm/attention/backends/utils.py
View file @
4ff58b66
...
@@ -239,8 +239,11 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -239,8 +239,11 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
# block_tables = torch.from_numpy(input_block_tables).to(
device
,
non_blocking
=
True
)
# device, non_blocking=True)
block_tables
=
async_tensor_h2d
(
input_block_tables
.
tolist
(),
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
else
:
else
:
block_tables
=
make_tensor_with_pad
(
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
self
.
block_tables
,
...
...
vllm/entrypoints/llm.py
View file @
4ff58b66
...
@@ -1450,6 +1450,8 @@ class LLM:
...
@@ -1450,6 +1450,8 @@ class LLM:
if
use_tqdm
:
if
use_tqdm
:
pbar
.
close
()
pbar
.
close
()
if
is_zero_overhead
():
self
.
llm_engine
.
finish_thread
()
# Sort the outputs by request ID.
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# This is necessary because some requests may be finished earlier than
# its previous requests.
# its previous requests.
...
...
vllm/model_executor/layers/sampler.py
View file @
4ff58b66
...
@@ -21,7 +21,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
...
@@ -21,7 +21,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput
,
Logprob
,
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
from
vllm.zero_overhead.v0.sampler
import
ZeroOverheadSampler
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
...
@@ -41,6 +40,7 @@ def get_sampler() -> torch.nn.Module:
...
@@ -41,6 +40,7 @@ def get_sampler() -> torch.nn.Module:
from
vllm.v1.sample.sampler
import
Sampler
as
V1Sampler
from
vllm.v1.sample.sampler
import
Sampler
as
V1Sampler
return
V1Sampler
()
return
V1Sampler
()
if
is_zero_overhead
():
if
is_zero_overhead
():
from
vllm.zero_overhead.v0.sampler
import
ZeroOverheadSampler
return
ZeroOverheadSampler
()
return
ZeroOverheadSampler
()
return
Sampler
()
return
Sampler
()
...
...
vllm/worker/model_runner.py
View file @
4ff58b66
...
@@ -60,7 +60,6 @@ from vllm.worker.model_runner_base import (
...
@@ -60,7 +60,6 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
_init_sampling_metadata_from_tensor_dict
)
from
vllm.zero_overhead.v0.model_runner
import
ZeroOverheadModelInputForGpuBuilder
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -1639,6 +1638,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1639,6 +1638,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
=
ModelInputForGPUBuilder
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
=
ModelInputForGPUBuilder
if
is_zero_overhead
():
if
is_zero_overhead
():
from
vllm.zero_overhead.v0.model_runner
import
ZeroOverheadModelInputForGpuBuilder
_builder_cls
=
ZeroOverheadModelInputForGpuBuilder
_builder_cls
=
ZeroOverheadModelInputForGpuBuilder
def
make_model_input_from_broadcasted_tensor_dict
(
def
make_model_input_from_broadcasted_tensor_dict
(
...
...
vllm/zero_overhead/v0/llm_engine.py
View file @
4ff58b66
from
collections
import
Counter
from
functools
import
partial
from
functools
import
partial
import
os
import
os
import
queue
import
queue
...
@@ -13,7 +12,7 @@ from vllm.core.scheduler import ScheduledSequenceGroup
...
@@ -13,7 +12,7 @@ from vllm.core.scheduler import ScheduledSequenceGroup
from
vllm.engine.llm_engine
import
_LOCAL_LOGGING_INTERVAL_SEC
,
LLMEngine
,
SchedulerContext
,
SchedulerOutputState
from
vllm.engine.llm_engine
import
_LOCAL_LOGGING_INTERVAL_SEC
,
LLMEngine
,
SchedulerContext
,
SchedulerOutputState
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.engine.output_processor.interfaces
import
SequenceGroupOutputProcessor
from
vllm.engine.output_processor.interfaces
import
SequenceGroupOutputProcessor
from
vllm.
entrypoints
import
logger
from
vllm.
logger
import
init_
logger
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs.data
import
ProcessorInputs
from
vllm.inputs.data
import
ProcessorInputs
...
@@ -31,8 +30,10 @@ from vllm.sampling_params import SamplingParams
...
@@ -31,8 +30,10 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
,
ParallelSampleSequenceGroup
,
SequenceGroup
,
SequenceGroupBase
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
ParallelSampleSequenceGroup
,
SequenceGroup
,
SequenceGroupBase
,
SequenceGroupMetadata
from
vllm.tracing
import
init_tracer
from
vllm.tracing
import
init_tracer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.usage.usage_lib
import
UsageContext
,
is_usage_stats_enabled
from
vllm.usage.usage_lib
import
UsageContext
,
is_usage_stats_enabled
from
vllm.utils
import
resolve_obj_by_qualname
,
weak_bind
from
vllm.utils
import
resolve_obj_by_qualname
,
weak_bind
,
Counter
from
vllm.zero_overhead.v0.sampler
import
SampleRecorder
,
get_last_sampler
from
vllm.zero_overhead.v0.sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.v0.sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.v0.stop_check
import
ZeroOverheadStopChecker
from
vllm.zero_overhead.v0.stop_check
import
ZeroOverheadStopChecker
from
vllm.zero_overhead.v0.tokenizer
import
ZeroOverheadDetokenizer
from
vllm.zero_overhead.v0.tokenizer
import
ZeroOverheadDetokenizer
...
@@ -40,6 +41,8 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
...
@@ -40,6 +41,8 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message
)
usage_message
)
from
vllm.profiler.prof
import
profile
from
vllm.profiler.prof
import
profile
logger
=
init_logger
(
__name__
)
class
ZeroOverheadEngine
(
LLMEngine
):
class
ZeroOverheadEngine
(
LLMEngine
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -77,7 +80,7 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -77,7 +80,7 @@ class ZeroOverheadEngine(LLMEngine):
logger
.
info
(
logger
.
info
(
"Initializing a V0 LLM engine (v%s) with config: %s, "
"Initializing a V0 LLM engine (v%s) with config: %s, "
"use_cached_outputs=%s, "
,
"use_cached_outputs=%s, "
,
ZLIB
_VERSION
,
VLLM
_VERSION
,
vllm_config
,
vllm_config
,
use_cached_outputs
,
use_cached_outputs
,
)
)
...
@@ -259,6 +262,7 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -259,6 +262,7 @@ class ZeroOverheadEngine(LLMEngine):
self
.
_skip_scheduling_next_step
=
False
self
.
_skip_scheduling_next_step
=
False
self
.
async_d2h
=
None
self
.
async_d2h
=
None
self
.
last_record
=
None
self
.
last_record
=
None
assert
os
.
environ
.
get
(
'HIP_ALLOC_INITIALIZE'
)
==
'0'
self
.
async_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
async_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
q_recorder
=
queue
.
Queue
()
self
.
q_recorder
=
queue
.
Queue
()
...
@@ -277,6 +281,7 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -277,6 +281,7 @@ class ZeroOverheadEngine(LLMEngine):
self
.
sem_m2s
.
release
()
self
.
sem_m2s
.
release
()
def
thread_zero_overhead
(
self
):
def
thread_zero_overhead
(
self
):
logger
.
info
(
'zero overhead thread start!'
)
try
:
try
:
while
True
:
while
True
:
self
.
sem_m2s
.
acquire
()
self
.
sem_m2s
.
acquire
()
...
@@ -290,12 +295,9 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -290,12 +295,9 @@ class ZeroOverheadEngine(LLMEngine):
(
seq_group_metadata_list
,
scheduler_outputs
,
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
last_outputs_ids
=
None
last_outputs_tensor
=
None
if
self
.
last_record
is
not
None
:
if
self
.
last_record
is
not
None
:
last_output
=
self
.
last_record
[
0
][
0
]
last_sampler
=
self
.
last_record
[
1
]
last_outputs_ids
,
last_outputs_tensor
=
last_output
.
sampler_out_ids
,
last_output
.
sampler_out_tenosr
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_d2h
=
last_outputs_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
async_event
.
record
()
self
.
q_recorder
.
put
(
self
.
last_record
)
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
else
:
...
@@ -322,10 +324,7 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -322,10 +324,7 @@ class ZeroOverheadEngine(LLMEngine):
finished_requests_ids
=
finished_requests_ids
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
,
last_sampled_token_ids
=
last_sampled_token_ids
)
last_outputs_ids
=
last_outputs_ids
,
last_outputs_sample
=
last_outputs_tensor
)
outputs
=
self
.
model_executor
.
execute_model
(
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
...
@@ -334,7 +333,8 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -334,7 +333,8 @@ class ZeroOverheadEngine(LLMEngine):
outputs
[
0
],
seq_group_metadata_list
,
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
self
.
last_record
=
[
outputs
,
seq_group_metadata_list
,
scheduler_outputs
]
last_sampler
=
get_last_sampler
()
self
.
last_record
=
[
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
]
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
...
@@ -353,12 +353,12 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -353,12 +353,12 @@ class ZeroOverheadEngine(LLMEngine):
virtual_engine
=
0
virtual_engine
=
0
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
.
request_outputs
.
clear
()
ctx
.
request_outputs
.
clear
()
outputs
,
seq_group_metadata_list
,
scheduler_outputs
=
recode_output
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
=
recode_output
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
ctx
.
scheduler_outputs
=
scheduler_outputs
self
.
async_event
.
synchronize
()
self
.
async_event
.
synchronize
()
self
.
_fix_last_step
(
self
.
_fix_last_step
(
outputs
,
seq_group_metadata_list
,
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
# is_first_step_output is True only when the num_steps of all
# is_first_step_output is True only when the num_steps of all
...
@@ -398,12 +398,13 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -398,12 +398,13 @@ class ZeroOverheadEngine(LLMEngine):
def
_fix_last_step
(
def
_fix_last_step
(
self
,
output
:
List
[
SamplerOutput
],
self
,
output
:
List
[
SamplerOutput
],
last_sampler
:
SampleRecorder
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list
=
self
.
async_d2h
.
tolist
()
sample_out_list
=
self
.
async_d2h
.
tolist
()
sample_out_ids
=
output
[
0
].
sampler
_out
_id
s
.
tolist
()
sample_out_ids
=
last_
sampler
.
seq
_id
.
tolist
()
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
):
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
...
...
vllm/zero_overhead/v0/model_runner.py
View file @
4ff58b66
...
@@ -35,7 +35,7 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
...
@@ -35,7 +35,7 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
def
build
(
self
)
->
ModelInputForGPU
:
def
build
(
self
)
->
ModelInputForGPU
:
model_input
=
super
().
build
()
model_input
=
super
().
build
()
last_sampler
=
get_last_sampler
()
last_sampler
=
get_last_sampler
()
if
last_sampler
.
sampled_token_ids_tensor
is
not
None
:
if
last_sampler
is
not
None
:
input_ids
=
async_tensor_h2d
(
self
.
req_ids
,
torch
.
long
,
input_ids
=
async_tensor_h2d
(
self
.
req_ids
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
self
.
runner
.
pin_memory
)
...
...
vllm/zero_overhead/v0/sampler.py
View file @
4ff58b66
...
@@ -3,11 +3,10 @@ from typing import Dict, List, Optional
...
@@ -3,11 +3,10 @@ from typing import Dict, List, Optional
import
torch
import
torch
from
vllm
import
envs
from
vllm
import
envs
from
vllm.model_executor.layers.rejection_sampler
import
_multinomial
from
vllm.model_executor.layers.sampler
import
MultinomialSamplesType
,
SampleMetadataType
,
\
from
vllm.model_executor.layers.sampler
import
MultinomialSamplesType
,
SampleMetadataType
,
\
SampleResultArgsType
,
SampleResultType
,
SampleResultsDictType
,
SampleReturnType
,
Sampler
,
\
SampleResultArgsType
,
SampleResultType
,
SampleResultsDictType
,
SampleReturnType
,
Sampler
,
\
SamplerOutput
,
_apply_min_p
,
_apply_min_tokens_penalty
,
_apply_top_k_top_p
,
_build_sampler_output
,
\
SamplerOutput
,
_apply_min_p
,
_apply_min_tokens_penalty
,
_apply_top_k_top_p
,
_build_sampler_output
,
\
_modify_greedy_probs_inplace
,
_top_k_top_p_multinomial_with_flashinfer
,
get_logprobs
_modify_greedy_probs_inplace
,
_top_k_top_p_multinomial_with_flashinfer
,
get_logprobs
,
_multinomial
from
vllm.model_executor.layers.utils
import
apply_penalties
from
vllm.model_executor.layers.utils
import
apply_penalties
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
,
SequenceGroupToSample
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
,
SequenceGroupToSample
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
...
@@ -17,13 +16,16 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
...
@@ -17,13 +16,16 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
# yapf: disable
# yapf: disable
from
flashinfer.sampling
import
(
from
flashinfer.sampling
import
(
top_k_top_p_sampling_from_probs
as
flashinfer_top_k_top_p_sampling
)
top_k_top_p_sampling_from_probs
as
flashinfer_top_k_top_p_sampling
)
# yapf: enable
else
:
flashinfer_top_k_top_p_sampling
=
None
class
SampleRecorder
:
class
SampleRecorder
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
seq_id
:
torch
.
Tensor
=
None
self
.
seq_id
:
torch
.
Tensor
=
None
self
.
sampled_token_ids_tensor
:
torch
.
Tensor
=
None
self
.
sampled_token_ids_tensor
:
torch
.
Tensor
=
None
last_sampler
=
SampleRecorder
()
last_sampler
=
None
def
get_last_sampler
():
def
get_last_sampler
():
return
last_sampler
return
last_sampler
...
@@ -55,6 +57,8 @@ class ZeroOverheadSampler(Sampler):
...
@@ -55,6 +57,8 @@ class ZeroOverheadSampler(Sampler):
logits: (num_tokens, vocab_size).
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
sampling_metadata: Metadata for sampling.
"""
"""
global
last_sampler
last_sampler
=
SampleRecorder
()
assert
logits
is
not
None
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
_
,
vocab_size
=
logits
.
shape
...
@@ -282,7 +286,6 @@ def _sample_with_torch(
...
@@ -282,7 +286,6 @@ def _sample_with_torch(
sample_metadata
:
SampleMetadataType
=
{}
sample_metadata
:
SampleMetadataType
=
{}
multinomial_samples
:
MultinomialSamplesType
=
{}
multinomial_samples
:
MultinomialSamplesType
=
{}
greedy_samples
:
Optional
[
torch
.
Tensor
]
=
None
greedy_samples
:
Optional
[
torch
.
Tensor
]
=
None
beam_search_logprobs
:
Optional
[
torch
.
Tensor
]
=
None
# Create output tensor for sampled token ids.
# Create output tensor for sampled token ids.
if
include_gpu_probs_tensor
:
if
include_gpu_probs_tensor
:
...
@@ -356,11 +359,6 @@ def _sample_with_torch(
...
@@ -356,11 +359,6 @@ def _sample_with_torch(
sampled_token_ids_tensor
[
long_sample_indices
]
=
\
sampled_token_ids_tensor
[
long_sample_indices
]
=
\
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
# Encapsulate arguments for computing Pythonized sampler
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
# results, whether deferred or otherwise.
maybe_deferred_args
=
SampleResultArgsType
(
maybe_deferred_args
=
SampleResultArgsType
(
...
@@ -368,7 +366,6 @@ def _sample_with_torch(
...
@@ -368,7 +366,6 @@ def _sample_with_torch(
sample_metadata
=
sample_metadata
,
sample_metadata
=
sample_metadata
,
multinomial_samples
=
multinomial_samples
,
multinomial_samples
=
multinomial_samples
,
greedy_samples
=
greedy_samples
,
greedy_samples
=
greedy_samples
,
beam_search_logprobs
=
beam_search_logprobs
,
sample_results_dict
=
sample_results_dict
)
sample_results_dict
=
sample_results_dict
)
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
...
...
vllm/zero_overhead/v0/stop_check.py
View file @
4ff58b66
...
@@ -44,7 +44,7 @@ class ZeroOverheadStopChecker(StopChecker):
...
@@ -44,7 +44,7 @@ class ZeroOverheadStopChecker(StopChecker):
# Check if a stop token was encountered.
# Check if a stop token was encountered.
# This assumes a single token produced per step.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
(
self
.
zero_overhead
)
last_token_id
=
seq
.
zero_overhead_
get_last_token_id
()
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
not
sampling_params
.
include_stop_str_in_output
):
...
...
vllm/zero_overhead/v0/update_input.py
View file @
4ff58b66
...
@@ -23,6 +23,12 @@ def _update_input_tokens(
...
@@ -23,6 +23,12 @@ def _update_input_tokens(
output_token
=
tl
.
load
(
sample_output
+
i
)
output_token
=
tl
.
load
(
sample_output
+
i
)
tl
.
store
(
input_tokens
+
pid
,
output_token
)
tl
.
store
(
input_tokens
+
pid
,
output_token
)
_update_input_tokens_ptr
=
None
def
UpdateInputTokens
(
input_tokens
,
input_seq_ids
,
last_sample
,
last_ids
):
def
UpdateInputTokens
(
input_tokens
,
input_seq_ids
,
last_sample
,
last_ids
):
global
_update_input_tokens_ptr
grid
=
[
input_seq_ids
.
shape
[
0
],
1
,
1
]
grid
=
[
input_seq_ids
.
shape
[
0
],
1
,
1
]
_update_input_tokens
[
grid
](
last_sample
,
last_ids
,
input_tokens
,
input_seq_ids
,
last_ids
.
shape
[
0
],
input_seq_ids
.
shape
[
0
])
if
_update_input_tokens_ptr
is
None
:
\ No newline at end of file
_update_input_tokens_ptr
=
_update_input_tokens
[
grid
](
last_sample
,
last_ids
,
input_tokens
,
input_seq_ids
,
last_ids
.
shape
[
0
],
input_seq_ids
.
shape
[
0
])
else
:
_update_input_tokens_ptr
[
grid
](
last_sample
,
last_ids
,
input_tokens
,
input_seq_ids
,
last_ids
.
shape
[
0
],
input_seq_ids
.
shape
[
0
])
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment