Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
603ad848
Unverified
Commit
603ad848
authored
Apr 26, 2024
by
SangBin Cho
Committed by
GitHub
Apr 26, 2024
Browse files
[Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309)
parent
a88081bf
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
862 additions
and
633 deletions
+862
-633
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+41
-3
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+31
-16
tests/test_logits_processor.py
tests/test_logits_processor.py
+7
-3
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+13
-6
vllm/core/scheduler.py
vllm/core/scheduler.py
+15
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+13
-12
vllm/engine/output_processor/interfaces.py
vllm/engine/output_processor/interfaces.py
+6
-0
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+9
-0
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+14
-8
vllm/engine/output_processor/util.py
vllm/engine/output_processor/util.py
+4
-3
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+12
-15
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+366
-178
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+284
-65
vllm/sequence.py
vllm/sequence.py
+10
-1
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+15
-101
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+9
-114
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+12
-107
No files found.
tests/samplers/test_logprobs.py
View file @
603ad848
...
@@ -9,15 +9,26 @@ MODELS = ["facebook/opt-125m"]
...
@@ -9,15 +9,26 @@ MODELS = ["facebook/opt-125m"]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"num_top_logprobs"
,
[
6
])
# 32000 == vocab_size
def
test_get_prompt_logprobs
(
def
test_get_prompt_logprobs
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
model
,
model
,
dtype
,
dtype
,
chunked_prefill_token_size
:
int
,
num_top_logprobs
:
int
,
example_prompts
,
example_prompts
,
):
):
max_num_seqs
=
256
enable_chunked_prefill
=
False
max_num_batched_tokens
=
None
if
chunked_prefill_token_size
!=
-
1
:
enable_chunked_prefill
=
True
max_num_seqs
=
min
(
chunked_prefill_token_size
,
max_num_seqs
)
max_num_batched_tokens
=
chunked_prefill_token_size
max_tokens
=
5
max_tokens
=
5
num_top_logprobs
=
6
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
example_prompts
,
example_prompts
,
...
@@ -25,10 +36,17 @@ def test_get_prompt_logprobs(
...
@@ -25,10 +36,17 @@ def test_get_prompt_logprobs(
)
)
del
hf_model
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
max_logprobs
=
num_top_logprobs
)
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
max_logprobs
=
num_top_logprobs
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_seqs
=
max_num_seqs
,
)
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
logprobs
=
num_top_logprobs
,
logprobs
=
num_top_logprobs
,
prompt_logprobs
=
5
,
prompt_logprobs
=
num_top_logprobs
,
temperature
=
0.0
)
temperature
=
0.0
)
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
example_prompts
,
sampling_params
=
vllm_sampling_params
)
...
@@ -52,9 +70,18 @@ def test_get_prompt_logprobs(
...
@@ -52,9 +70,18 @@ def test_get_prompt_logprobs(
"The output text from the top logprob for each token position "
"The output text from the top logprob for each token position "
"should be the same as the output text in the result."
)
"should be the same as the output text in the result."
)
# The first prompt logprob is always None
assert
result
.
prompt_logprobs
[
0
]
is
None
for
prompt_logprobs
in
result
.
prompt_logprobs
[
1
:]:
# If the prompt token is not included in the top X
# logprob, it can return 1 more data
assert
(
len
(
prompt_logprobs
)
==
num_top_logprobs
or
len
(
prompt_logprobs
)
==
num_top_logprobs
+
1
)
# Test whether prompt logprobs are consistent with HF
# Test whether prompt logprobs are consistent with HF
for
vllm_result
,
hf_logprob
in
zip
(
vllm_results
,
hf_logprobs
):
for
vllm_result
,
hf_logprob
in
zip
(
vllm_results
,
hf_logprobs
):
# Check prompt logprobs
# Check prompt logprobs
# The first prompt logprob is always None, so we compare it from 1:.
vllm_prompt_logprobs
=
vllm_result
.
prompt_logprobs
[
1
:]
vllm_prompt_logprobs
=
vllm_result
.
prompt_logprobs
[
1
:]
for
i
,
vllm_prompt_logprob_dict
in
enumerate
(
vllm_prompt_logprobs
):
for
i
,
vllm_prompt_logprob_dict
in
enumerate
(
vllm_prompt_logprobs
):
for
token_id
,
logprob
in
vllm_prompt_logprob_dict
.
items
():
for
token_id
,
logprob
in
vllm_prompt_logprob_dict
.
items
():
...
@@ -74,6 +101,17 @@ def test_get_prompt_logprobs(
...
@@ -74,6 +101,17 @@ def test_get_prompt_logprobs(
"The token should be decoded by the time it is returned "
"The token should be decoded by the time it is returned "
" to the user."
)
" to the user."
)
# Test if prompt logprobs are correctly set.
for
vllm_result
in
vllm_results
:
token_ids
=
vllm_result
.
prompt_token_ids
prompt_logprobs
=
vllm_result
.
prompt_logprobs
# The first token doesn't have logprob.
assert
prompt_logprobs
[
0
]
is
None
for
token_id
,
logprob_dict
in
zip
(
token_ids
[
1
:],
prompt_logprobs
[
1
:]):
assert
token_id
in
logprob_dict
def
test_max_logprobs
():
def
test_max_logprobs
():
runner
=
VllmRunner
(
"facebook/opt-125m"
,
max_logprobs
=
1
)
runner
=
VllmRunner
(
"facebook/opt-125m"
,
max_logprobs
=
1
)
...
...
tests/samplers/test_sampler.py
View file @
603ad848
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
transformers
import
GenerationConfig
,
GenerationMixin
from
transformers
import
GenerationConfig
,
GenerationMixin
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -54,6 +55,7 @@ def _do_sample(
...
@@ -54,6 +55,7 @@ def _do_sample(
sampler
:
MockLogitsSampler
,
sampler
:
MockLogitsSampler
,
model_runner
:
ModelRunner
,
model_runner
:
ModelRunner
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
device
:
str
,
):
):
seq_group_metadata_list
=
[]
seq_group_metadata_list
=
[]
prompt_lens
=
[]
prompt_lens
=
[]
...
@@ -68,9 +70,12 @@ def _do_sample(
...
@@ -68,9 +70,12 @@ def _do_sample(
))
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
,
seq_group_metadata_list
,
subquery_lens
=
prompt_lens
)
prompt_lens
,
subquery_lens
=
prompt_lens
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
return
sampler
(
logits
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
return
sampler
(
logits
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
...
@@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):
...
@@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
sampling_params
,
device
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
...
@@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
...
@@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
n
=
random
.
randint
(
1
,
10
),
n
=
random
.
randint
(
1
,
10
),
)
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
sampling_params
,
device
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
...
@@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
...
@@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
seed
=
random
.
randint
(
0
,
10000
),
seed
=
random
.
randint
(
0
,
10000
),
)
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
sampling_params
,
device
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
...
@@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
...
@@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
seed
=
random
.
randint
(
0
,
10000
),
seed
=
random
.
randint
(
0
,
10000
),
)
)
first_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
first_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
model_runner
,
sampling_params
,
device
)
second_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
second_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
model_runner
,
sampling_params
,
device
)
assert
first_sampler_output
==
second_sampler_output
assert
first_sampler_output
==
second_sampler_output
...
@@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
...
@@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
best_of
=
2
,
best_of
=
2
,
use_beam_search
=
True
,
use_beam_search
=
True
,
)
)
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
,
device
)
# no assertion here as I am not sure how to determine whether
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# whether there are no exceptions in the sampler
...
@@ -443,10 +449,12 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -443,10 +449,12 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
"batch size"
)
"batch size"
)
_
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
_
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
sampling_metadata
=
model_runner
.
_prepare_sampl
e
(
sampling_metadata
=
SamplingMetadata
.
prepar
e
(
seq_group_metadata_list
,
seq_group_metadata_list
,
prompt_lens
=
prompt_lens
if
prompt_lens
else
None
,
prompt_lens
=
prompt_lens
if
prompt_lens
else
None
,
subquery_lens
=
prompt_lens
if
prompt_lens
else
None
)
subquery_lens
=
prompt_lens
if
prompt_lens
else
None
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
# the logits tensor is modified in-place by the sampler
# the logits tensor is modified in-place by the sampler
_
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
_
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
...
@@ -530,8 +538,12 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -530,8 +538,12 @@ def test_sampler_mixed(seed: int, device: str):
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
def
test_sampling
(
model_runner
:
ModelRunner
):
def
test_sampling
(
model_runner
:
ModelRunner
):
sampling_metadata
=
model_runner
.
_prepare_sample
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
...
@@ -627,9 +639,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -627,9 +639,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
))
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
,
seq_group_metadata_list
,
subquery_lens
=
prompt_lens
)
prompt_lens
,
subquery_lens
=
prompt_lens
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
sample_probs
=
None
sample_probs
=
None
...
...
tests/test_logits_processor.py
View file @
603ad848
...
@@ -6,6 +6,7 @@ import pytest
...
@@ -6,6 +6,7 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.model_runner
import
ModelRunner
...
@@ -82,9 +83,12 @@ def test_logits_processors(seed: int, device: str):
...
@@ -82,9 +83,12 @@ def test_logits_processors(seed: int, device: str):
))
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
,
seq_group_metadata_list
,
subquery_lens
=
prompt_lens
)
prompt_lens
,
subquery_lens
=
prompt_lens
,
device
=
model_runner
.
device
,
pin_memory
=
model_runner
.
pin_memory
)
logits_processor_output
=
logits_processor
(
logits_processor_output
=
logits_processor
(
embedding
=
None
,
embedding
=
None
,
hidden_states
=
input_tensor
,
hidden_states
=
input_tensor
,
...
...
tests/worker/test_model_runner.py
View file @
603ad848
...
@@ -2,6 +2,7 @@ import pytest
...
@@ -2,6 +2,7 @@ import pytest
import
torch
import
torch
from
vllm.config
import
ModelConfig
,
SchedulerConfig
from
vllm.config
import
ModelConfig
,
SchedulerConfig
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
...
@@ -97,9 +98,12 @@ def test_prepare_prompt(batch_size):
...
@@ -97,9 +98,12 @@ def test_prepare_prompt(batch_size):
assert
len
(
input_positions
)
==
sum
(
prompt_lens
)
assert
len
(
input_positions
)
==
sum
(
prompt_lens
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
,
seq_group_metadata_list
,
subquery_lens
=
prompt_lens
)
prompt_lens
,
subquery_lens
=
prompt_lens
,
device
=
model_runner
.
device
,
pin_memory
=
model_runner
.
pin_memory
)
assert
len
(
input_tokens
)
==
sum
(
prompt_lens
)
assert
len
(
input_tokens
)
==
sum
(
prompt_lens
)
assert
len
(
input_positions
)
==
sum
(
prompt_lens
)
assert
len
(
input_positions
)
==
sum
(
prompt_lens
)
actual
=
sampling_metadata
.
selected_token_indices
actual
=
sampling_metadata
.
selected_token_indices
...
@@ -195,9 +199,12 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -195,9 +199,12 @@ def test_prepare_decode_cuda_graph(batch_size):
for
prompt_len
in
prompt_lens
:
for
prompt_len
in
prompt_lens
:
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
selected_token_start_idx
+=
1
selected_token_start_idx
+=
1
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
,
seq_group_metadata_list
,
subquery_lens
=
prompt_lens
)
prompt_lens
,
subquery_lens
=
prompt_lens
,
device
=
model_runner
.
device
,
pin_memory
=
model_runner
.
pin_memory
)
actual
=
sampling_metadata
.
selected_token_indices
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
device
=
actual
.
device
,
...
...
vllm/core/scheduler.py
View file @
603ad848
...
@@ -915,6 +915,20 @@ class Scheduler:
...
@@ -915,6 +915,20 @@ class Scheduler:
self
.
block_manager
.
get_common_computed_block_ids
(
self
.
block_manager
.
get_common_computed_block_ids
(
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)))
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)))
do_sample
=
True
if
seq_group
.
is_prefill
():
seqs
=
seq_group
.
get_seqs
()
# Prefill has only 1 sequence.
assert
len
(
seqs
)
==
1
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if
(
token_chunk_size
+
seqs
[
0
].
data
.
get_num_computed_tokens
()
<
seqs
[
0
].
data
.
get_len
()):
do_sample
=
False
# It assumes the scheduled_seq_groups is ordered by
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
# prefill < decoding.
is_prompt
=
seq_group
.
is_prefill
()
is_prompt
=
seq_group
.
is_prefill
()
...
@@ -924,6 +938,7 @@ class Scheduler:
...
@@ -924,6 +938,7 @@ class Scheduler:
seq_data
=
seq_data
,
seq_data
=
seq_data
,
sampling_params
=
seq_group
.
sampling_params
,
sampling_params
=
seq_group
.
sampling_params
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
do_sample
=
do_sample
,
token_chunk_size
=
token_chunk_size
,
token_chunk_size
=
token_chunk_size
,
lora_request
=
seq_group
.
lora_request
,
lora_request
=
seq_group
.
lora_request
,
computed_block_nums
=
common_computed_block_nums
,
computed_block_nums
=
common_computed_block_nums
,
...
...
vllm/engine/async_llm_engine.py
View file @
603ad848
...
@@ -219,7 +219,7 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -219,7 +219,7 @@ class _AsyncLLMEngine(LLMEngine):
request_outputs
=
self
.
_process_model_outputs
(
request_outputs
=
self
.
_process_model_outputs
(
output
,
scheduler_outputs
.
scheduled_seq_groups
,
output
,
scheduler_outputs
.
scheduled_seq_groups
,
scheduler_outputs
.
ignored_seq_groups
)
scheduler_outputs
.
ignored_seq_groups
,
seq_group_metadata_list
)
# Log stats.
# Log stats.
if
self
.
log_stats
:
if
self
.
log_stats
:
...
...
vllm/engine/llm_engine.py
View file @
603ad848
...
@@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
...
@@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
Sequence
,
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
Sequence
,
SequenceGroup
,
Sequence
Stage
)
SequenceGroup
,
Sequence
GroupMetadata
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
get_tokenizer_group
)
get_tokenizer_group
)
...
@@ -476,9 +476,12 @@ class LLMEngine:
...
@@ -476,9 +476,12 @@ class LLMEngine:
return
self
.
scheduler
.
has_unfinished_seqs
()
return
self
.
scheduler
.
has_unfinished_seqs
()
def
_process_model_outputs
(
def
_process_model_outputs
(
self
,
output
:
List
[
SamplerOutput
],
self
,
scheduled_seq_groups
:
List
[
SequenceGroup
],
output
:
List
[
SamplerOutput
],
ignored_seq_groups
:
List
[
SequenceGroup
])
->
List
[
RequestOutput
]:
scheduled_seq_groups
:
List
[
SequenceGroup
],
ignored_seq_groups
:
List
[
SequenceGroup
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
List
[
RequestOutput
]:
"""Apply the model output to the sequences in the scheduled seq groups.
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
Returns RequestOutputs that can be returned to the client.
...
@@ -492,17 +495,15 @@ class LLMEngine:
...
@@ -492,17 +495,15 @@ class LLMEngine:
sampler_outputs
=
output
,
num_seq_groups
=
len
(
scheduled_seq_groups
))
sampler_outputs
=
output
,
num_seq_groups
=
len
(
scheduled_seq_groups
))
# Update the scheduled sequence groups with the model outputs.
# Update the scheduled sequence groups with the model outputs.
for
scheduled_seq_group
,
outputs
in
zip
(
scheduled_seq_groups
,
for
scheduled_seq_group
,
outputs
,
seq_group_meta
in
zip
(
output_by_sequence_group
):
scheduled_seq_groups
,
output_by_sequence_group
,
seq_group_metadata_list
):
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
update_num_computed_tokens
(
seq_group
.
update_num_computed_tokens
(
scheduled_seq_group
.
token_chunk_size
)
scheduled_seq_group
.
token_chunk_size
)
# If all sequences in the sequence group are in DECODE, then we can
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
outputs
)
# process the output tokens. Otherwise, they are (chunked) prefill
if
seq_group_meta
.
do_sample
:
# samples and should not be processed.
stages
=
[
seq
.
data
.
_stage
for
seq
in
seq_group
.
seqs_dict
.
values
()]
if
all
(
stage
==
SequenceStage
.
DECODE
for
stage
in
stages
):
self
.
output_processor
.
process_outputs
(
seq_group
,
outputs
)
self
.
output_processor
.
process_outputs
(
seq_group
,
outputs
)
# Free the finished sequence groups.
# Free the finished sequence groups.
...
@@ -585,7 +586,7 @@ class LLMEngine:
...
@@ -585,7 +586,7 @@ class LLMEngine:
request_outputs
=
self
.
_process_model_outputs
(
request_outputs
=
self
.
_process_model_outputs
(
output
,
scheduler_outputs
.
scheduled_seq_groups
,
output
,
scheduler_outputs
.
scheduled_seq_groups
,
scheduler_outputs
.
ignored_seq_groups
)
scheduler_outputs
.
ignored_seq_groups
,
seq_group_metadata_list
)
# Log stats.
# Log stats.
if
self
.
log_stats
:
if
self
.
log_stats
:
...
...
vllm/engine/output_processor/interfaces.py
View file @
603ad848
...
@@ -68,3 +68,9 @@ class SequenceGroupOutputProcessor(ABC):
...
@@ -68,3 +68,9 @@ class SequenceGroupOutputProcessor(ABC):
scheduler.
scheduler.
"""
"""
pass
pass
@
abstractmethod
def
process_prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
"""Update prompt logprobs received from outputs to seq_group."""
pass
vllm/engine/output_processor/multi_step.py
View file @
603ad848
...
@@ -44,6 +44,15 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -44,6 +44,15 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
stop_checker
=
stop_checker
self
.
stop_checker
=
stop_checker
def
process_prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
logger
.
warning
(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers)."
)
pass
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
"""Append new tokens in the outputs to sequences in the sequence group.
"""Append new tokens in the outputs to sequences in the sequence group.
...
...
vllm/engine/output_processor/single_step.py
View file @
603ad848
...
@@ -55,17 +55,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -55,17 +55,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
),
f
"
{
type
(
self
)
}
does not support multiple outputs per step"
),
f
"
{
type
(
self
)
}
does not support multiple outputs per step"
return
self
.
_process_sequence_group_outputs
(
sequence_group
,
outputs
[
0
])
return
self
.
_process_sequence_group_outputs
(
sequence_group
,
outputs
[
0
])
def
_
process_
sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
def
process_
prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
)
->
None
:
outputs
:
List
[
SequenceGroupOutput
]
)
->
None
:
assert
len
(
outputs
)
==
1
,
(
"Single step should only has 1 output."
)
# Process prompt logprobs
output
=
outputs
[
0
]
prompt_logprobs
=
output
s
.
prompt_logprobs
prompt_logprobs
=
output
.
prompt_logprobs
if
prompt_logprobs
is
not
None
and
\
if
(
prompt_logprobs
is
not
None
seq_group
.
sampling_params
.
detokenize
and
self
.
detokenizer
:
and
seq_group
.
sampling_params
.
detokenize
and
self
.
detokenizer
)
:
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
prompt_logprobs
)
seq_group
,
prompt_logprobs
)
seq_group
.
prompt_logprobs
=
prompt_logprobs
if
not
seq_group
.
prompt_logprobs
:
# The first prompt token's logprob is None because it doesn't
# have tokens that are precedent.
seq_group
.
prompt_logprobs
=
[
None
]
seq_group
.
prompt_logprobs
.
extend
(
prompt_logprobs
)
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
)
->
None
:
# Process samples
# Process samples
samples
=
outputs
.
samples
samples
=
outputs
.
samples
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
...
...
vllm/engine/output_processor/util.py
View file @
603ad848
from
typing
import
List
from
typing
import
List
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupOutput
def
create_output_by_sequence_group
(
sampler_outputs
:
List
[
SamplerOutput
],
def
create_output_by_sequence_group
(
num_seq_groups
:
int
):
sampler_outputs
:
List
[
SamplerOutput
],
num_seq_groups
:
int
)
->
List
[
List
[
SequenceGroupOutput
]]:
"""Helper method which transforms a 2d list organized by
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
[step][sequence group] into [sequence group][step].
"""
"""
...
...
vllm/model_executor/layers/logits_processor.py
View file @
603ad848
...
@@ -83,30 +83,27 @@ def _apply_logits_processors(
...
@@ -83,30 +83,27 @@ def _apply_logits_processors(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
logits_row_idx
=
0
found_logits_processors
=
False
found_logits_processors
=
False
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
logits_processed
=
0
seq_ids
,
sampling_params
=
seq_group
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
logits_processors
=
sampling_params
.
logits_processors
logits_processors
=
sampling_params
.
logits_processors
# handle prompt_logprobs by skipping rows in logits added for
# the prompt tokens (prompt logprobs are not processed)
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
assert
len
(
seq_ids
)
==
1
logits_row_idx
+=
sampling_metadata
.
prompt_lens
[
i
]
-
1
if
logits_processors
:
if
logits_processors
:
found_logits_processors
=
True
found_logits_processors
=
True
for
seq_id
in
seq_ids
:
for
seq_id
,
logits_row_idx
in
zip
(
seq_ids
,
seq_group
.
sample_indices
):
logits_row
=
logits
[
logits_row_idx
]
logits_row
=
logits
[
logits_row_idx
]
token_ids
=
s
ampling_metadata
.
seq_data
[
seq_id
].
output_token_ids
token_ids
=
s
eq_group
.
seq_data
[
seq_id
].
output_token_ids
for
logits_processor
in
logits_processors
:
for
logits_processor
in
logits_processors
:
logits_row
=
logits_processor
(
token_ids
,
logits_row
)
logits_row
=
logits_processor
(
token_ids
,
logits_row
)
logits
[
logits_row_idx
]
=
logits_row
logits
[
logits_row_idx
]
=
logits_row
logits_row_idx
+=
1
else
:
logits_processed
+=
len
(
seq_group
.
sample_indices
)
+
len
(
logits_row_idx
+=
len
(
seq_ids
)
seq_group
.
prompt_logprob_indices
)
if
found_logits_processors
:
if
found_logits_processors
:
# verifies that no rows in logits were missed unexpectedly
# verifies that no rows in logits were missed unexpectedly
assert
logits_ro
w_idx
==
logits
.
shape
[
0
]
assert
logits_
p
ro
cessed
==
logits
.
shape
[
0
]
return
logits
return
logits
vllm/model_executor/layers/sampler.py
View file @
603ad848
...
@@ -7,11 +7,11 @@ import torch.nn as nn
...
@@ -7,11 +7,11 @@ import torch.nn as nn
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingTensors
)
SamplingTensors
,
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
SequenceGroupToSample
)
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
(
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
from
vllm.sequence
import
(
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SamplerOutput
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceOutput
)
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
...
@@ -48,11 +48,14 @@ class Sampler(nn.Module):
...
@@ -48,11 +48,14 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
"""
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert
logits
is
not
None
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
_
,
vocab_size
=
logits
.
shape
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
# have not been generated yet
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
# Prepare sampling tensors with pinned memory to avoid blocking.
# Prepare sampling tensors with pinned memory to avoid blocking.
...
@@ -83,7 +86,6 @@ class Sampler(nn.Module):
...
@@ -83,7 +86,6 @@ class Sampler(nn.Module):
# Compute the probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities.
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Sample the next tokens.
# Sample the next tokens.
...
@@ -149,24 +151,28 @@ def _apply_min_tokens_penalty(
...
@@ -149,24 +151,28 @@ def _apply_min_tokens_penalty(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
have not been generated yet
"""
# list of indices in logits that will be set to -inf
# list of indices in logits that will be set to -inf
logits_to_penalize
=
[]
logits_to_penalize
=
[]
start_idx
=
0
logits_applied
=
0
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
)
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
,
sampling_params
=
seq_group
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
# handle prompt_logprobs by skipping rows in logits added for the prompt
# tokens (prompt logprobs are not penalized)
sample_indices
=
seq_group
.
sample_indices
if
(
i
<
sampling_metadata
.
num_prompts
logits_applied
+=
len
(
sample_indices
)
+
len
(
and
sampling_params
.
prompt_logprobs
is
not
None
):
seq_group
.
prompt_logprob_indices
)
assert
len
(
seq_ids
)
==
1
if
not
seq_group
.
do_sample
:
start_idx
+=
sampling_metadata
.
prompt_lens
[
i
]
-
1
continue
start_idx
=
sample_indices
[
0
]
min_tokens
=
sampling_params
.
min_tokens
min_tokens
=
sampling_params
.
min_tokens
if
min_tokens
>
0
:
if
min_tokens
>
0
:
seqs_to_penalize
=
[]
seqs_to_penalize
=
[]
for
i
,
seq_id
in
enumerate
(
seq_ids
):
for
i
,
seq_id
in
enumerate
(
seq_ids
):
seq_data
=
s
ampling_metadata
.
seq_data
[
seq_id
]
seq_data
=
s
eq_group
.
seq_data
[
seq_id
]
if
len
(
seq_data
.
output_token_ids
)
<
min_tokens
:
if
len
(
seq_data
.
output_token_ids
)
<
min_tokens
:
seqs_to_penalize
.
append
(
i
)
seqs_to_penalize
.
append
(
i
)
...
@@ -180,15 +186,13 @@ def _apply_min_tokens_penalty(
...
@@ -180,15 +186,13 @@ def _apply_min_tokens_penalty(
logits_to_penalize
.
extend
(
logits_to_penalize
.
extend
(
itertools
.
product
(
seqs_to_penalize
,
token_ids_to_penalize
))
itertools
.
product
(
seqs_to_penalize
,
token_ids_to_penalize
))
start_idx
+=
len
(
seq_ids
)
if
logits_to_penalize
:
if
logits_to_penalize
:
# use zip and * to group indices along each dimension
# use zip and * to group indices along each dimension
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits
[
tuple
(
zip
(
*
logits_to_penalize
))]
=
-
float
(
"inf"
)
logits
[
tuple
(
zip
(
*
logits_to_penalize
))]
=
-
float
(
"inf"
)
# verifies that no rows in logits were missed unexpectedly
# verifies that no rows in logits were missed unexpectedly
assert
start_idx
==
logits
.
shape
[
0
]
assert
logits_applied
==
logits
.
shape
[
0
]
return
logits
return
logits
...
@@ -265,14 +269,30 @@ def _apply_min_p(
...
@@ -265,14 +269,30 @@ def _apply_min_p(
def
_greedy_sample
(
def
_greedy_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]
],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
samples
:
torch
.
Tensor
,
samples
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
samples
=
samples
.
tolist
()
samples
=
samples
.
tolist
()
sample_idx
=
0
sample_idx
=
0
results
=
[]
results
=
[]
for
seq_group
in
selected_seq_groups
:
for
seq_group
in
selected_seq_groups
:
seq_ids
,
_
=
seq_group
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
num_parent_seqs
=
len
(
seq_ids
)
num_parent_seqs
=
len
(
seq_ids
)
assert
num_parent_seqs
==
1
,
(
assert
num_parent_seqs
==
1
,
(
"Greedy sampling should have only one seq."
)
"Greedy sampling should have only one seq."
)
...
@@ -284,16 +304,33 @@ def _greedy_sample(
...
@@ -284,16 +304,33 @@ def _greedy_sample(
def
_random_sample
(
def
_random_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
is_prompts
:
List
[
bool
],
random_samples
:
torch
.
Tensor
,
random_samples
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum best_of value of the prompt phase requests.
# Find the maximum best_of value of the prompt phase requests.
random_samples
=
random_samples
.
cpu
()
random_samples
=
random_samples
.
cpu
()
sample_idx
=
0
sample_idx
=
0
results
=
[]
results
=
[]
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
for
seq_group
in
selected_seq_groups
:
seq_ids
,
sampling_params
=
seq_group
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
num_parent_seqs
=
len
(
seq_ids
)
num_parent_seqs
=
len
(
seq_ids
)
if
is_prompt
:
if
is_prompt
:
# Prompt phase.
# Prompt phase.
...
@@ -311,11 +348,20 @@ def _random_sample(
...
@@ -311,11 +348,20 @@ def _random_sample(
def
_beam_search_sample
(
def
_beam_search_sample
(
selected_seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
is_prompts
:
List
[
bool
],
seq_data
:
Dict
[
int
,
SequenceData
],
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
"""Run beam sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
on selected sample indices.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# We sample 2 * beam_width candidates to make sure that with high
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# the finished sequences for the next iteration. See
...
@@ -327,8 +373,13 @@ def _beam_search_sample(
...
@@ -327,8 +373,13 @@ def _beam_search_sample(
# other sampling methods.
# other sampling methods.
sample_idx
=
0
sample_idx
=
0
results
=
[]
results
=
[]
for
seq_group
,
is_prompt
in
zip
(
selected_seq_groups
,
is_prompts
):
for
seq_group
in
selected_seq_groups
:
seq_ids
,
sampling_params
=
seq_group
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
is_prompt
=
seq_group
.
is_prompt
seq_ids
,
sampling_params
=
seq_group
.
seq_ids
,
seq_group
.
sampling_params
num_parent_seqs
=
len
(
seq_ids
)
num_parent_seqs
=
len
(
seq_ids
)
beam_width
=
sampling_params
.
best_of
beam_width
=
sampling_params
.
best_of
seq_group_logprobs
=
logprobs
[
sample_idx
:
sample_idx
+
num_parent_seqs
]
seq_group_logprobs
=
logprobs
[
sample_idx
:
sample_idx
+
num_parent_seqs
]
...
@@ -343,7 +394,8 @@ def _beam_search_sample(
...
@@ -343,7 +394,8 @@ def _beam_search_sample(
else
:
else
:
# Generation phase.
# Generation phase.
cumulative_logprobs
=
[
cumulative_logprobs
=
[
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
seq_group
.
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
]
]
cumulative_logprobs
=
torch
.
tensor
(
cumulative_logprobs
=
torch
.
tensor
(
cumulative_logprobs
,
cumulative_logprobs
,
...
@@ -371,8 +423,7 @@ def _beam_search_sample(
...
@@ -371,8 +423,7 @@ def _beam_search_sample(
def
_multinomial
(
def
_multinomial
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
num_samples
:
int
,
num_samples
:
int
,
seq_groups
:
Optional
[
List
[
Tuple
[
List
[
int
],
SamplingParams
]]]
=
None
,
seq_groups
:
Optional
[
List
[
SequenceGroupToSample
]]
=
None
,
generators
:
Optional
[
List
[
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# This is equivalent to torch.repeat_interleaved (which also
...
@@ -388,9 +439,11 @@ def _multinomial(
...
@@ -388,9 +439,11 @@ def _multinomial(
q
.
exponential_
()
q
.
exponential_
()
else
:
else
:
sample_idx
=
0
sample_idx
=
0
for
(
seq_ids
,
_
),
generator
in
zip
(
seq_groups
,
generators
):
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
next_sample_idx
=
sample_idx
+
len
(
seq_ids
)
*
num_samples
next_sample_idx
=
sample_idx
+
len
(
seq_ids
)
*
num_samples
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
generator
=
generator
)
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
generator
=
seq_group
.
generator
)
sample_idx
=
next_sample_idx
sample_idx
=
next_sample_idx
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
...
@@ -405,7 +458,7 @@ def _sample_with_torch(
...
@@ -405,7 +458,7 @@ def _sample_with_torch(
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
_
,
sampling_params
=
seq_group
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
@@ -429,13 +482,11 @@ def _sample_with_torch(
...
@@ -429,13 +482,11 @@ def _sample_with_torch(
num_tokens
=
len
(
sample_indices
)
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
if
num_tokens
==
0
:
continue
continue
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
sampling_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
)
long_sample_indices
=
sample_indices
.
long
()
seq_group_id
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_id
]
sample_metadata
[
sampling_type
]
=
(
seq_group_id
,
seq_groups
)
long_sample_indices
=
sample_indices
.
long
()
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
dim
=-
1
)
dim
=-
1
)
...
@@ -455,14 +506,13 @@ def _sample_with_torch(
...
@@ -455,14 +506,13 @@ def _sample_with_torch(
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
max_best_of_in_batch
=
1
max_best_of_in_batch
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
)
:
for
seq_group
in
seq_groups
:
if
is_prompt
:
if
seq_group
.
is_prompt
:
_
,
sampling_params
=
seq_group
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
sampling_params
.
best_of
)
seeded_args
=
{}
if
sampling_type
==
SamplingType
.
RANDOM
else
{
seeded_args
=
{}
if
sampling_type
==
SamplingType
.
RANDOM
else
{
"seq_groups"
:
seq_groups
,
"seq_groups"
:
seq_groups
,
"generators"
:
sampling_metadata
.
generators
,
}
}
multinomial_samples
[
sampling_type
]
=
_multinomial
(
multinomial_samples
[
sampling_type
]
=
_multinomial
(
...
@@ -481,25 +531,22 @@ def _sample_with_torch(
...
@@ -481,25 +531,22 @@ def _sample_with_torch(
# GPU<->CPU sync happens in the loop below.
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
# This also converts the sample output to Python objects.
for
sampling_type
in
SamplingType
:
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
if
sampling_type
not
in
sample_metadata
:
continue
continue
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
=
sample_metadata
[
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
multinomial_samples
[
sampling_type
])
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_beam_search_sample
(
seq_groups
,
sampling_metadata
.
seq_data
,
beam_search_logprobs
)
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
s
,
sample_results
))
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
[
sample_results
=
[
sample_results_dict
[
i
]
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
]
return
sample_results
,
sampled_token_ids_tensor
return
sample_results
,
sampled_token_ids_tensor
...
@@ -514,7 +561,7 @@ def _sample_with_triton_kernel(
...
@@ -514,7 +561,7 @@ def _sample_with_triton_kernel(
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
_
,
sampling_params
=
seq_group
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
@@ -530,17 +577,16 @@ def _sample_with_triton_kernel(
...
@@ -530,17 +577,16 @@ def _sample_with_triton_kernel(
num_tokens
=
len
(
sample_indices
)
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
if
num_tokens
==
0
:
continue
continue
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_group_id
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_id
]
is_prompts
=
[
i
<
sampling_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_metadata
[
sampling_type
]
=
(
seq_group_id
,
seq_groups
,
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
sample_indices
,
is_prompts
,
sample_indices
,
sampled_token_indices
)
sampled_token_indices
)
if
sampling_type
in
(
SamplingType
.
GREEDY
,
SamplingType
.
RANDOM
,
if
sampling_type
in
(
SamplingType
.
GREEDY
,
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
SamplingType
.
RANDOM_SEED
):
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
)
:
for
seq_group
in
seq_groups
:
if
is_prompt
:
if
seq_group
.
is_prompt
:
_
,
sampling_params
=
seq_group
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
sampling_params
.
best_of
)
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
...
@@ -564,22 +610,21 @@ def _sample_with_triton_kernel(
...
@@ -564,22 +610,21 @@ def _sample_with_triton_kernel(
for
sampling_type
in
SamplingType
:
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
if
sampling_type
not
in
sample_metadata
:
continue
continue
(
seq_group_id
s
,
seq_groups
,
is_prompts
,
sample_indices
,
(
seq_group_id
,
seq_groups
,
sample_indices
,
sampled_token_indices
)
=
sample_metadata
[
sampling_type
]
sampled_token_indices
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
sample_results
=
_greedy_sample
(
seq_groups
,
sampled_tokens
[
sampled_token_indices
][:,
0
])
seq_groups
,
sampled_tokens
[
sampled_token_indices
][:,
0
])
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
sample_results
=
_random_sample
(
seq_groups
,
is_prompts
,
sampled_tokens
[
sampled_token_indices
])
seq_groups
,
sampled_tokens
[
sampled_token_indices
])
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_beam_search_sample
(
seq_groups
,
sampling_metadata
.
seq_data
,
beam_search_logprobs
)
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
s
,
sample_results
))
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
[
sample_results
=
[
sample_results_dict
[
i
]
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
]
return
sample_results
return
sample_results
...
@@ -590,6 +635,18 @@ def _sample(
...
@@ -590,6 +635,18 @@ def _sample(
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
)
->
Tuple
[
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
Optional
[
torch
.
Tensor
]]:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return
_sample_with_torch
(
return
_sample_with_torch
(
probs
,
probs
,
logprobs
,
logprobs
,
...
@@ -626,56 +683,97 @@ def _get_logprobs(
...
@@ -626,56 +683,97 @@ def _get_logprobs(
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
)
->
Tuple
[
List
[
Optional
[
List
[
Optional
[
Dict
[
int
,
float
]]]]],
List
[
List
[
Dict
[
)
->
Tuple
[
List
[
Optional
[
PromptLogprobs
]],
List
[
SampleLogprobs
]]:
int
,
float
]]]]:
"""Return sample lobprobs and prompt logprobs.
# Prepare query indices
batched_logprobs_query_seq_indices
:
List
[
int
]
=
[]
The logic consists of 3 parts.
batched_logprobs_query_token_indices
:
List
[
int
]
=
[]
- Select indices to compute logprob from, ranks of token ids, and
# at least get one logprob for each token
the top k token ids from logprobs.
- Compute prompt logprobs if required.
- Compute sample logprobs if required.
Args:
logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
logprob per vocab. Sequence groups' query tokens are batched in a
single flattened tensor. For example, assuming there are N
seq groups, it is sorted by prefill tokens for seq_group_1 (if
prompt logprob is enabled), decode tokens for seq_group_1 (if
sampling is required), prefill tokens for seq_group_2, ...
sampling_metadata: The sampling metadata.
sample_results: (num_seq_groups) The tuple of (next_token_ids,
parent_ids) for each sequence group. When beam search is enabled,
sample_results can contain different number of seq_ids from
sampling_metadata.seq_groups. It is because beam search creates
2 * BEAM_WIDTH number of samples (whereas there are only up to
BEAM_WIDTH number of seq_ids).
Returns:
A tuple of prompt and sample logprobs per sequence group in a batch.
"""
# The index of query token to calculate logprobs. It includes both
# prompt and sample logprob indices.
query_indices
:
List
[
int
]
=
[]
# The next token ids to get the logprob value from.
next_token_ids
:
List
[
int
]
=
[]
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
largest_num_logprobs
=
1
largest_num_logprobs
=
1
sample_idx
=
0
for
i
,
(
seq_group
,
sample_result
)
in
enumerate
(
# Select indices to compute logprob from, ranks of token ids, and the top
zip
(
sampling_metadata
.
seq_groups
,
sample_results
)):
# k token ids from logprobs.
seq_ids
,
sampling_params
=
seq_group
for
(
seq_group
,
sample_result
)
in
zip
(
sampling_metadata
.
seq_groups
,
next_token_ids
,
parent_ids
=
sample_result
sample_results
):
num_parent_seqs
=
len
(
seq_ids
)
sampling_params
=
seq_group
.
sampling_params
if
(
i
<
sampling_metadata
.
num_prompts
# Update indices and tokens for prompt logprobs.
if
(
seq_group
.
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
largest_num_logprobs
=
max
(
largest_num_logprobs
,
largest_num_logprobs
=
max
(
largest_num_logprobs
,
sampling_params
.
prompt_logprobs
)
sampling_params
.
prompt_logprobs
)
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
next_prompt_tokens
=
_get_next_prompt_tokens
(
seq_group
)
prompt_tokens
=
sampling_metadata
.
seq_data
[
query_indices
.
extend
(
seq_group
.
prompt_logprob_indices
)
seq_ids
[
0
]].
prompt_token_ids
next_token_ids
.
extend
(
next_prompt_tokens
)
batched_logprobs_query_seq_indices
.
extend
(
sample_idx
+
j
for
j
in
range
(
prompt_len
-
1
))
# Update indices and next tokenes for sample logprob.
batched_logprobs_query_token_indices
.
extend
(
if
seq_group
.
do_sample
:
token_id
for
token_id
in
prompt_tokens
[
1
:])
token_ids
,
parent_seq_ids
=
sample_result
sample_idx
+=
prompt_len
-
1
# NOTE: We cannot directly use sample_indices because
batched_logprobs_query_seq_indices
.
extend
(
# sample_indices only contain parent seq_ids of a previous step.
[
sample_idx
+
parent_id
for
parent_id
in
parent_ids
])
# The current step may have different number of seq_ids, and
batched_logprobs_query_token_indices
.
extend
(
next_token_ids
)
# we can obtain it from `sample_result[1]`.
if
sampling_params
.
logprobs
is
not
None
:
query_idx
=
seq_group
.
sample_indices
[
0
]
largest_num_logprobs
=
max
(
largest_num_logprobs
,
query_indices
.
extend
(
sampling_params
.
logprobs
)
[
query_idx
+
parent_id
for
parent_id
in
parent_seq_ids
])
sample_idx
+=
num_parent_seqs
next_token_ids
.
extend
(
token_ids
)
assert
sample_idx
==
logprobs
.
size
(
0
)
if
sampling_params
.
logprobs
is
not
None
:
batched_logprobs_query_seq_indices_gpu
=
torch
.
tensor
(
largest_num_logprobs
=
max
(
largest_num_logprobs
,
batched_logprobs_query_seq_indices
,
device
=
logprobs
.
device
)
sampling_params
.
logprobs
)
batched_logprobs_query_token_indices_gpu
=
torch
.
tensor
(
batched_logprobs_query_token_indices
,
device
=
logprobs
.
device
)
assert
len
(
next_token_ids
)
==
len
(
query_indices
)
# Batched query for logprobs of selected token
if
len
(
query_indices
)
==
0
:
batched_logprobs_query_result
=
logprobs
[[
empty_sampled_logprob
=
[]
batched_logprobs_query_seq_indices_gpu
,
empty_prompt_logprob
=
None
batched_logprobs_query_token_indices_gpu
return
[
empty_prompt_logprob
],
[
empty_sampled_logprob
]
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
next_token_ids_gpu
=
torch
.
tensor
(
next_token_ids
,
device
=
logprobs
.
device
)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs
=
logprobs
[[
query_indices_gpu
,
next_token_ids_gpu
,
]]
]]
ranks
=
_get_ranks
(
logprobs
[
query_indices_gpu
],
next_token_ids_gpu
,
)
assert
selected_logprobs
.
shape
[
0
]
==
ranks
.
shape
[
0
]
batched_ranks_query_result
=
_get_ranks
(
# Logprobs of topk tokens for a batch of sequence groups.
logprobs
[
batched_logprobs_query_seq_indices_gpu
],
# (num_query_tokens_across_batch).
batched_logprobs_query_token_indices_gpu
)
# Batched query for logprobs of topk tokens
if
largest_num_logprobs
>
0
:
if
largest_num_logprobs
>
0
:
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
largest_num_logprobs
,
largest_num_logprobs
,
...
@@ -685,79 +783,136 @@ def _get_logprobs(
...
@@ -685,79 +783,136 @@ def _get_logprobs(
else
:
else
:
top_logprobs
,
top_token_ids
=
None
,
None
top_logprobs
,
top_token_ids
=
None
,
None
batched_logprobs_query_result
=
batched_logprobs_query_result
.
cpu
()
selected_logprobs
=
selected_logprobs
.
cpu
()
batched_ranks_query_result
=
batched_ranks_query_result
.
cpu
()
ranks
=
ranks
.
cpu
()
# Gather results
# Find prompt/sample logprobs.
result_prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
result_sample_logprobs
:
List
[
SampleLogprobs
]
=
[]
sample_logprobs_per_seq_group
:
List
[
SampleLogprobs
]
=
[]
sample_idx
=
0
top_logprob_idx
=
0
query_result_idx
=
0
selected_logprobs_idx
=
0
for
i
,
(
seq_group
,
sample_result
)
in
enumerate
(
zip
(
sampling_metadata
.
seq_groups
,
sample_results
)):
for
seq_group
,
sample_result
in
zip
(
sampling_metadata
.
seq_groups
,
seq_ids
,
sampling_params
=
seq_group
sample_results
):
next_token_ids
,
parent_ids
=
sample_result
(
prompt_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
)
=
_get_prompt_logprob_if_needed
(
seq_group
,
selected_logprobs
,
ranks
,
top_token_ids
,
top_logprobs
,
selected_logprobs_idx
,
top_logprob_idx
)
prompt_logprobs_per_seq_group
.
append
(
prompt_logprobs
)
(
sampled_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
)
=
_get_sampled_logprob_if_needed
(
seq_group
,
sample_result
,
selected_logprobs
,
ranks
,
top_token_ids
,
top_logprobs
,
selected_logprobs_idx
,
top_logprob_idx
)
sample_logprobs_per_seq_group
.
append
(
sampled_logprobs
)
return
prompt_logprobs_per_seq_group
,
sample_logprobs_per_seq_group
def
_get_prompt_logprob_if_needed
(
seq_group
:
SequenceGroupToSample
,
selected_logprobs
:
torch
.
Tensor
,
ranks
:
torch
.
Tensor
,
top_token_ids
:
torch
.
Tensor
,
top_logprobs
:
torch
.
Tensor
,
selected_logprobs_idx
:
int
,
top_logprob_idx
:
int
,
):
"""Compute the prompt logprob from a sequence group if needed."""
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
# Find prompt logprobs
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
if
(
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
prompt_logprobs
=
[]
num_logprobs
=
sampling_params
.
prompt_logprobs
next_prompt_tokens
=
_get_next_prompt_tokens
(
seq_group
)
for
token_id
in
next_prompt_tokens
:
# Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
token_id
:
(
selected_logprobs
[
selected_logprobs_idx
].
item
(),
ranks
[
selected_logprobs_idx
].
item
())
}
# Prompt logprobs
# Add top K prompt logprobs along with its rank.
if
(
i
<
sampling_metadata
.
num_prompts
if
num_logprobs
>
0
:
and
sampling_params
.
prompt_logprobs
is
not
None
):
prompt_logprobs_dict
.
update
(
num_logprobs
=
sampling_params
.
prompt_logprobs
zip
(
prompt_tokens
=
sampling_metadata
.
seq_data
[
top_token_ids
[
top_logprob_idx
,
:
num_logprobs
].
tolist
(),
seq_ids
[
0
]].
prompt_token_ids
group_prompt_logprobs
:
PromptLogprobs
=
[
None
]
for
token_id
in
prompt_tokens
[
1
:]:
prompt_logprobs_dict
=
{
token_id
:
(
batched_logprobs_query_result
[
query_result_idx
].
item
(),
batched_ranks_query_result
[
query_result_idx
].
item
())
}
if
num_logprobs
>
0
:
prompt_logprobs_dict
.
update
(
zip
(
zip
(
top_token_ids
[
sample_idx
,
:
num_logprobs
].
tolist
(),
top_logprobs
[
zip
(
top_logprob_idx
,
:
num_logprobs
].
tolist
(),
top_logprobs
[
# This is ranks. Since top_logprob is sorted,
sample_idx
,
:
num_logprobs
].
tolist
(),
# we can just use a range here.
range
(
1
,
num_logprobs
+
1
))))
range
(
1
,
num_logprobs
+
1
))))
group_prompt_logprobs
.
append
({
prompt_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_rank
)
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_rank
in
prompt_logprobs_dict
.
items
()
for
token_id
,
logprob_and_rank
in
prompt_logprobs_dict
.
items
()
})
})
sample_idx
+=
1
# + 1 to go to the next prompt token.
query_result_idx
+=
1
top_logprob_idx
+=
1
result_prompt_logprobs
.
append
(
group_prompt_logprobs
)
selected_logprobs_idx
+=
1
else
:
return
prompt_logprobs
,
top_logprob_idx
,
selected_logprobs_idx
result_prompt_logprobs
.
append
(
None
)
# Sample logprobs
def
_get_sampled_logprob_if_needed
(
num_logprobs
=
sampling_params
.
logprobs
seq_group
:
SequenceGroupToSample
,
if
num_logprobs
is
None
:
sample_result
:
Tuple
[
List
[
int
],
List
[
int
]],
num_logprobs
=
0
selected_logprobs
:
torch
.
Tensor
,
group_sample_logprobs
:
SampleLogprobs
=
[]
ranks
:
torch
.
Tensor
,
for
next_token_id
,
parent_id
in
zip
(
next_token_ids
,
parent_ids
):
top_token_ids
:
torch
.
Tensor
,
sample_logprobs_dict
=
{
top_logprobs
:
torch
.
Tensor
,
selected_logprobs_idx
:
int
,
top_logprob_idx
:
int
,
):
"""Compute the sample logprob if needed."""
seq_ids
=
seq_group
.
seq_ids
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
if
num_logprobs
is
None
:
num_logprobs
=
0
sampled_logprobs
:
SampleLogprobs
=
[]
next_token_ids
,
parent_seq_ids
=
sample_result
if
seq_group
.
do_sample
:
assert
len
(
next_token_ids
)
>
0
for
(
next_token_id
,
parent_id
)
in
zip
(
next_token_ids
,
parent_seq_ids
):
# Calculate the sample logprob of the real sampled tokens.
# Use tuple here for performance (to use to_list()).
# token_id: (logprob, rank_from_vocab)
sampled_logprobs_dict
:
Dict
[
int
,
Tuple
[
float
,
int
]]
=
{
next_token_id
:
next_token_id
:
(
batch
ed_logprobs
_query_result
[
query_result
_idx
].
item
(),
(
select
ed_logprobs
[
selected_logprobs
_idx
].
item
(),
batched_ranks_query_result
[
query_result
_idx
].
item
())
ranks
[
selected_logprobs
_idx
].
item
())
}
}
query_result_idx
+=
1
# +1 to go to the next sampled token. Note that
# selected_logprobs can contain duplicates unlike top_logprobs
# when beam search is enabled.
selected_logprobs_idx
+=
1
# Second, add top K logprobs along with its rank.
if
num_logprobs
>=
0
:
if
num_logprobs
>=
0
:
sample_logprobs_dict
.
update
(
sample
d
_logprobs_dict
.
update
(
zip
(
zip
(
top_token_ids
[
sample
_idx
+
top_token_ids
[
top_logprob
_idx
+
parent_id
,
:
num_logprobs
].
tolist
(),
parent_id
,
:
num_logprobs
].
tolist
(),
zip
(
zip
(
top_logprobs
[
sample
_idx
+
top_logprobs
[
top_logprob
_idx
+
parent_id
,
:
num_logprobs
].
tolist
(),
parent_id
,
:
num_logprobs
].
tolist
(),
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range
(
1
,
num_logprobs
+
1
))))
range
(
1
,
num_logprobs
+
1
))))
group_sample_logprobs
.
append
({
sampled_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_rank
)
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_rank
in
sample_logprobs_dict
.
items
()
for
token_id
,
logprob_and_rank
in
sampled_logprobs_dict
.
items
()
})
})
result_sample_logprobs
.
append
(
group_sample_logprobs
)
# There are len(seq_ids) number of sampled tokens for the current
sample_idx
+=
len
(
seq_ids
)
# sequence group in top_logprobs. Jump to the next seq_group.
top_logprob_idx
+=
len
(
seq_ids
)
return
result_prompt_logprobs
,
result_sample
_logprobs
return
sampled_logprobs
,
top_logprob_idx
,
selected
_logprobs
_idx
def
_modify_greedy_probs_inplace
(
logprobs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
def
_modify_greedy_probs_inplace
(
logprobs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
...
@@ -832,7 +987,7 @@ def _build_sampler_output(
...
@@ -832,7 +987,7 @@ def _build_sampler_output(
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
sample_results
,
prompt_logprobs
,
sample_results
,
prompt_logprobs
,
sample_logprobs
):
sample_logprobs
):
seq_ids
,
_
=
seq_group
seq_ids
=
seq_group
.
seq_ids
next_token_ids
,
parent_ids
=
sample_result
next_token_ids
,
parent_ids
=
sample_result
seq_outputs
=
[]
seq_outputs
=
[]
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
...
@@ -854,3 +1009,36 @@ def _build_sampler_output(
...
@@ -854,3 +1009,36 @@ def _build_sampler_output(
sampled_token_probs
=
sampled_token_probs
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
)
)
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
str
]:
"""Get a list of next prompt tokens to compute logprob from a
given sequence group.
It is used to compute prompt logprob. Imagine you have logprob for each
query token. Query token needs to know the next prompt token id to compute
prompt logprob. This is a helper to obtain next prompt token ids.
This API has to be used only when the caller knows seq_group is in prefill
stage.
Returns:
A list of next prompt tokens to compute logprob.
"""
assert
seq_group
.
is_prompt
,
(
"Caller should ensure the sequence group is in a prefill stage."
)
seq_ids
=
seq_group
.
seq_ids
subquery_len
=
seq_group
.
subquery_len
assert
subquery_len
is
not
None
# prompt has only 1 seq id.
assert
len
(
seq_ids
)
==
1
seq_data
=
seq_group
.
seq_data
[
seq_ids
[
0
]]
computed_len
=
seq_data
.
get_num_computed_tokens
()
prompt_tokens
=
seq_data
.
prompt_token_ids
# +1 because we are looking for a next prompt token.
next_token_index_start
=
computed_len
+
1
next_token_index_end
=
min
(
computed_len
+
subquery_len
+
1
,
len
(
prompt_tokens
))
next_prompt_tokens
=
prompt_tokens
[
next_token_index_start
:
next_token_index_end
]
return
next_prompt_tokens
vllm/model_executor/sampling_metadata.py
View file @
603ad848
...
@@ -6,57 +6,275 @@ import torch
...
@@ -6,57 +6,275 @@ import torch
from
vllm.model_executor.layers.ops.sample
import
get_num_triton_sampler_splits
from
vllm.model_executor.layers.ops.sample
import
get_num_triton_sampler_splits
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
maybe_expand_dim
)
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
_SEED_0_REPLACEMENT
=
3403598558
_SEED_0_REPLACEMENT
=
3403598558
@
dataclass
class
SequenceGroupToSample
:
# Sequence ids for the sequence group in a previous step.
seq_ids
:
List
[
int
]
sampling_params
:
SamplingParams
# seq_id -> sequence data.
seq_data
:
Dict
[
int
,
SequenceData
]
# The length of the prompt of the sequence group. None if it is in a decode
# stage.
prompt_len
:
Optional
[
int
]
# The length of the query tokens to compute in the current step. None if it
# is in a decode stage. The length of subquery_len <= prompt_len.
subquery_len
:
Optional
[
int
]
# A random number generator for sampling.
generator
:
Optional
[
torch
.
Generator
]
# True if the sequence group is in prefill stage. False if it is in a
# decode stage.
is_prompt
:
bool
# Query token indices from logits. to compute prompt logprob. Empty if
# prompt logprob is not required.
prompt_logprob_indices
:
List
[
int
]
# Sample token indices from logits. Empty if sampling is not required.
sample_indices
:
List
[
int
]
@
property
def
do_sample
(
self
):
return
len
(
self
.
sample_indices
)
>
0
def
__post_init__
(
self
):
if
len
(
self
.
prompt_logprob_indices
)
>
0
:
assert
self
.
sampling_params
.
prompt_logprobs
is
not
None
if
self
.
is_prompt
:
assert
self
.
prompt_len
is
not
None
assert
self
.
subquery_len
is
not
None
class
SamplingMetadata
:
class
SamplingMetadata
:
"""Metadata for input sequences. Used in sampler.
"""Metadata for input sequences. Used in sampler.
The usage is as follow;
```
hidden_states = execute_model(...)
logits = hidden_states[sampling_metadata.selected_token_indices]
sample(logits)
def sample(logits):
# Use categorized_sample_indices for sampling....
```
Args:
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_groups: List of batched sequence groups.
seq_data: Seq_id -> SequenceData.
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
prompt_lens: Lengths of prompts.
logits from the initial model output hidden states.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indices to sample.
categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
Each token indices is 2D tensor of (num_indices, num_indices) where
perform_sampling: Whether to perform sampling. This option is used to
the first item means the sample index within the returned logit
make the sampling only happens in the driver worker, and disable
(before pruning padding), and the second item means the sample
sampling in other worker processes.
index after pruning using selected_token_indices.
For example, if the returned logit is [1, 2, 3], and we select
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
seq_groups
:
Optional
[
List
[
Tuple
[
List
[
int
],
SamplingParams
]]],
seq_groups
:
List
[
SequenceGroupToSample
],
seq_data
:
Optional
[
Dict
[
int
,
SequenceData
]],
prompt_lens
:
Optional
[
List
[
int
]],
selected_token_indices
:
torch
.
Tensor
,
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Optional
[
Dict
[
SamplingType
,
torch
.
Tensor
]],
categorized_sample_indices
:
Dict
[
SamplingType
,
torch
.
Tensor
],
generators
:
Optional
[
List
[
torch
.
Generator
]]
=
None
,
num_prompts
:
int
,
perform_sampling
:
bool
=
True
,
)
->
None
:
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_groups
=
seq_groups
self
.
seq_data
=
seq_data
self
.
prompt_lens
=
prompt_lens
self
.
selected_token_indices
=
selected_token_indices
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
generators
=
generators
self
.
num_prompts
=
num_prompts
self
.
perform_sampling
=
perform_sampling
self
.
num_prompts
=
len
(
prompt_lens
)
if
prompt_lens
is
not
None
else
0
@
staticmethod
def
prepare
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
subquery_lens
:
Optional
[
List
[
int
]],
device
:
str
,
pin_memory
:
bool
,
)
->
"SamplingMetadata"
:
(
seq_groups
,
selected_token_indices
,
categorized_sample_indices
,
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
,
device
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
device
,
pin_memory
=
pin_memory
)
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
async_tensor_h2d
(
seq_ids
,
dtype
=
torch
.
int
,
target_device
=
device
,
pin_memory
=
pin_memory
),
2
,
2
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
num_prompts
=
num_prompts
,
)
return
sampling_metadata
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
return
(
"SamplingMetadata("
"SamplingMetadata("
f
"seq_groups=
{
self
.
seq_groups
}
, "
f
"seq_groups=
{
self
.
seq_groups
}
, "
f
"seq_data=
{
self
.
seq_data
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
), "
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
), "
)
f
"perform_sampling=
{
self
.
perform_sampling
}
)"
)
def
_prepare_seq_groups
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
subquery_lens
:
Optional
[
List
[
int
]],
device
:
str
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
"""Prepare sequence groups and indices for sampling.
Args:
seq_group_metadata_list: A list of sequence group to batch.
prompt_lens: A list of prompt lens per sequence group.
Index of prompt len should match with seq_group_metadata_list.
subquery_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
`SequenceGroupToSample.generator`.
Returns:
seq_groups: A list of sequence group to sample.
selected_token_indices: See the definition from `SamplingMetadata`.
categorized_sample_indices: See the definition from `SamplingMetadata`.
num_prompts: Total number of prompts from `seq_group_metadata_list`.
"""
# Batched sequence groups for the current model forward stsep.
seq_groups
:
List
[
SequenceGroupToSample
]
=
[]
# A list of token indices to sample/compute logprob. It is used to
# prune the outcome logits from the model for the performance.
selected_token_indices
:
List
[
int
]
=
[]
# Used for selected_token_indices.
model_output_idx
=
0
# Sampling type -> (
# indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits)
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
# Index of logits to compute logprob. Logits include both prompt logprob
# and sample logprob indices.
logit_idx
=
0
# Index to sample from a sample tensor. It is used by triton sample kernel.
# See `_sample_with_triton_kernel` for more details.
sample_idx
=
0
# Total number of prompts from given sequence groups.
num_prompts
=
0
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
is_prompt
=
seq_group_metadata
.
is_prompt
generator
:
Optional
[
torch
.
Generator
]
=
None
# If the current seq group is in decode stage, it is None.
prompt_len
:
Optional
[
int
]
=
None
subquery_len
:
Optional
[
int
]
=
None
prompt_logprob_indices
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
do_sample
=
seq_group_metadata
.
do_sample
if
seq_group_metadata
.
is_prompt
:
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
sampling_params
.
seed
)
num_prompts
+=
1
num_prefill_sample
=
len
(
seq_ids
)
assert
num_prefill_sample
==
1
assert
subquery_lens
is
not
None
and
prompt_lens
is
not
None
subquery_len
,
prompt_len
=
subquery_lens
[
i
],
prompt_lens
[
i
]
# If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob.
prompt_logprob_len
=
(
subquery_len
-
num_prefill_sample
if
do_sample
else
subquery_len
)
sample_len
=
num_prefill_sample
if
do_sample
else
0
else
:
# Decode
prompt_logprob_len
=
0
sample_len
=
len
(
seq_ids
)
if
do_sample
else
0
# Update indices to select from the model output.
"""
This blocks computes selected_token_indices which is used in the
following way.
hidden_states = model(...)
logits = hidden_states[selected_token_indices]
"""
if
sampling_params
.
prompt_logprobs
:
selected_token_indices
.
extend
(
range
(
model_output_idx
,
model_output_idx
+
prompt_logprob_len
))
model_output_idx
+=
prompt_logprob_len
if
do_sample
:
selected_token_indices
.
extend
(
range
(
model_output_idx
,
model_output_idx
+
sample_len
))
model_output_idx
+=
sample_len
# We now find indices for logprob computation and sampling.
"""
This block computes categorized_sample_indices which is used in the
following way.
hidden_states = model(...)
logits = hidden_states[selected_token_indices]
def sample(logits):
# Use categorized_sample_indices for sampling.
# prompt_logprob_indices to find prompt logprob indices.
# sample_indices to find sample indices.
"""
if
sampling_params
.
prompt_logprobs
is
not
None
:
prompt_logprob_indices
.
extend
(
range
(
logit_idx
,
logit_idx
+
prompt_logprob_len
))
logit_idx
+=
prompt_logprob_len
if
do_sample
:
sample_indices
.
extend
(
range
(
logit_idx
,
logit_idx
+
sample_len
))
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
list
(
zip
(
range
(
logit_idx
,
logit_idx
+
sample_len
),
range
(
sample_idx
,
sample_idx
+
sample_len
))))
logit_idx
+=
sample_len
sample_idx
+=
sample_len
if
sampling_params
.
seed
is
not
None
:
generator
=
seq_group_metadata
.
state
.
generator
seq_groups
.
append
(
SequenceGroupToSample
(
seq_ids
=
seq_ids
,
sampling_params
=
sampling_params
,
seq_data
=
seq_group_metadata
.
seq_data
,
prompt_len
=
prompt_len
,
subquery_len
=
subquery_len
,
generator
=
generator
,
is_prompt
=
is_prompt
,
prompt_logprob_indices
=
list
(
prompt_logprob_indices
),
sample_indices
=
list
(
sample_indices
)))
return
(
seq_groups
,
selected_token_indices
,
categorized_sample_indices
,
num_prompts
)
@
dataclass
@
dataclass
...
@@ -112,11 +330,10 @@ class SamplingTensors:
...
@@ -112,11 +330,10 @@ class SamplingTensors:
seeds_to_generate
=
(
extra_seeds_to_generate
+
seeds_to_generate
=
(
extra_seeds_to_generate
+
get_num_triton_sampler_splits
(
vocab_size
))
get_num_triton_sampler_splits
(
vocab_size
))
sample_indices_start_idx
=
0
assert
sampling_metadata
.
seq_groups
is
not
None
assert
sampling_metadata
.
seq_groups
is
not
None
assert
sampling_metadata
.
seq_
data
is
not
None
for
seq_group
in
sampling_metadata
.
seq_
groups
:
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
=
seq_group
.
seq_ids
seq_ids
,
sampling_params
=
seq_group
sampling_params
=
seq_group
.
sampling_params
temperature
=
sampling_params
.
temperature
temperature
=
sampling_params
.
temperature
p
=
sampling_params
.
presence_penalty
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
f
=
sampling_params
.
frequency_penalty
...
@@ -145,45 +362,46 @@ class SamplingTensors:
...
@@ -145,45 +362,46 @@ class SamplingTensors:
or
abs
(
r
-
1.0
)
>=
_SAMPLING_EPS
):
or
abs
(
r
-
1.0
)
>=
_SAMPLING_EPS
):
do_penalties
=
True
do_penalties
=
True
if
(
i
<
sampling_metadata
.
num_prompts
is_prompt
=
seq_group
.
is_prompt
if
(
seq_group
.
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get
# For tokens in the prompt that we only need to get
# their logprobs
# their logprobs
assert
sampling_metadata
.
prompt_lens
is
not
None
subquery_len
=
seq_group
.
subquery_len
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
assert
subquery_len
is
not
None
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
prefill_len
=
len
(
seq_group
.
prompt_logprob_indices
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
temperatures
+=
[
temperature
]
*
prefill_len
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
prefill_len
min_ps
+=
[
min_p
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
prefill_len
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
min_ps
+=
[
min_p
]
*
prefill_len
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
presence_penalties
+=
[
0
]
*
prefill_len
repetition_penalties
+=
[
1
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
prefill_len
prompt_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
repetition_penalties
+=
[
1
]
*
prefill_len
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
prompt_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
for
seq_id
in
seq_ids
:
output_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
seq_data
=
sampling_metadata
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
if
seq_group
.
do_sample
:
output_tokens
.
append
(
seq_data
.
output_token_ids
)
sample_lens
=
len
(
seq_group
.
sample_indices
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
assert
sample_lens
==
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
for
seq_id
in
seq_ids
:
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
seq_data
=
seq_group
.
seq_data
[
seq_id
]
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
is_prompt
=
i
<
sampling_metadata
.
num_prompts
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
if
is_prompt
:
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
prompt_best_of
.
append
(
sampling_params
.
best_of
)
assert
sampling_metadata
.
prompt_lens
is
not
None
subquery_len
=
seq_group
.
subquery_len
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
assert
subquery_len
is
not
None
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: the sampling position is the last token
# in the prompt
sample_indices_start_idx
+=
prompt_len
-
1
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
seq_data
=
s
ampling_metadata
.
seq_data
[
seq_id
]
seq_data
=
s
eq_group
.
seq_data
[
seq_id
]
extra_entropy
=
extra_entropy
or
()
extra_entropy
=
extra_entropy
or
()
seq_seeds
=
cls
.
_get_sequence_seeds
(
seq_seeds
=
cls
.
_get_sequence_seeds
(
seed
,
seed
,
...
@@ -193,8 +411,7 @@ class SamplingTensors:
...
@@ -193,8 +411,7 @@ class SamplingTensors:
seeds_to_generate
=
seeds_to_generate
,
seeds_to_generate
=
seeds_to_generate
,
is_greedy
=
is_greedy
)
is_greedy
=
is_greedy
)
sampling_seeds
.
append
(
seq_seeds
)
sampling_seeds
.
append
(
seq_seeds
)
sample_indices
.
append
(
sample_indices_start_idx
)
sample_indices
.
extend
(
seq_group
.
sample_indices
)
sample_indices_start_idx
+=
1
sampling_tensors
=
SamplingTensors
.
from_lists
(
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
...
@@ -217,12 +434,14 @@ class SamplingTensors:
...
@@ -217,12 +434,14 @@ class SamplingTensors:
# Note that the performance will be very bad without
# Note that the performance will be very bad without
# pinned memory.
# pinned memory.
pin_memory
=
is_pin_memory_available
()
pin_memory
=
is_pin_memory_available
()
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_max_len
=
max
([
len
(
tokens
)
for
tokens
in
prompt_tokens
],
default
=
0
)
prompt_padded_tokens
=
[
prompt_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
for
tokens
in
prompt_tokens
for
tokens
in
prompt_tokens
]
]
output_max_len
=
max
(
len
(
tokens
)
for
tokens
in
output_tokens
)
output_max_len
=
max
([
len
(
tokens
)
for
tokens
in
output_tokens
],
default
=
0
)
output_padded_tokens
=
[
output_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
output_max_len
-
len
(
tokens
))
tokens
+
[
vocab_size
]
*
(
output_max_len
-
len
(
tokens
))
for
tokens
in
output_tokens
for
tokens
in
output_tokens
...
...
vllm/sequence.py
View file @
603ad848
...
@@ -28,7 +28,10 @@ class Logprob:
...
@@ -28,7 +28,10 @@ class Logprob:
decoded_token
:
Optional
[
str
]
=
None
decoded_token
:
Optional
[
str
]
=
None
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs
=
List
[
Optional
[
Dict
[
int
,
Logprob
]]]
PromptLogprobs
=
List
[
Optional
[
Dict
[
int
,
Logprob
]]]
# {token_id -> logprob} for each sequence group.
SampleLogprobs
=
List
[
Dict
[
int
,
Logprob
]]
SampleLogprobs
=
List
[
Dict
[
int
,
Logprob
]]
...
@@ -215,7 +218,7 @@ class Sequence:
...
@@ -215,7 +218,7 @@ class Sequence:
self
.
eos_token_id
=
eos_token_id
self
.
eos_token_id
=
eos_token_id
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
data
=
SequenceData
(
prompt_token_ids
)
self
.
data
:
SequenceData
=
SequenceData
(
prompt_token_ids
)
self
.
output_logprobs
:
SampleLogprobs
=
[]
self
.
output_logprobs
:
SampleLogprobs
=
[]
self
.
output_text
=
""
self
.
output_text
=
""
...
@@ -559,6 +562,9 @@ class SequenceGroupMetadata:
...
@@ -559,6 +562,9 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
block_tables: The block tables. (Seq id -> list of physical block
numbers)
numbers)
do_sample: True if sampling is required. Sampling is not required when
e.g., prefill is chunked, and the current iteration only computes
query tokens for prefill, we don't need sampling.
token_chunk_size: The number of tokens to be processed (per sequence).
token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required.
None if chunking is not required.
state: Internal state tied to this sequence group.
state: Internal state tied to this sequence group.
...
@@ -573,6 +579,7 @@ class SequenceGroupMetadata:
...
@@ -573,6 +579,7 @@ class SequenceGroupMetadata:
seq_data
:
Dict
[
int
,
SequenceData
],
seq_data
:
Dict
[
int
,
SequenceData
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
block_tables
:
Dict
[
int
,
List
[
int
]],
do_sample
:
bool
=
True
,
token_chunk_size
:
Optional
[
int
]
=
None
,
token_chunk_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
...
@@ -589,6 +596,7 @@ class SequenceGroupMetadata:
...
@@ -589,6 +596,7 @@ class SequenceGroupMetadata:
self
.
multi_modal_data
=
multi_modal_data
self
.
multi_modal_data
=
multi_modal_data
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
self
.
_token_chunk_size
=
token_chunk_size
self
.
_token_chunk_size
=
token_chunk_size
self
.
do_sample
=
do_sample
if
self
.
_token_chunk_size
is
None
:
if
self
.
_token_chunk_size
is
None
:
if
is_prompt
:
if
is_prompt
:
...
@@ -650,6 +658,7 @@ class SequenceGroupOutput:
...
@@ -650,6 +658,7 @@ class SequenceGroupOutput:
prompt_logprobs
:
Optional
[
PromptLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
)
->
None
:
)
->
None
:
self
.
samples
=
samples
self
.
samples
=
samples
# Prompt logprob for each prompt query token.
self
.
prompt_logprobs
=
prompt_logprobs
self
.
prompt_logprobs
=
prompt_logprobs
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
...
...
vllm/worker/cpu_model_runner.py
View file @
603ad848
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -10,9 +10,8 @@ from vllm.distributed import broadcast_tensor_dict
...
@@ -10,9 +10,8 @@ from vllm.distributed import broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
make_tensor_with_pad
,
maybe_expand_dim
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -38,6 +37,8 @@ class CPUModelRunner:
...
@@ -38,6 +37,8 @@ class CPUModelRunner:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert
self
.
scheduler_config
.
chunked_prefill_enabled
is
False
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
vision_language_config
=
vision_language_config
self
.
vision_language_config
=
vision_language_config
self
.
load_config
=
load_config
self
.
load_config
=
load_config
...
@@ -252,99 +253,6 @@ class CPUModelRunner:
...
@@ -252,99 +253,6 @@ class CPUModelRunner:
attn_metadata
,
attn_metadata
,
)
)
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
if
seq_group_metadata
.
is_prompt
:
assert
len
(
seq_ids
)
==
1
subquery_len
=
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
(
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
))
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
subquery_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
subquery_len
-
1
)
selected_token_start_idx
+=
subquery_len
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
sampling_params
.
seed
)
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
num_seqs
))
selected_token_start_idx
+=
num_seqs
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
zip
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
+
num_seqs
),
range
(
categorized_sampled_token_indices_start_idx
,
categorized_sampled_token_indices_start_idx
+
num_seqs
)))
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
if
sampling_params
.
seed
is
not
None
:
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
dtype
=
torch
.
long
)
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
torch
.
tensor
(
seq_ids
,
dtype
=
torch
.
int
),
2
,
2
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
generators
=
generators
,
)
return
sampling_metadata
def
prepare_input_tensors
(
def
prepare_input_tensors
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
@@ -364,8 +272,15 @@ class CPUModelRunner:
...
@@ -364,8 +272,15 @@ class CPUModelRunner:
(
input_tokens
,
input_positions
,
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
prompt_lens
=
[]
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
)
seq_group_metadata_list
,
prompt_lens
,
# subquery_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use prompt_lens instead.
prompt_lens
,
self
.
device
,
pin_memory
=
False
)
# Broadcast the metadata.
# Broadcast the metadata.
metadata_dict
=
{
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
"input_tokens"
:
input_tokens
,
...
@@ -389,7 +304,6 @@ class CPUModelRunner:
...
@@ -389,7 +304,6 @@ class CPUModelRunner:
selected_token_indices
=
selected_token_indices
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
categorized_sample_indices
=
None
,
generators
=
None
,
generators
=
None
,
perform_sampling
=
False
,
)
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
return
(
input_tokens
,
input_positions
,
attn_metadata
,
...
@@ -421,7 +335,7 @@ class CPUModelRunner:
...
@@ -421,7 +335,7 @@ class CPUModelRunner:
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Only perform sampling in the driver worker.
# Only perform sampling in the driver worker.
if
not
s
ampling_metadata
.
perform_sampling
:
if
not
s
elf
.
is_driver_worker
:
return
None
return
None
# Sample the next token.
# Sample the next token.
...
...
vllm/worker/model_runner.py
View file @
603ad848
...
@@ -20,12 +20,11 @@ from vllm.lora.request import LoRARequest
...
@@ -20,12 +20,11 @@ from vllm.lora.request import LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
SequenceData
,
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
async_tensor_h2d
,
is_hip
,
from
vllm.utils
import
(
CudaMemoryProfiler
,
is_hip
,
is_pin_memory_available
,
is_pin_memory_available
,
make_tensor_with_pad
,
make_tensor_with_pad
)
maybe_expand_dim
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -547,108 +546,6 @@ class ModelRunner:
...
@@ -547,108 +546,6 @@ class ModelRunner:
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
)
)
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
subquery_lens
:
Optional
[
List
[
int
]],
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
if
seq_group_metadata
.
is_prompt
:
assert
len
(
seq_ids
)
==
1
assert
subquery_lens
is
not
None
subquery_len
=
subquery_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
(
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
))
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
subquery_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
subquery_len
-
1
)
selected_token_start_idx
+=
subquery_len
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
sampling_params
.
seed
)
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
num_seqs
))
selected_token_start_idx
+=
num_seqs
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
list
(
zip
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
+
num_seqs
),
range
(
categorized_sampled_token_indices_start_idx
,
categorized_sampled_token_indices_start_idx
+
num_seqs
))))
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
if
sampling_params
.
seed
is
not
None
:
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
)
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
async_tensor_h2d
(
seq_ids
,
dtype
=
torch
.
int
,
target_device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
),
2
,
2
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
generators
=
generators
,
)
return
sampling_metadata
def
prepare_input_tensors
(
def
prepare_input_tensors
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
@@ -685,9 +582,9 @@ class ModelRunner:
...
@@ -685,9 +582,9 @@ class ModelRunner:
decode_lora_requests
,
decode_lora_requests
,
decode_slot_mapping
,
decode_slot_mapping
,
)
=
self
.
_prepare_decode
(
decode_reqs
)
)
=
self
.
_prepare_decode
(
decode_reqs
)
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt
_lens
,
seq_group_metadata_list
,
prompt_lens
,
subquery
_lens
,
subquery_lens
)
self
.
device
,
self
.
pin_memory
)
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
:
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
:
assert
(
len
(
prefill_reqs
)
and
len
(
decode_reqs
))
==
0
assert
(
len
(
prefill_reqs
)
and
len
(
decode_reqs
))
==
0
...
@@ -788,12 +685,9 @@ class ModelRunner:
...
@@ -788,12 +685,9 @@ class ModelRunner:
**
metadata_dict
)
**
metadata_dict
)
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_groups
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
categorized_sample_indices
=
None
,
generators
=
None
,
num_prompts
=
0
,
perform_sampling
=
False
,
)
)
# if it is a mixed batch, decode attn_metadata is broadcasted
# if it is a mixed batch, decode attn_metadata is broadcasted
...
@@ -852,7 +746,7 @@ class ModelRunner:
...
@@ -852,7 +746,7 @@ class ModelRunner:
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Only perform sampling in the driver worker.
# Only perform sampling in the driver worker.
if
not
s
ampling_metadata
.
perform_sampling
:
if
not
s
elf
.
is_driver_worker
:
return
None
return
None
# Sample the next token.
# Sample the next token.
...
@@ -860,6 +754,7 @@ class ModelRunner:
...
@@ -860,6 +754,7 @@ class ModelRunner:
logits
=
logits
,
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
return
output
return
output
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
vllm/worker/neuron_model_runner.py
View file @
603ad848
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -8,10 +8,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
...
@@ -8,10 +8,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader.neuron
import
get_neuron_model
from
vllm.model_executor.model_loader.neuron
import
get_neuron_model
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
maybe_expand_dim
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -141,106 +139,6 @@ class NeuronModelRunner:
...
@@ -141,106 +139,6 @@ class NeuronModelRunner:
return
input_tokens
,
input_positions
,
input_block_ids
return
input_tokens
,
input_positions
,
input_block_ids
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
if
seq_group_metadata
.
is_prompt
:
assert
len
(
seq_ids
)
==
1
assert
prompt_lens
is
not
None
prompt_len
=
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
prompt_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
(
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
))
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
prompt_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
selected_token_start_idx
+=
prompt_len
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
sampling_params
.
seed
)
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
num_seqs
))
selected_token_start_idx
+=
num_seqs
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
zip
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
+
num_seqs
),
range
(
categorized_sampled_token_indices_start_idx
,
categorized_sampled_token_indices_start_idx
+
num_seqs
)))
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
if
sampling_params
.
seed
is
not
None
:
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
)
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
async_tensor_h2d
(
seq_ids
,
dtype
=
torch
.
int
,
target_device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
),
2
,
2
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
generators
=
generators
,
)
return
sampling_metadata
def
prepare_input_tensors
(
def
prepare_input_tensors
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
@@ -256,8 +154,15 @@ class NeuronModelRunner:
...
@@ -256,8 +154,15 @@ class NeuronModelRunner:
(
input_tokens
,
input_positions
,
(
input_tokens
,
input_positions
,
input_block_ids
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
input_block_ids
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
prompt_lens
=
[]
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
)
seq_group_metadata_list
,
prompt_lens
,
# subquery_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use prompt_lens instead.
prompt_lens
,
self
.
device
,
self
.
pin_memory
)
return
(
input_tokens
,
input_positions
,
input_block_ids
,
return
(
input_tokens
,
input_positions
,
input_block_ids
,
sampling_metadata
)
sampling_metadata
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment