Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
30172b49
Unverified
Commit
30172b49
authored
Feb 18, 2025
by
Nick Hill
Committed by
GitHub
Feb 18, 2025
Browse files
[V1] Optimize handling of sampling metadata and req_ids list (#13244)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
a4d577b3
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
255 additions
and
298 deletions
+255
-298
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+2
-7
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+18
-26
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+21
-26
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+22
-11
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+3
-3
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+4
-2
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+10
-11
vllm/v1/sample/ops/penalties.py
vllm/v1/sample/ops/penalties.py
+6
-7
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+20
-28
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+2
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+6
-7
vllm/v1/utils.py
vllm/v1/utils.py
+11
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+102
-111
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+28
-57
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+0
-2
No files found.
tests/v1/sample/test_rejection_sampler.py
View file @
30172b49
...
...
@@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int],
def
create_sampling_metadata
(
spec_tokens
:
List
[
List
[
int
]])
->
SamplingMetadata
:
batch_size
=
len
(
spec_tokens
)
return
SamplingMetadata
(
temperature
=
0.0
,
temperature
=
torch
.
tensor
([])
,
all_greedy
=
True
,
all_random
=
False
,
rejection_sampling
=
True
,
spec_token_ids
=
spec_tokens
,
top_p
=
None
,
top_k
=
None
,
no_top_p
=
False
,
no_top_k
=
False
,
min_p
=
torch
.
empty
(
batch_size
,
),
no_min_p
=
True
,
generators
=
{},
max_num_logprobs
=
0
,
no_penalties
=
False
,
...
...
@@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
presence_penalties
=
torch
.
tensor
([]),
repetition_penalties
=
torch
.
tensor
([]),
output_token_ids
=
[],
min_tokens
=
[],
stop_token_ids
=
[],
min_tokens
=
{},
logit_bias
=
[
None
]
*
batch_size
,
)
...
...
tests/v1/sample/test_sampler.py
View file @
30172b49
...
...
@@ -77,25 +77,20 @@ def _create_default_sampling_metadata(
temperature
=
torch
.
full
((
batch_size
,
),
0.0
),
all_greedy
=
True
,
all_random
=
False
,
rejection_sampling
=
False
,
top_p
=
torch
.
empty
(
batch_size
,
),
top_k
=
torch
.
empty
(
batch_size
,
),
no_top_p
=
True
,
no_top_k
=
True
,
min_p
=
torch
.
empty
(
batch_size
,
),
no_min_p
=
True
,
top_p
=
None
,
top_k
=
None
,
min_p
=
None
,
generators
=
{},
max_num_logprobs
=
0
,
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
vocab_size
,
device
),
output_token_ids
=
output_token_ids
,
spec_token_ids
=
[]
,
spec_token_ids
=
None
,
frequency_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
presence_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
repetition_penalties
=
_create_penalty_tensor
(
batch_size
,
1.0
,
device
),
no_penalties
=
True
,
min_tokens
=
[],
stop_token_ids
=
[],
min_tokens
=
{},
logit_bias
=
[
None
]
*
batch_size
,
)
return
fake_sampling_metadata
...
...
@@ -104,10 +99,10 @@ def _create_default_sampling_metadata(
def
_generate_min_token_penalties_and_stop_tokens
(
num_output_tokens
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
batch_indices_for_min_token_penalty
:
List
[
int
]
)
->
Tuple
[
List
[
int
]
,
List
[
Set
[
int
]]]:
)
->
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]]:
"""
Generates and returns a
lis
t of minimum token penalties
(`min_tokens`)
and a
corresponding
list of
stop token IDs (`stop_token_ids`) for each
Generates and returns a
dic
t of minimum token penalties
and
corresponding stop token IDs (
`min_tokens`,
`stop_token_ids`) for each
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
...
...
@@ -115,22 +110,19 @@ def _generate_min_token_penalties_and_stop_tokens(
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids
:
List
[
Set
[
int
]]
=
[]
min_tokens
:
List
[
int
]
=
[]
min_tokens
:
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]]
=
{}
for
index
in
range
(
batch_size
):
if
index
in
batch_indices_for_min_token_penalty
:
min_tokens
.
append
(
min_tokens
[
index
]
=
(
np
.
random
.
randint
(
num_output_tokens
+
1
,
2
*
num_output_tokens
))
stop_token_ids
.
append
(
2
*
num_output_tokens
),
set
(
np
.
random
.
randint
(
0
,
vocab_size
-
1
)
for
_
in
range
(
np
.
random
.
randint
(
0
,
vocab_size
))))
else
:
min_tokens
.
append
(
np
.
random
.
randint
(
0
,
num_output_tokens
))
stop_token_ids
.
append
(
set
())
return
(
min_tokens
,
stop_token_ids
)
min_tokens
[
index
]
=
(
np
.
random
.
randint
(
0
,
num_output_tokens
),
set
())
return
min_tokens
def
_create_weighted_output_token_list
(
...
...
@@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
output_token_ids_for_batch
.
extend
(
[
token_id
for
_
in
range
(
index
+
1
)])
output_token_ids
.
append
(
output_token_ids_for_batch
)
return
(
output_token_ids
,
sorted_token_ids_in_output
)
return
output_token_ids
,
sorted_token_ids_in_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
...
...
@@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
batch_indices_for_min_token_penalty
=
np
.
random
.
randint
(
0
,
batch_size
-
1
,
size
=
np
.
random
.
randint
(
0
,
batch_size
)).
tolist
()
min_tokens
,
stop_token_ids
=
_generate_min_token_penalties_and_stop_tokens
(
min_tokens
=
_generate_min_token_penalties_and_stop_tokens
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
batch_indices_for_min_token_penalty
)
sampling_metadata
.
min_tokens
=
min_tokens
sampling_metadata
.
stop_token_ids
=
stop_token_ids
sampler
=
Sampler
()
logits
=
sampler
.
apply_penalties
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
for
token_id
in
range
(
VOCAB_SIZE
):
if
token_id
in
stop_token_ids
[
batch_idx
]:
_
,
stop_token_ids
=
min_tokens
.
get
(
batch_idx
,
(
0
,
set
()))
if
token_id
in
stop_token_ids
:
assert
logits
[
batch_idx
][
token_id
]
==
-
float
(
"inf"
)
else
:
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
30172b49
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
numpy
as
np
import
pytest
...
...
@@ -41,7 +41,7 @@ def _remove_requests(
for
index
in
req_indices_to_remove
:
input_batch
.
remove_request
(
reqs
[
index
].
req_id
)
req_ids_to_remove
.
add
(
reqs
[
index
].
req_id
)
return
(
req_ids_to_remove
,
req_indices_to_remove_list
)
return
req_ids_to_remove
,
req_indices_to_remove_list
def
_construct_expected_sampling_metadata
(
...
...
@@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
top_p
=
[
0.0
for
_
in
range
(
num_reqs
)]
min_p
=
[
0.0
for
_
in
range
(
num_reqs
)]
temperature
=
[
0.0
for
_
in
range
(
num_reqs
)]
stop_token_ids
:
List
[
Set
[
int
]]
=
[
set
()
for
_
in
range
(
num_reqs
)]
min_tokens
=
[
0
for
_
in
range
(
num_reqs
)]
min_tokens
=
{}
logit_bias
=
[
None
]
*
num_reqs
for
req
in
reqs
:
if
req
.
req_id
not
in
req_ids_retained
:
...
...
@@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
top_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_p
min_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
min_p
temperature
[
index_in_input_batch
]
=
req
.
sampling_params
.
temperature
stop
_token
_ids
[
index_in_input_batch
]
=
req
.
sampling_params
.
all_stop
_token
_id
s
min_tokens
[
index_in_input_batch
]
=
req
.
sampling_params
.
min
_token
s
min
_token
s
[
index_in_input_batch
]
=
(
req
.
sampling_params
.
min
_tokens
,
req
.
sampling_params
.
all_stop
_token
_ids
)
logit_bias
[
index_in_input_batch
]
=
req
.
sampling_params
.
logit_bias
return
SamplingMetadata
(
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
device
=
device
),
all_greedy
=
False
,
all_random
=
True
,
rejection_sampling
=
False
,
top_p
=
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
top_k
=
torch
.
tensor
(
top_k
,
dtype
=
torch
.
int
,
device
=
device
),
no_top_p
=
all
(
x
==
1.0
for
x
in
top_p
),
no_top_k
=
all
(
x
==
0
for
x
in
top_k
),
min_p
=
torch
.
tensor
(
min_p
,
dtype
=
torch
.
float
,
device
=
device
),
no_min_p
=
all
(
x
==
0.0
for
x
in
min_p
),
top_p
=
None
if
all
(
x
==
1.0
for
x
in
top_p
)
else
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
top_k
=
None
if
all
(
x
==
0
for
x
in
top_k
)
else
torch
.
tensor
(
top_k
,
dtype
=
torch
.
int
,
device
=
device
),
min_p
=
None
if
all
(
x
==
0.0
for
x
in
min_p
)
else
torch
.
tensor
(
min_p
,
dtype
=
torch
.
float
,
device
=
device
),
generators
=
{},
max_num_logprobs
=
0
,
prompt_token_ids
=
make_tensor_with_pad
(
...
...
@@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata(
dtype
=
torch
.
float
,
device
=
device
),
output_token_ids
=
output_token_ids
,
spec_token_ids
=
[]
,
spec_token_ids
=
None
,
min_tokens
=
min_tokens
,
stop_token_ids
=
stop_token_ids
,
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
and
all
(
x
==
0
for
x
in
frequency_penalties
)
and
all
(
x
==
1
for
x
in
repetition_penalties
)),
...
...
@@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch
.
condense
(
req_indices_to_remove
)
# Generate the sampling metadata
sampling_metadata
=
input_batch
.
make_sampling_metadata
(
req_id_output_token_ids
,
req_id_to_spec_token_ids
=
{},
skip_copy
=
False
)
sampling_metadata
=
input_batch
.
_make_sampling_metadata
()
# Create expected output.
expected_sampling_metadata
=
_construct_expected_sampling_metadata
(
...
...
@@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch
.
req_id_to_index
,
device
=
torch
.
device
(
device
))
def
same
(
t1
:
Optional
[
torch
.
Tensor
],
t2
:
Optional
[
torch
.
Tensor
])
->
bool
:
return
(
t1
is
None
and
t2
is
None
)
or
(
t1
is
not
None
and
t2
is
not
None
and
torch
.
allclose
(
t1
,
t2
))
# Assert the actual and expected output.
assert
torch
.
allclose
(
expected_sampling_metadata
.
temperature
,
sampling_metadata
.
temperature
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
top_p
,
sampling_metadata
.
top_p
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
)
assert
same
(
expected_sampling_metadata
.
top_p
,
sampling_metadata
.
top_p
)
assert
same
(
expected_sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
frequency_penalties
,
...
...
@@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
assert
(
expected_sampling_metadata
.
output_token_ids
==
sampling_metadata
.
output_token_ids
)
assert
expected_sampling_metadata
.
min_tokens
==
sampling_metadata
.
min_tokens
assert
expected_sampling_metadata
.
stop_token_ids
==
\
sampling_metadata
.
stop_token_ids
assert
expected_sampling_metadata
.
no_penalties
==
\
sampling_metadata
.
no_penalties
assert
expected_sampling_metadata
.
no_top_p
==
sampling_metadata
.
no_top_p
assert
expected_sampling_metadata
.
no_top_k
==
sampling_metadata
.
no_top_k
assert
expected_sampling_metadata
.
logit_bias
==
sampling_metadata
.
logit_bias
tests/v1/worker/test_gpu_model_runner.py
View file @
30172b49
...
...
@@ -5,6 +5,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.scheduler_output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
...
...
@@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool:
return
req_id
in
model_runner
.
requests
def
_is_sampling_metadata_changed
(
model_runner
,
sampling_metadata_before
:
SamplingMetadata
):
return
model_runner
.
input_batch
.
sampling_metadata
is
not
(
sampling_metadata_before
)
def
test_update_states_new_request
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
batch_changed
=
model_runner
.
_update_states
(
scheduler_output
)
assert
batch_changed
is
True
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
...
...
@@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner):
free_encoder_input_ids
=
[],
)
batch_changed
=
model_runner
.
_update_states
(
scheduler_output
)
assert
batch_changed
is
True
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
not
_is_req_added
(
model_runner
,
req_id
)
assert
not
_is_req_scheduled
(
model_runner
,
req_id
)
...
...
@@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
{}
,
finished_req_ids
=
set
()
,
free_encoder_input_ids
=
[],
)
...
...
@@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner):
free_encoder_input_ids
=
[],
)
batch_changed
=
model_runner
.
_update_states
(
scheduler_output
)
assert
batch_changed
is
True
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
...
...
@@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner):
free_encoder_input_ids
=
[],
)
batch_changed
=
model_runner
.
_update_states
(
scheduler_output
)
assert
batch_changed
is
False
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
not
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
...
...
@@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner):
free_encoder_input_ids
=
[],
)
batch_changed
=
model_runner
.
_update_states
(
scheduler_output
)
assert
batch_changed
is
True
metadata_before
=
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_ids
[
0
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
0
])
...
...
vllm/model_executor/layers/utils.py
View file @
30172b49
...
...
@@ -45,7 +45,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
vocab_size
,
num_seqs
)
output_bin_counts
,
output_mask
=
get_token_bin_counts_and_mask
(
output_tokens_tensor
,
vocab_size
,
num_seqs
)
repetition_penalties
=
repetition_penalties
.
unsqueeze
_
(
dim
=
1
).
repeat
(
repetition_penalties
=
repetition_penalties
.
unsqueeze
(
dim
=
1
).
repeat
(
1
,
vocab_size
)
logits
[
logits
>
0
]
/=
torch
.
where
(
prompt_mask
|
output_mask
,
repetition_penalties
,
1.0
)[
logits
>
0
]
...
...
@@ -53,6 +53,6 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
repetition_penalties
,
1.0
)[
logits
<=
0
]
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
-=
frequency_penalties
.
unsqueeze
_
(
dim
=
1
)
*
output_bin_counts
logits
-=
presence_penalties
.
unsqueeze
_
(
dim
=
1
)
*
output_mask
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
output_bin_counts
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
output_mask
return
logits
vllm/v1/core/scheduler.py
View file @
30172b49
...
...
@@ -195,8 +195,10 @@ class Scheduler:
request
.
num_computed_tokens
-
request
.
num_tokens
)
if
num_scheduled_spec_tokens
>
0
:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del
request
.
spec_token_ids
[
num_scheduled_spec_tokens
:]
scheduled_spec_decode_tokens
[
request
.
request_id
]
=
(
request
.
spec_token_ids
[:
num_scheduled_spec_tokens
]
)
request
.
spec_token_ids
)
# Encoder-related.
if
encoder_inputs_to_schedule
:
...
...
@@ -567,7 +569,7 @@ class Scheduler:
outputs
.
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
or
[]
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
...
...
vllm/v1/sample/metadata.py
View file @
30172b49
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Set
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
torch
...
...
@@ -12,15 +12,13 @@ class SamplingMetadata:
temperature
:
torch
.
Tensor
all_greedy
:
bool
all_random
:
bool
rejection_sampling
:
bool
spec_token_ids
:
List
[
List
[
int
]]
top_p
:
torch
.
T
ens
or
top_k
:
torch
.
Tensor
no_top_p
:
bool
no_
top_
k
:
bool
min_p
:
torch
.
Tensor
no_
min_p
:
bool
# None when there are no speculated tok
ens
.
spec_token_ids
:
Optional
[
List
[
List
[
int
]]]
top_
p
:
Optional
[
torch
.
Tensor
]
top_k
:
Optional
[
torch
.
Tensor
]
min_p
:
Optional
[
torch
.
Tensor
]
generators
:
Dict
[
int
,
torch
.
Generator
]
...
...
@@ -34,7 +32,8 @@ class SamplingMetadata:
repetition_penalties
:
torch
.
Tensor
output_token_ids
:
List
[
List
[
int
]]
min_tokens
:
List
[
int
]
stop_token_ids
:
List
[
Set
[
int
]]
# req_index -> (min_tokens, stop_token_ids)
min_tokens
:
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]]
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
vllm/v1/sample/ops/penalties.py
View file @
30172b49
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Set
,
Tuple
import
torch
...
...
@@ -8,18 +8,17 @@ from vllm.model_executor.layers.utils import apply_penalties
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
def
apply_min_token_penalties
(
logits
:
torch
.
Tensor
,
output_token_ids
:
List
[
List
[
int
]],
stop_token_ids
:
List
[
Set
[
int
]],
min_tokens
:
List
[
int
])
->
None
:
def
apply_min_token_penalties
(
logits
:
torch
.
Tensor
,
output_token_ids
:
List
[
List
[
int
]],
min_tokens
:
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]])
->
None
:
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
index
,
min_token
in
enumerate
(
min_tokens
):
for
index
,
(
min_token
,
stop_token_ids
)
in
min_tokens
.
items
(
):
if
len
(
output_token_ids
[
index
])
<
min_token
:
for
stop_token_id
in
stop_token_ids
[
index
]
:
for
stop_token_id
in
stop_token_ids
:
min_tokens_logits_to_penalize
.
append
((
index
,
stop_token_id
))
if
min_tokens_logits_to_penalize
:
logits
[
tuple
(
zip
(
*
min_tokens_logits_to_penalize
))]
=
-
float
(
"inf"
)
...
...
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
30172b49
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
from
typing
import
Dict
,
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -55,13 +55,11 @@ class TopKTopPSampler(nn.Module):
self
,
logits
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""PyTorch-native implementation of top-k and top-p sampling."""
logits
=
apply_top_k_top_p
(
logits
,
no_top_k
,
k
,
no_top_p
,
p
)
logits
=
apply_top_k_top_p
(
logits
,
k
,
p
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
...
...
@@ -69,37 +67,33 @@ class TopKTopPSampler(nn.Module):
self
,
logits
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""More optimized implementation for top-k and top-p sampling."""
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
if
no_top_k
and
no_top_p
:
if
k
is
None
and
p
is
None
:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
return
random_sample
(
probs
,
generators
)
return
flashinfer_sample
(
probs
,
no_top_k
,
k
,
no_top_p
,
p
,
generators
)
return
flashinfer_sample
(
probs
,
k
,
p
,
generators
)
def
apply_top_k_top_p
(
logits
:
torch
.
Tensor
,
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
"""
if
no_top_k
and
no_top_p
:
if
k
is
None
and
p
is
None
:
return
logits
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
if
not
no_top_k
:
if
k
is
not
None
:
# Apply top-k.
top_k_mask
=
logits_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
# Get all the top_k values.
...
...
@@ -107,7 +101,7 @@ def apply_top_k_top_p(
top_k_mask
=
logits_sort
<
top_k_mask
logits_sort
.
masked_fill_
(
top_k_mask
,
-
float
(
"inf"
))
if
not
no_top_p
:
if
p
is
not
None
:
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
...
...
@@ -147,10 +141,8 @@ def random_sample(
def
flashinfer_sample
(
probs
:
torch
.
Tensor
,
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
generators
:
Dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
"""Sample from the probabilities using FlashInfer.
...
...
@@ -167,7 +159,7 @@ def flashinfer_sample(
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
assert
not
(
no_top_k
and
no_top_p
)
assert
not
(
k
is
None
and
p
is
None
)
max_top_k_round
=
32
batch_size
=
probs
.
shape
[
0
]
uniform_samples
=
torch
.
empty
((
max_top_k_round
,
batch_size
),
...
...
@@ -178,11 +170,11 @@ def flashinfer_sample(
for
i
,
generator
in
generators
.
items
():
uniform_samples
[:,
i
].
uniform_
(
generator
=
generator
)
if
no_top_k
:
if
k
is
None
:
# Top-p only.
next_token_ids
,
success
=
flashinfer
.
sampling
.
top_p_sampling_from_probs
(
probs
,
uniform_samples
,
p
,
deterministic
=
True
)
elif
no_top_p
:
elif
p
is
None
:
# Top-k only.
next_token_ids
,
success
=
flashinfer
.
sampling
.
top_k_sampling_from_probs
(
probs
,
uniform_samples
,
k
,
deterministic
=
True
)
...
...
@@ -194,9 +186,9 @@ def flashinfer_sample(
# NOTE: CPU-GPU synchronization happens here.
if
not
success
.
all
():
if
not
no_top_k
:
if
k
is
not
None
:
probs
=
flashinfer
.
sampling
.
top_k_renorm_prob
(
probs
,
k
)
if
not
no_top_p
:
if
p
is
not
None
:
probs
=
flashinfer
.
sampling
.
top_p_renorm_prob
(
probs
,
p
)
next_token_ids
=
flashinfer
.
sampling
.
sampling_from_probs
(
probs
,
uniform_samples
[
0
],
deterministic
=
True
)
...
...
vllm/v1/sample/rejection_sampler.py
View file @
30172b49
...
...
@@ -68,6 +68,7 @@ class RejectionSampler(nn.Module):
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
assert
sampling_metadata
.
spec_token_ids
is
not
None
spec_token_ids
=
sampling_metadata
.
spec_token_ids
max_spec_len
=
max
(
len
(
s
)
for
s
in
spec_token_ids
)
batch_size
=
len
(
spec_token_ids
)
...
...
@@ -119,6 +120,7 @@ class RejectionSampler(nn.Module):
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
assert
sampling_metadata
.
spec_token_ids
is
not
None
spec_lens
=
[
len
(
x
)
for
x
in
sampling_metadata
.
spec_token_ids
]
# Add 1 to include the 'bonus' token.
sample_lens
=
[
x
+
1
for
x
in
spec_lens
]
...
...
vllm/v1/sample/sampler.py
View file @
30172b49
...
...
@@ -26,7 +26,7 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
if
sampling_metadata
.
rejection_sampling
:
if
sampling_metadata
.
spec_token_ids
:
if
sampling_metadata
.
max_num_logprobs
:
raise
NotImplementedError
(
"Rejection sampling does not support logprobs."
)
...
...
@@ -104,16 +104,14 @@ class Sampler(nn.Module):
logits
=
self
.
apply_temperature
(
logits
,
sampling_metadata
.
temperature
)
# Apply min_p.
if
not
sampling_metadata
.
no_
min_p
:
if
sampling_metadata
.
min_p
is
not
None
:
logits
=
self
.
apply_min_p
(
logits
,
sampling_metadata
.
min_p
)
# Apply top_k and/or top_p.
random_sampled
=
self
.
topk_topp_sampler
(
logits
,
sampling_metadata
.
generators
,
sampling_metadata
.
no_top_k
,
sampling_metadata
.
top_k
,
sampling_metadata
.
no_top_p
,
sampling_metadata
.
top_p
,
)
...
...
@@ -179,8 +177,9 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
apply_min_token_penalties
(
logits
,
sampling_metadata
.
output_token_ids
,
sampling_metadata
.
stop_token_ids
,
if
sampling_metadata
.
min_tokens
:
apply_min_token_penalties
(
logits
,
sampling_metadata
.
output_token_ids
,
sampling_metadata
.
min_tokens
)
if
not
sampling_metadata
.
no_penalties
:
assert
sampling_metadata
.
prompt_token_ids
is
not
None
...
...
vllm/v1/utils.py
View file @
30172b49
...
...
@@ -188,3 +188,14 @@ def bind_kv_cache(
for
layer_name
,
kv_cache
in
kv_caches
.
items
():
# NOTE: Use list because of v0 PP virtual engine.
forward_context
[
layer_name
].
kv_cache
=
[
kv_cache
]
def
copy_slice
(
from_tensor
:
torch
.
Tensor
,
to_tensor
:
torch
.
Tensor
,
length
:
int
)
->
None
:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
"""
to_tensor
[:
length
].
copy_
(
from_tensor
[:
length
],
non_blocking
=
True
)
vllm/v1/worker/gpu_input_batch.py
View file @
30172b49
# SPDX-License-Identifier: Apache-2.0
# Datastructures defining an input batch
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
cast
import
numpy
as
np
import
torch
...
...
@@ -12,6 +11,7 @@ from vllm.lora.request import LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.utils
import
copy_slice
from
vllm.v1.worker.block_table
import
BlockTable
_SAMPLING_EPS
=
1e-5
...
...
@@ -63,7 +63,7 @@ class InputBatch:
self
.
pin_memory
=
pin_memory
self
.
vocab_size
=
vocab_size
self
.
req_ids
:
List
[
Optional
[
str
]]
=
[
None
]
*
max_num_reqs
self
.
_
req_ids
:
List
[
Optional
[
str
]]
=
[
]
self
.
req_id_to_index
:
Dict
[
str
,
int
]
=
{}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
...
...
@@ -171,11 +171,8 @@ class InputBatch:
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_reqs
:
Set
[
str
]
=
set
()
self
.
min_tokens
:
List
[
int
]
=
[
0
]
*
max_num_reqs
self
.
stop_token_ids
:
List
[
Set
[
int
]]
=
[
set
()
for
_
in
range
(
max_num_reqs
)
]
self
.
prompt_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# req_index -> (min_tokens, stop_token_ids)
self
.
min_tokens
:
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]]
=
{}
# lora related
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,
),
...
...
@@ -196,6 +193,17 @@ class InputBatch:
self
.
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
=
[
None
]
*
max_num_reqs
self
.
req_output_token_ids
:
List
[
Optional
[
List
[
int
]]]
=
[]
# This is updated each time the batch constituents change.
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
@
property
def
req_ids
(
self
)
->
List
[
str
]:
# None elements should only be present transiently
# while performing state updates to the batch.
return
cast
(
List
[
str
],
self
.
_req_ids
)
def
add_request
(
self
,
request
:
"CachedRequestState"
,
...
...
@@ -206,7 +214,13 @@ class InputBatch:
assert
req_index
<
self
.
max_num_reqs
req_id
=
request
.
req_id
self
.
req_ids
[
req_index
]
=
req_id
if
req_index
==
len
(
self
.
_req_ids
):
self
.
_req_ids
.
append
(
req_id
)
self
.
req_output_token_ids
.
append
(
request
.
output_token_ids
)
else
:
self
.
_req_ids
[
req_index
]
=
req_id
self
.
req_output_token_ids
[
req_index
]
=
request
.
output_token_ids
self
.
req_id_to_index
[
req_id
]
=
req_index
# Copy the prompt token ids and output token ids.
...
...
@@ -255,8 +269,9 @@ class InputBatch:
req_index
]
=
sampling_params
.
repetition_penalty
if
sampling_params
.
repetition_penalty
!=
1.0
:
self
.
repetition_penalties_reqs
.
add
(
req_id
)
self
.
min_tokens
[
req_index
]
=
sampling_params
.
min_tokens
self
.
stop_token_ids
[
req_index
]
=
sampling_params
.
all_stop_token_ids
if
sampling_params
.
min_tokens
:
self
.
min_tokens
[
req_index
]
=
(
sampling_params
.
min_tokens
,
sampling_params
.
all_stop_token_ids
)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
...
...
@@ -284,16 +299,20 @@ class InputBatch:
self
.
request_lora_mapping
[
req_index
]
=
0
def
remove_request
(
self
,
req_id
:
str
)
->
Optional
[
int
]:
"""This method must always be followed by a call to condense()."""
req_index
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
if
req_index
is
None
:
return
None
self
.
req_ids
[
req_index
]
=
None
self
.
_req_ids
[
req_index
]
=
None
self
.
req_output_token_ids
[
req_index
]
=
None
self
.
greedy_reqs
.
discard
(
req_id
)
self
.
random_reqs
.
discard
(
req_id
)
self
.
top_p_reqs
.
discard
(
req_id
)
self
.
top_k_reqs
.
discard
(
req_id
)
self
.
min_p_reqs
.
discard
(
req_id
)
self
.
min_tokens
.
pop
(
req_index
,
None
)
self
.
frequency_penalties_reqs
.
discard
(
req_id
)
self
.
presence_penalties_reqs
.
discard
(
req_id
)
self
.
repetition_penalties_reqs
.
discard
(
req_id
)
...
...
@@ -313,33 +332,17 @@ class InputBatch:
self
.
logit_bias
[
req_index
]
=
None
return
req_index
def
clear
(
self
)
->
None
:
self
.
req_ids
=
[
None
]
*
self
.
max_num_reqs
self
.
req_id_to_index
.
clear
()
self
.
greedy_reqs
.
clear
()
self
.
random_reqs
.
clear
()
self
.
top_p_reqs
.
clear
()
self
.
top_k_reqs
.
clear
()
self
.
min_p_reqs
.
clear
()
self
.
frequency_penalties_reqs
.
clear
()
self
.
presence_penalties_reqs
.
clear
()
self
.
repetition_penalties_reqs
.
clear
()
self
.
generators
.
clear
()
self
.
num_logprobs
.
clear
()
self
.
num_prompt_logprobs
.
clear
()
self
.
request_lora_mapping
.
fill
(
0
)
self
.
lora_id_to_lora_request
.
clear
()
self
.
lora_id_to_request_ids
.
clear
()
self
.
logit_bias
=
[
None
]
*
self
.
max_num_reqs
def
condense
(
self
,
empty_req_indices
:
List
[
int
])
->
None
:
if
self
.
num_reqs
==
0
:
num_reqs
=
self
.
num_reqs
if
num_reqs
==
0
:
# The batched states are empty.
self
.
_req_ids
.
clear
()
self
.
req_output_token_ids
.
clear
()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index
=
self
.
num_reqs
+
len
(
empty_req_indices
)
-
1
last_req_index
=
num_reqs
+
len
(
empty_req_indices
)
-
1
while
empty_req_indices
:
# Find the largest non-empty index.
while
last_req_index
in
empty_req_indices
:
...
...
@@ -351,10 +354,13 @@ class InputBatch:
break
# Swap the states.
req_id
=
self
.
req_ids
[
last_req_index
]
req_id
=
self
.
_req_ids
[
last_req_index
]
output_token_ids
=
self
.
req_output_token_ids
[
last_req_index
]
assert
req_id
is
not
None
self
.
req_ids
[
empty_index
]
=
req_id
self
.
req_ids
[
last_req_index
]
=
None
self
.
_req_ids
[
empty_index
]
=
req_id
self
.
_req_ids
[
last_req_index
]
=
None
self
.
req_output_token_ids
[
empty_index
]
=
output_token_ids
self
.
req_output_token_ids
[
last_req_index
]
=
None
self
.
req_id_to_index
[
req_id
]
=
empty_index
num_tokens
=
self
.
num_tokens
[
last_req_index
]
...
...
@@ -379,13 +385,14 @@ class InputBatch:
self
.
repetition_penalties_cpu
[
empty_index
]
=
self
.
repetition_penalties_cpu
[
last_req_index
]
self
.
min_p_cpu
[
empty_index
]
=
self
.
min_p_cpu
[
last_req_index
]
self
.
min_tokens
[
empty_index
]
=
self
.
min_tokens
[
last_req_index
]
self
.
stop_token_ids
[
empty_index
]
=
self
.
stop_token_ids
[
last_req_index
]
generator
=
self
.
generators
.
pop
(
last_req_index
,
None
)
if
generator
is
not
None
:
self
.
generators
[
empty_index
]
=
generator
min_token
=
self
.
min_tokens
.
pop
(
last_req_index
,
None
)
if
min_token
is
not
None
:
self
.
min_tokens
[
empty_index
]
=
min_token
self
.
request_lora_mapping
[
empty_index
]
=
self
.
request_lora_mapping
[
last_req_index
]
...
...
@@ -394,87 +401,71 @@ class InputBatch:
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
def
make_sampling_metadata
(
self
,
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]],
req_id_to_spec_token_ids
:
Dict
[
str
,
List
[
int
]],
skip_copy
:
bool
=
False
,
)
->
SamplingMetadata
:
if
not
skip_copy
:
self
.
temperature
[:
self
.
num_reqs
].
copy_
(
self
.
temperature_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
self
.
top_p
[:
self
.
num_reqs
].
copy_
(
self
.
top_p_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
self
.
top_k
[:
self
.
num_reqs
].
copy_
(
self
.
top_k_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
self
.
min_p
[:
self
.
num_reqs
].
copy_
(
self
.
min_p_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
# Trim lists to the batch size.
del
self
.
_req_ids
[
self
.
num_reqs
:]
del
self
.
req_output_token_ids
[
self
.
num_reqs
:]
def
refresh_sampling_metadata
(
self
):
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
def
_make_sampling_metadata
(
self
)
->
SamplingMetadata
:
num_reqs
=
self
.
num_reqs
copy_slice
(
self
.
temperature_cpu_tensor
,
self
.
temperature
,
num_reqs
)
if
not
self
.
no_top_p
:
copy_slice
(
self
.
top_p_cpu_tensor
,
self
.
top_p
,
num_reqs
)
if
not
self
.
no_top_k
:
copy_slice
(
self
.
top_k_cpu_tensor
,
self
.
top_k
,
num_reqs
)
if
not
self
.
no_min_p
:
copy_slice
(
self
.
min_p_cpu_tensor
,
self
.
min_p
,
num_reqs
)
if
not
self
.
no_penalties
:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
self
.
frequency_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
frequency_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
,
)
self
.
presence_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
presence_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
,
)
self
.
repetition_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
repetition_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
,
)
copy_slice
(
self
.
frequency_penalties_cpu_tensor
,
self
.
frequency_penalties
,
num_reqs
)
copy_slice
(
self
.
presence_penalties_cpu_tensor
,
self
.
presence_penalties
,
num_reqs
)
copy_slice
(
self
.
repetition_penalties_cpu_tensor
,
self
.
repetition_penalties
,
num_reqs
)
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
self
.
prompt_token_ids
=
self
.
_make_prompt_token_ids_tensor
()
output_token_ids
:
List
[
List
[
int
]]
=
[]
spec_token_ids
:
List
[
List
[
int
]]
=
[]
rejection_sampling
=
False
for
req_id
in
self
.
req_ids
[:
self
.
num_reqs
]:
assert
req_id
is
not
None
# Currently we create a tensor for output_token_ids from scratch
# at each step. However, for the penalties computation what we
# need is stats about the token ids present in the output. This
# stats can be maintained incrementally instead of computing it
# from scratch at each step.
# TODO - Replace this with incremental update to output token
# statistics.
output_token_ids
.
append
(
req_id_output_token_ids
[
req_id
])
req_spec_token_ids
=
req_id_to_spec_token_ids
.
get
(
req_id
,
[])
spec_token_ids
.
append
(
req_spec_token_ids
)
if
req_spec_token_ids
:
# If any of the requests require speculative decoding, set the
# flag to True.
rejection_sampling
=
True
prompt_token_ids
=
self
.
_make_prompt_token_ids_tensor
()
else
:
prompt_token_ids
=
None
return
SamplingMetadata
(
temperature
=
self
.
temperature
[:
self
.
num_reqs
],
temperature
=
self
.
temperature
[:
num_reqs
],
all_greedy
=
self
.
all_greedy
,
all_random
=
self
.
all_random
,
rejection_sampling
=
rejection_sampling
,
top_p
=
self
.
top_p
[:
self
.
num_reqs
],
top_k
=
self
.
top_k
[:
self
.
num_reqs
],
min_p
=
self
.
min_p
[:
self
.
num_reqs
],
no_min_p
=
self
.
no_min_p
,
no_top_p
=
self
.
no_top_p
,
no_top_k
=
self
.
no_top_k
,
top_p
=
None
if
self
.
no_top_p
else
self
.
top_p
[:
num_reqs
],
top_k
=
None
if
self
.
no_top_k
else
self
.
top_k
[:
num_reqs
],
min_p
=
None
if
self
.
no_min_p
else
self
.
min_p
[:
num_reqs
],
generators
=
self
.
generators
,
max_num_logprobs
=
self
.
max_num_logprobs
,
prompt_token_ids
=
self
.
prompt_token_ids
,
frequency_penalties
=
self
.
frequency_penalties
[:
self
.
num_reqs
],
presence_penalties
=
self
.
presence_penalties
[:
self
.
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
self
.
num_reqs
],
output_token_ids
=
output_token_ids
,
spec_token_ids
=
spec_token_ids
,
min_tokens
=
self
.
min_tokens
[:
self
.
num_reqs
],
stop_token_ids
=
self
.
stop_token_ids
[:
self
.
num_reqs
],
prompt_token_ids
=
prompt_token_ids
,
frequency_penalties
=
self
.
frequency_penalties
[:
num_reqs
],
presence_penalties
=
self
.
presence_penalties
[:
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
num_reqs
],
output_token_ids
=
cast
(
List
[
List
[
int
]],
self
.
req_output_token_ids
),
spec_token_ids
=
None
,
min_tokens
=
self
.
min_tokens
,
no_penalties
=
self
.
no_penalties
,
logit_bias
=
self
.
logit_bias
[:
self
.
num_reqs
],
logit_bias
=
self
.
logit_bias
[:
num_reqs
],
)
def
get_sampling_metadata
(
self
,
req_id_to_spec_token_ids
:
Dict
[
str
,
List
[
int
]],
)
->
SamplingMetadata
:
# Set the new spec token ids in the cached sampling metadata.
self
.
sampling_metadata
.
spec_token_ids
=
[
req_id_to_spec_token_ids
.
get
(
req_id
,
[])
for
req_id
in
self
.
req_ids
]
if
req_id_to_spec_token_ids
else
None
return
self
.
sampling_metadata
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
max_prompt_len
=
self
.
num_prompt_tokens
[:
self
.
num_reqs
].
max
()
prompt_token_ids_cpu_tensor
=
torch
.
empty
(
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
30172b49
...
...
@@ -31,7 +31,6 @@ from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.utils
import
bind_kv_cache
...
...
@@ -224,16 +223,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory
=
self
.
pin_memory
)
self
.
seq_lens_np
=
self
.
seq_lens_cpu
.
numpy
()
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""Update the cached states and the persistent batch with the scheduler
output.
The updated states are used by the `_prepare_inputs` function to create
the input GPU tensors for the model.
Returns:
True if there is a new/resumed/paused/finished request in the batch.
If False, we can skip copying SamplingMetadata to the GPU.
The SamplingMetadata is updated and copied to the GPU if there is a
new/resumed/paused/finished request in the batch.
"""
# Remove finished requests from the cached states.
for
req_id
in
scheduler_output
.
finished_req_ids
:
...
...
@@ -344,9 +342,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_new_tokens
=
(
num_computed_tokens
+
len
(
req_data
.
new_token_ids
)
-
req_state
.
num_tokens
)
new_token_ids
=
(
req_data
.
new_token_ids
[
-
num_new_tokens
:]
if
num_new_tokens
>
0
else
[])
req_state
.
output_token_ids
.
extend
(
new_token_ids
)
if
num_new_tokens
==
1
:
# Avoid slicing list in most common case.
req_state
.
output_token_ids
.
append
(
req_data
.
new_token_ids
[
-
1
])
elif
num_new_tokens
>
0
:
req_state
.
output_token_ids
.
extend
(
req_data
.
new_token_ids
[
-
num_new_tokens
:])
# Update the block IDs.
if
not
req_data
.
resumed_from_preemption
:
# Append the new blocks to the existing block IDs.
...
...
@@ -380,7 +381,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
=
end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[]
)
req_id
,
()
)
if
spec_token_ids
:
start_index
=
end_token_index
end_token_index
+=
len
(
spec_token_ids
)
...
...
@@ -410,7 +411,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
removed_req_indices
:
self
.
input_batch
.
condense
(
removed_req_indices
)
return
batch_changed
if
batch_changed
:
self
.
input_batch
.
refresh_sampling_metadata
()
def
_prepare_inputs
(
self
,
...
...
@@ -429,8 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
0
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens
[
i
]
=
num_tokens
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
...
...
@@ -669,10 +670,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_calc_mrope_positions
(
self
,
scheduler_output
:
"SchedulerOutput"
):
mrope_pos_ptr
=
0
num_reqs
=
self
.
input_batch
.
num_reqs
for
index
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
[:
num_reqs
]):
assert
req_id
is
not
None
for
index
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
req
=
self
.
requests
[
req_id
]
assert
req
.
mrope_positions
is
not
None
...
...
@@ -726,12 +724,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
,
scheduler_output
:
"SchedulerOutput"
,
cu_num_tokens
:
np
.
ndarray
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
# Get the number of spec decode tokens for each request.
num_reqs
=
self
.
input_batch
.
num_reqs
num_spec_decode_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
num_spec_decode_tokens
[
i
]
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
()))
...
...
@@ -769,22 +766,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
torch
.
from_numpy
(
spec_decode_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
def
_prepare_sampling
(
self
,
batch_changed
:
bool
,
req_to_spec_token_ids
:
Dict
[
str
,
List
[
int
]],
)
->
SamplingMetadata
:
# Create the sampling metadata.
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]]
=
\
{
req_id
:
req
.
output_token_ids
\
for
req_id
,
req
in
self
.
requests
.
items
()}
sampling_metadata
=
self
.
input_batch
.
make_sampling_metadata
(
req_id_output_token_ids
,
req_to_spec_token_ids
,
skip_copy
=
not
batch_changed
)
return
sampling_metadata
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
...
...
@@ -838,9 +819,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
:
"SchedulerOutput"
,
)
->
List
[
torch
.
Tensor
]:
encoder_outputs
:
List
[
torch
.
Tensor
]
=
[]
num_reqs
=
self
.
input_batch
.
num_reqs
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
assert
req_id
is
not
None
for
req_id
in
self
.
input_batch
.
req_ids
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
...
...
@@ -882,7 +861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
ModelRunnerOutput
,
torch
.
Tensor
]:
batch_changed
=
self
.
_update_states
(
scheduler_output
)
self
.
_update_states
(
scheduler_output
)
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
...
...
@@ -964,8 +943,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
# Sample the next token and get logprobs if needed.
sampling_metadata
=
self
.
_prepare_sampling
(
batch_changed
,
scheduler_output
.
scheduled_spec_decode_tokens
)
sampling_metadata
=
self
.
input_batch
.
get_sampling_metadata
(
scheduler_output
.
scheduled_spec_decode_tokens
)
sampler_output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
...
...
@@ -973,14 +952,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
num_reqs
=
self
.
input_batch
.
num_reqs
req_ids
:
List
[
str
]
=
[]
# Because `input_batch.req_ids` is a list of length `max_num_reqs`,
# we need to stop at `num_reqs`.
# FIXME(woosuk): This is hacky. Refactor.
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
req_ids
.
append
(
req_id
)
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
...
...
@@ -1027,7 +999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
valid_sampled_token_ids
)
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
req_ids
=
self
.
input_batch
.
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
valid_sampled_token_ids
,
spec_token_ids
=
spec_token_ids
,
...
...
@@ -1041,19 +1013,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampled_token_ids
:
List
[
List
[
int
]],
)
->
List
[
List
[
int
]]:
# TODO(woosuk): Optimize.
num_reqs
=
len
(
sampled_token_ids
)
draft_token_ids
:
List
[
List
[
int
]]
=
[]
for
i
in
range
(
num_reqs
):
if
len
(
sampled_token_ids
[
i
])
==
0
:
for
i
,
sampled_ids
in
enumerate
(
sampled_token_ids
):
num_sampled_ids
=
len
(
sampled_ids
)
if
not
num_sampled_ids
:
# Skip speculative decoding.
draft_token_ids
.
append
([])
continue
# Add sampled_token_ids to token_ids_cpu.
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
i
]
end_idx
=
start_idx
+
len
(
sampled_token_ids
[
i
])
self
.
input_batch
.
token_ids_cpu
[
i
,
start_idx
:
end_idx
]
=
sampled_token_ids
[
i
]
end_idx
=
start_idx
+
num_sampled_ids
self
.
input_batch
.
token_ids_cpu
[
i
,
start_idx
:
end_idx
]
=
sampled_ids
drafter_output
=
self
.
drafter
.
propose
(
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
],
self
.
speculative_config
.
ngram_prompt_lookup_min
,
...
...
@@ -1204,7 +1175,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
dummy_kv_caches
=
[
torch
.
tensor
(
[]
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
torch
.
tensor
(
()
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
_
in
range
(
self
.
num_attn_layers
)
]
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
30172b49
...
...
@@ -1048,8 +1048,6 @@ def swap_positions(b: InputBatch, id_1, id_2):
b
.
min_tokens
[
id_1
],
b
.
min_tokens
[
id_2
]
=
b
.
min_tokens
[
id_2
],
b
.
min_tokens
[
id_1
]
b
.
stop_token_ids
[
id_1
],
b
.
stop_token_ids
[
id_2
]
=
b
.
stop_token_ids
[
id_2
],
b
.
stop_token_ids
[
id_1
]
gen_1
=
b
.
generators
.
pop
(
id_1
,
None
)
gen_2
=
b
.
generators
.
pop
(
id_2
,
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment