Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
30172b49
Unverified
Commit
30172b49
authored
Feb 18, 2025
by
Nick Hill
Committed by
GitHub
Feb 18, 2025
Browse files
[V1] Optimize handling of sampling metadata and req_ids list (#13244)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
a4d577b3
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
255 additions
and
298 deletions
+255
-298
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+2
-7
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+18
-26
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+21
-26
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+22
-11
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+3
-3
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+4
-2
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+10
-11
vllm/v1/sample/ops/penalties.py
vllm/v1/sample/ops/penalties.py
+6
-7
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+20
-28
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+2
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+6
-7
vllm/v1/utils.py
vllm/v1/utils.py
+11
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+102
-111
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+28
-57
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+0
-2
No files found.
tests/v1/sample/test_rejection_sampler.py
View file @
30172b49
...
@@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int],
...
@@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int],
def
create_sampling_metadata
(
spec_tokens
:
List
[
List
[
int
]])
->
SamplingMetadata
:
def
create_sampling_metadata
(
spec_tokens
:
List
[
List
[
int
]])
->
SamplingMetadata
:
batch_size
=
len
(
spec_tokens
)
batch_size
=
len
(
spec_tokens
)
return
SamplingMetadata
(
return
SamplingMetadata
(
temperature
=
0.0
,
temperature
=
torch
.
tensor
([])
,
all_greedy
=
True
,
all_greedy
=
True
,
all_random
=
False
,
all_random
=
False
,
rejection_sampling
=
True
,
spec_token_ids
=
spec_tokens
,
spec_token_ids
=
spec_tokens
,
top_p
=
None
,
top_p
=
None
,
top_k
=
None
,
top_k
=
None
,
no_top_p
=
False
,
no_top_k
=
False
,
min_p
=
torch
.
empty
(
batch_size
,
),
min_p
=
torch
.
empty
(
batch_size
,
),
no_min_p
=
True
,
generators
=
{},
generators
=
{},
max_num_logprobs
=
0
,
max_num_logprobs
=
0
,
no_penalties
=
False
,
no_penalties
=
False
,
...
@@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
...
@@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
presence_penalties
=
torch
.
tensor
([]),
presence_penalties
=
torch
.
tensor
([]),
repetition_penalties
=
torch
.
tensor
([]),
repetition_penalties
=
torch
.
tensor
([]),
output_token_ids
=
[],
output_token_ids
=
[],
min_tokens
=
[],
min_tokens
=
{},
stop_token_ids
=
[],
logit_bias
=
[
None
]
*
batch_size
,
logit_bias
=
[
None
]
*
batch_size
,
)
)
...
...
tests/v1/sample/test_sampler.py
View file @
30172b49
...
@@ -77,25 +77,20 @@ def _create_default_sampling_metadata(
...
@@ -77,25 +77,20 @@ def _create_default_sampling_metadata(
temperature
=
torch
.
full
((
batch_size
,
),
0.0
),
temperature
=
torch
.
full
((
batch_size
,
),
0.0
),
all_greedy
=
True
,
all_greedy
=
True
,
all_random
=
False
,
all_random
=
False
,
rejection_sampling
=
False
,
top_p
=
None
,
top_p
=
torch
.
empty
(
batch_size
,
),
top_k
=
None
,
top_k
=
torch
.
empty
(
batch_size
,
),
min_p
=
None
,
no_top_p
=
True
,
no_top_k
=
True
,
min_p
=
torch
.
empty
(
batch_size
,
),
no_min_p
=
True
,
generators
=
{},
generators
=
{},
max_num_logprobs
=
0
,
max_num_logprobs
=
0
,
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
vocab_size
,
device
),
vocab_size
,
device
),
output_token_ids
=
output_token_ids
,
output_token_ids
=
output_token_ids
,
spec_token_ids
=
[]
,
spec_token_ids
=
None
,
frequency_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
frequency_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
presence_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
presence_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
repetition_penalties
=
_create_penalty_tensor
(
batch_size
,
1.0
,
device
),
repetition_penalties
=
_create_penalty_tensor
(
batch_size
,
1.0
,
device
),
no_penalties
=
True
,
no_penalties
=
True
,
min_tokens
=
[],
min_tokens
=
{},
stop_token_ids
=
[],
logit_bias
=
[
None
]
*
batch_size
,
logit_bias
=
[
None
]
*
batch_size
,
)
)
return
fake_sampling_metadata
return
fake_sampling_metadata
...
@@ -104,10 +99,10 @@ def _create_default_sampling_metadata(
...
@@ -104,10 +99,10 @@ def _create_default_sampling_metadata(
def
_generate_min_token_penalties_and_stop_tokens
(
def
_generate_min_token_penalties_and_stop_tokens
(
num_output_tokens
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
num_output_tokens
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
batch_indices_for_min_token_penalty
:
List
[
int
]
batch_indices_for_min_token_penalty
:
List
[
int
]
)
->
Tuple
[
List
[
int
]
,
List
[
Set
[
int
]]]:
)
->
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]]:
"""
"""
Generates and returns a
lis
t of minimum token penalties
(`min_tokens`)
Generates and returns a
dic
t of minimum token penalties
and
and a
corresponding
list of
stop token IDs (`stop_token_ids`) for each
corresponding stop token IDs (
`min_tokens`,
`stop_token_ids`) for each
batch.
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
If a batch index is included in `batch_indices_for_min_token_penalty`,
...
@@ -115,22 +110,19 @@ def _generate_min_token_penalties_and_stop_tokens(
...
@@ -115,22 +110,19 @@ def _generate_min_token_penalties_and_stop_tokens(
and a random set of stop token IDs is created. Otherwise, a lower
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
"""
stop_token_ids
:
List
[
Set
[
int
]]
=
[]
min_tokens
:
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]]
=
{}
min_tokens
:
List
[
int
]
=
[]
for
index
in
range
(
batch_size
):
for
index
in
range
(
batch_size
):
if
index
in
batch_indices_for_min_token_penalty
:
if
index
in
batch_indices_for_min_token_penalty
:
min_tokens
.
append
(
min_tokens
[
index
]
=
(
np
.
random
.
randint
(
num_output_tokens
+
1
,
np
.
random
.
randint
(
num_output_tokens
+
1
,
2
*
num_output_tokens
))
2
*
num_output_tokens
),
stop_token_ids
.
append
(
set
(
set
(
np
.
random
.
randint
(
0
,
vocab_size
-
1
)
np
.
random
.
randint
(
0
,
vocab_size
-
1
)
for
_
in
range
(
np
.
random
.
randint
(
0
,
vocab_size
))))
for
_
in
range
(
np
.
random
.
randint
(
0
,
vocab_size
))))
else
:
else
:
min_tokens
.
append
(
np
.
random
.
randint
(
0
,
num_output_tokens
))
min_tokens
[
index
]
=
(
np
.
random
.
randint
(
0
,
stop_token_ids
.
append
(
set
())
num_output_tokens
),
set
())
return
(
min_tokens
,
stop_token_ids
)
return
min_tokens
def
_create_weighted_output_token_list
(
def
_create_weighted_output_token_list
(
...
@@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
...
@@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
output_token_ids_for_batch
.
extend
(
output_token_ids_for_batch
.
extend
(
[
token_id
for
_
in
range
(
index
+
1
)])
[
token_id
for
_
in
range
(
index
+
1
)])
output_token_ids
.
append
(
output_token_ids_for_batch
)
output_token_ids
.
append
(
output_token_ids_for_batch
)
return
(
output_token_ids
,
sorted_token_ids_in_output
)
return
output_token_ids
,
sorted_token_ids_in_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
...
@@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
...
@@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
batch_indices_for_min_token_penalty
=
np
.
random
.
randint
(
batch_indices_for_min_token_penalty
=
np
.
random
.
randint
(
0
,
batch_size
-
1
,
size
=
np
.
random
.
randint
(
0
,
batch_size
)).
tolist
()
0
,
batch_size
-
1
,
size
=
np
.
random
.
randint
(
0
,
batch_size
)).
tolist
()
min_tokens
,
stop_token_ids
=
_generate_min_token_penalties_and_stop_tokens
(
min_tokens
=
_generate_min_token_penalties_and_stop_tokens
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
batch_indices_for_min_token_penalty
)
batch_indices_for_min_token_penalty
)
sampling_metadata
.
min_tokens
=
min_tokens
sampling_metadata
.
min_tokens
=
min_tokens
sampling_metadata
.
stop_token_ids
=
stop_token_ids
sampler
=
Sampler
()
sampler
=
Sampler
()
logits
=
sampler
.
apply_penalties
(
fake_logits
,
sampling_metadata
)
logits
=
sampler
.
apply_penalties
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
for
batch_idx
in
range
(
batch_size
):
for
token_id
in
range
(
VOCAB_SIZE
):
for
token_id
in
range
(
VOCAB_SIZE
):
if
token_id
in
stop_token_ids
[
batch_idx
]:
_
,
stop_token_ids
=
min_tokens
.
get
(
batch_idx
,
(
0
,
set
()))
if
token_id
in
stop_token_ids
:
assert
logits
[
batch_idx
][
token_id
]
==
-
float
(
"inf"
)
assert
logits
[
batch_idx
][
token_id
]
==
-
float
(
"inf"
)
else
:
else
:
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
30172b49
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -41,7 +41,7 @@ def _remove_requests(
...
@@ -41,7 +41,7 @@ def _remove_requests(
for
index
in
req_indices_to_remove
:
for
index
in
req_indices_to_remove
:
input_batch
.
remove_request
(
reqs
[
index
].
req_id
)
input_batch
.
remove_request
(
reqs
[
index
].
req_id
)
req_ids_to_remove
.
add
(
reqs
[
index
].
req_id
)
req_ids_to_remove
.
add
(
reqs
[
index
].
req_id
)
return
(
req_ids_to_remove
,
req_indices_to_remove_list
)
return
req_ids_to_remove
,
req_indices_to_remove_list
def
_construct_expected_sampling_metadata
(
def
_construct_expected_sampling_metadata
(
...
@@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
...
@@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
top_p
=
[
0.0
for
_
in
range
(
num_reqs
)]
top_p
=
[
0.0
for
_
in
range
(
num_reqs
)]
min_p
=
[
0.0
for
_
in
range
(
num_reqs
)]
min_p
=
[
0.0
for
_
in
range
(
num_reqs
)]
temperature
=
[
0.0
for
_
in
range
(
num_reqs
)]
temperature
=
[
0.0
for
_
in
range
(
num_reqs
)]
stop_token_ids
:
List
[
Set
[
int
]]
=
[
set
()
for
_
in
range
(
num_reqs
)]
min_tokens
=
{}
min_tokens
=
[
0
for
_
in
range
(
num_reqs
)]
logit_bias
=
[
None
]
*
num_reqs
logit_bias
=
[
None
]
*
num_reqs
for
req
in
reqs
:
for
req
in
reqs
:
if
req
.
req_id
not
in
req_ids_retained
:
if
req
.
req_id
not
in
req_ids_retained
:
...
@@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
...
@@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
top_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_p
top_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_p
min_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
min_p
min_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
min_p
temperature
[
index_in_input_batch
]
=
req
.
sampling_params
.
temperature
temperature
[
index_in_input_batch
]
=
req
.
sampling_params
.
temperature
stop
_token
_ids
[
min
_token
s
[
index_in_input_batch
]
=
(
index_in_input_batch
]
=
req
.
sampling_params
.
all_stop
_token
_id
s
req
.
sampling_params
.
min
_tokens
,
min_tokens
[
index_in_input_batch
]
=
req
.
sampling_params
.
min
_token
s
req
.
sampling_params
.
all_stop
_token
_ids
)
logit_bias
[
index_in_input_batch
]
=
req
.
sampling_params
.
logit_bias
logit_bias
[
index_in_input_batch
]
=
req
.
sampling_params
.
logit_bias
return
SamplingMetadata
(
return
SamplingMetadata
(
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
device
=
device
),
device
=
device
),
all_greedy
=
False
,
all_greedy
=
False
,
all_random
=
True
,
all_random
=
True
,
rejection_sampling
=
False
,
top_p
=
None
if
all
(
x
==
1.0
for
x
in
top_p
)
else
torch
.
tensor
(
top_p
=
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
top_k
=
torch
.
tensor
(
top_k
,
dtype
=
torch
.
int
,
device
=
device
),
top_k
=
None
if
all
(
x
==
0
for
x
in
top_k
)
else
torch
.
tensor
(
no_top_p
=
all
(
x
==
1.0
for
x
in
top_p
),
top_k
,
dtype
=
torch
.
int
,
device
=
device
),
no_top_k
=
all
(
x
==
0
for
x
in
top_k
),
min_p
=
None
if
all
(
x
==
0.0
for
x
in
min_p
)
else
torch
.
tensor
(
min_p
=
torch
.
tensor
(
min_p
,
dtype
=
torch
.
float
,
device
=
device
),
min_p
,
dtype
=
torch
.
float
,
device
=
device
),
no_min_p
=
all
(
x
==
0.0
for
x
in
min_p
),
generators
=
{},
generators
=
{},
max_num_logprobs
=
0
,
max_num_logprobs
=
0
,
prompt_token_ids
=
make_tensor_with_pad
(
prompt_token_ids
=
make_tensor_with_pad
(
...
@@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata(
...
@@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata(
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
device
),
device
=
device
),
output_token_ids
=
output_token_ids
,
output_token_ids
=
output_token_ids
,
spec_token_ids
=
[]
,
spec_token_ids
=
None
,
min_tokens
=
min_tokens
,
min_tokens
=
min_tokens
,
stop_token_ids
=
stop_token_ids
,
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
and
all
(
x
==
0
for
x
in
frequency_penalties
)
and
all
(
x
==
0
for
x
in
frequency_penalties
)
and
all
(
x
==
1
for
x
in
repetition_penalties
)),
and
all
(
x
==
1
for
x
in
repetition_penalties
)),
...
@@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
...
@@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch
.
condense
(
req_indices_to_remove
)
input_batch
.
condense
(
req_indices_to_remove
)
# Generate the sampling metadata
# Generate the sampling metadata
sampling_metadata
=
input_batch
.
make_sampling_metadata
(
sampling_metadata
=
input_batch
.
_make_sampling_metadata
()
req_id_output_token_ids
,
req_id_to_spec_token_ids
=
{},
skip_copy
=
False
)
# Create expected output.
# Create expected output.
expected_sampling_metadata
=
_construct_expected_sampling_metadata
(
expected_sampling_metadata
=
_construct_expected_sampling_metadata
(
...
@@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
...
@@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch
.
req_id_to_index
,
input_batch
.
req_id_to_index
,
device
=
torch
.
device
(
device
))
device
=
torch
.
device
(
device
))
def
same
(
t1
:
Optional
[
torch
.
Tensor
],
t2
:
Optional
[
torch
.
Tensor
])
->
bool
:
return
(
t1
is
None
and
t2
is
None
)
or
(
t1
is
not
None
and
t2
is
not
None
and
torch
.
allclose
(
t1
,
t2
))
# Assert the actual and expected output.
# Assert the actual and expected output.
assert
torch
.
allclose
(
expected_sampling_metadata
.
temperature
,
assert
torch
.
allclose
(
expected_sampling_metadata
.
temperature
,
sampling_metadata
.
temperature
)
sampling_metadata
.
temperature
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
top_p
,
assert
same
(
expected_sampling_metadata
.
top_p
,
sampling_metadata
.
top_p
)
sampling_metadata
.
top_p
)
assert
same
(
expected_sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
expected_sampling_metadata
.
frequency_penalties
,
expected_sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
frequency_penalties
,
...
@@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
...
@@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
assert
(
expected_sampling_metadata
.
output_token_ids
==
assert
(
expected_sampling_metadata
.
output_token_ids
==
sampling_metadata
.
output_token_ids
)
sampling_metadata
.
output_token_ids
)
assert
expected_sampling_metadata
.
min_tokens
==
sampling_metadata
.
min_tokens
assert
expected_sampling_metadata
.
min_tokens
==
sampling_metadata
.
min_tokens
assert
expected_sampling_metadata
.
stop_token_ids
==
\
sampling_metadata
.
stop_token_ids
assert
expected_sampling_metadata
.
no_penalties
==
\
assert
expected_sampling_metadata
.
no_penalties
==
\
sampling_metadata
.
no_penalties
sampling_metadata
.
no_penalties
assert
expected_sampling_metadata
.
no_top_p
==
sampling_metadata
.
no_top_p
assert
expected_sampling_metadata
.
no_top_k
==
sampling_metadata
.
no_top_k
assert
expected_sampling_metadata
.
logit_bias
==
sampling_metadata
.
logit_bias
assert
expected_sampling_metadata
.
logit_bias
==
sampling_metadata
.
logit_bias
tests/v1/worker/test_gpu_model_runner.py
View file @
30172b49
...
@@ -5,6 +5,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
...
@@ -5,6 +5,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.scheduler_output
import
(
CachedRequestData
,
NewRequestData
,
from
vllm.v1.core.scheduler_output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
SchedulerOutput
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
...
@@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool:
...
@@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool:
return
req_id
in
model_runner
.
requests
return
req_id
in
model_runner
.
requests
def
_is_sampling_metadata_changed
(
model_runner
,
sampling_metadata_before
:
SamplingMetadata
):
return
model_runner
.
input_batch
.
sampling_metadata
is
not
(
sampling_metadata_before
)
def
test_update_states_new_request
(
model_runner
):
def
test_update_states_new_request
(
model_runner
):
req_id
=
"req_0"
req_id
=
"req_0"
# new req
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
scheduler_output
=
_schedule_new_request
(
req_id
)
batch_changed
=
model_runner
.
_update_states
(
scheduler_output
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
assert
batch_changed
is
True
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
...
@@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner):
...
@@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner):
free_encoder_input_ids
=
[],
free_encoder_input_ids
=
[],
)
)
batch_changed
=
model_runner
.
_update_states
(
scheduler_output
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
assert
batch_changed
is
True
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
not
_is_req_added
(
model_runner
,
req_id
)
assert
not
_is_req_added
(
model_runner
,
req_id
)
assert
not
_is_req_scheduled
(
model_runner
,
req_id
)
assert
not
_is_req_scheduled
(
model_runner
,
req_id
)
...
@@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner):
...
@@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_spec_decode_tokens
=
{},
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
num_common_prefix_blocks
=
0
,
finished_req_ids
=
{}
,
finished_req_ids
=
set
()
,
free_encoder_input_ids
=
[],
free_encoder_input_ids
=
[],
)
)
...
@@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner):
...
@@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner):
free_encoder_input_ids
=
[],
free_encoder_input_ids
=
[],
)
)
batch_changed
=
model_runner
.
_update_states
(
scheduler_output
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
assert
batch_changed
is
True
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
...
@@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner):
...
@@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner):
free_encoder_input_ids
=
[],
free_encoder_input_ids
=
[],
)
)
batch_changed
=
model_runner
.
_update_states
(
scheduler_output
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
assert
batch_changed
is
False
model_runner
.
_update_states
(
scheduler_output
)
assert
not
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
...
@@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner):
...
@@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner):
free_encoder_input_ids
=
[],
free_encoder_input_ids
=
[],
)
)
batch_changed
=
model_runner
.
_update_states
(
scheduler_output
)
metadata_before
=
model_runner
.
_update_states
(
scheduler_output
)
assert
batch_changed
is
True
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_ids
[
0
])
assert
_is_req_added
(
model_runner
,
req_ids
[
0
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
0
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
0
])
...
...
vllm/model_executor/layers/utils.py
View file @
30172b49
...
@@ -45,7 +45,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
...
@@ -45,7 +45,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
vocab_size
,
num_seqs
)
vocab_size
,
num_seqs
)
output_bin_counts
,
output_mask
=
get_token_bin_counts_and_mask
(
output_bin_counts
,
output_mask
=
get_token_bin_counts_and_mask
(
output_tokens_tensor
,
vocab_size
,
num_seqs
)
output_tokens_tensor
,
vocab_size
,
num_seqs
)
repetition_penalties
=
repetition_penalties
.
unsqueeze
_
(
dim
=
1
).
repeat
(
repetition_penalties
=
repetition_penalties
.
unsqueeze
(
dim
=
1
).
repeat
(
1
,
vocab_size
)
1
,
vocab_size
)
logits
[
logits
>
0
]
/=
torch
.
where
(
prompt_mask
|
output_mask
,
logits
[
logits
>
0
]
/=
torch
.
where
(
prompt_mask
|
output_mask
,
repetition_penalties
,
1.0
)[
logits
>
0
]
repetition_penalties
,
1.0
)[
logits
>
0
]
...
@@ -53,6 +53,6 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
...
@@ -53,6 +53,6 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
repetition_penalties
,
1.0
)[
logits
<=
0
]
repetition_penalties
,
1.0
)[
logits
<=
0
]
# We follow the definition in OpenAI API.
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
-=
frequency_penalties
.
unsqueeze
_
(
dim
=
1
)
*
output_bin_counts
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
output_bin_counts
logits
-=
presence_penalties
.
unsqueeze
_
(
dim
=
1
)
*
output_mask
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
output_mask
return
logits
return
logits
vllm/v1/core/scheduler.py
View file @
30172b49
...
@@ -195,8 +195,10 @@ class Scheduler:
...
@@ -195,8 +195,10 @@ class Scheduler:
request
.
num_computed_tokens
-
request
.
num_computed_tokens
-
request
.
num_tokens
)
request
.
num_tokens
)
if
num_scheduled_spec_tokens
>
0
:
if
num_scheduled_spec_tokens
>
0
:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del
request
.
spec_token_ids
[
num_scheduled_spec_tokens
:]
scheduled_spec_decode_tokens
[
request
.
request_id
]
=
(
scheduled_spec_decode_tokens
[
request
.
request_id
]
=
(
request
.
spec_token_ids
[:
num_scheduled_spec_tokens
]
)
request
.
spec_token_ids
)
# Encoder-related.
# Encoder-related.
if
encoder_inputs_to_schedule
:
if
encoder_inputs_to_schedule
:
...
@@ -567,7 +569,7 @@ class Scheduler:
...
@@ -567,7 +569,7 @@ class Scheduler:
outputs
.
append
(
outputs
.
append
(
EngineCoreOutput
(
EngineCoreOutput
(
request_id
=
req_id
,
request_id
=
req_id
,
new_token_ids
=
new_token_ids
or
[]
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
...
...
vllm/v1/sample/metadata.py
View file @
30172b49
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Set
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
...
@@ -12,15 +12,13 @@ class SamplingMetadata:
...
@@ -12,15 +12,13 @@ class SamplingMetadata:
temperature
:
torch
.
Tensor
temperature
:
torch
.
Tensor
all_greedy
:
bool
all_greedy
:
bool
all_random
:
bool
all_random
:
bool
rejection_sampling
:
bool
spec_token_ids
:
List
[
List
[
int
]]
top_p
:
torch
.
T
ens
or
# None when there are no speculated tok
ens
.
top_k
:
torch
.
Tensor
spec_token_ids
:
Optional
[
List
[
List
[
int
]]]
no_top_p
:
bool
no_
top_
k
:
bool
top_
p
:
Optional
[
torch
.
Tensor
]
min_p
:
torch
.
Tensor
top_k
:
Optional
[
torch
.
Tensor
]
no_
min_p
:
bool
min_p
:
Optional
[
torch
.
Tensor
]
generators
:
Dict
[
int
,
torch
.
Generator
]
generators
:
Dict
[
int
,
torch
.
Generator
]
...
@@ -34,7 +32,8 @@ class SamplingMetadata:
...
@@ -34,7 +32,8 @@ class SamplingMetadata:
repetition_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
output_token_ids
:
List
[
List
[
int
]]
output_token_ids
:
List
[
List
[
int
]]
min_tokens
:
List
[
int
]
stop_token_ids
:
List
[
Set
[
int
]]
# req_index -> (min_tokens, stop_token_ids)
min_tokens
:
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]]
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
vllm/v1/sample/ops/penalties.py
View file @
30172b49
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Set
,
Tuple
import
torch
import
torch
...
@@ -8,18 +8,17 @@ from vllm.model_executor.layers.utils import apply_penalties
...
@@ -8,18 +8,17 @@ from vllm.model_executor.layers.utils import apply_penalties
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
def
apply_min_token_penalties
(
logits
:
torch
.
Tensor
,
def
apply_min_token_penalties
(
output_token_ids
:
List
[
List
[
int
]],
logits
:
torch
.
Tensor
,
output_token_ids
:
List
[
List
[
int
]],
stop_token_ids
:
List
[
Set
[
int
]],
min_tokens
:
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]])
->
None
:
min_tokens
:
List
[
int
])
->
None
:
"""
"""
Applies minimum token penalty by setting the logits of the stop tokens
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
to -inf.
"""
"""
min_tokens_logits_to_penalize
:
List
[
Tuple
[
int
,
int
]]
=
[]
min_tokens_logits_to_penalize
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
index
,
min_token
in
enumerate
(
min_tokens
):
for
index
,
(
min_token
,
stop_token_ids
)
in
min_tokens
.
items
(
):
if
len
(
output_token_ids
[
index
])
<
min_token
:
if
len
(
output_token_ids
[
index
])
<
min_token
:
for
stop_token_id
in
stop_token_ids
[
index
]
:
for
stop_token_id
in
stop_token_ids
:
min_tokens_logits_to_penalize
.
append
((
index
,
stop_token_id
))
min_tokens_logits_to_penalize
.
append
((
index
,
stop_token_id
))
if
min_tokens_logits_to_penalize
:
if
min_tokens_logits_to_penalize
:
logits
[
tuple
(
zip
(
*
min_tokens_logits_to_penalize
))]
=
-
float
(
"inf"
)
logits
[
tuple
(
zip
(
*
min_tokens_logits_to_penalize
))]
=
-
float
(
"inf"
)
...
...
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
30172b49
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
from
typing
import
Dict
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -55,13 +55,11 @@ class TopKTopPSampler(nn.Module):
...
@@ -55,13 +55,11 @@ class TopKTopPSampler(nn.Module):
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
generators
:
Dict
[
int
,
torch
.
Generator
],
no_top_k
:
bool
,
k
:
Optional
[
torch
.
Tensor
],
k
:
torch
.
Tensor
,
p
:
Optional
[
torch
.
Tensor
],
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""PyTorch-native implementation of top-k and top-p sampling."""
"""PyTorch-native implementation of top-k and top-p sampling."""
logits
=
apply_top_k_top_p
(
logits
,
no_top_k
,
k
,
no_top_p
,
p
)
logits
=
apply_top_k_top_p
(
logits
,
k
,
p
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
return
random_sample
(
probs
,
generators
)
...
@@ -69,37 +67,33 @@ class TopKTopPSampler(nn.Module):
...
@@ -69,37 +67,33 @@ class TopKTopPSampler(nn.Module):
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
generators
:
Dict
[
int
,
torch
.
Generator
],
no_top_k
:
bool
,
k
:
Optional
[
torch
.
Tensor
],
k
:
torch
.
Tensor
,
p
:
Optional
[
torch
.
Tensor
],
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""More optimized implementation for top-k and top-p sampling."""
"""More optimized implementation for top-k and top-p sampling."""
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
if
no_top_k
and
no_top_p
:
if
k
is
None
and
p
is
None
:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
# CPU-GPU synchronization while `flashinfer_sample` does.
return
random_sample
(
probs
,
generators
)
return
random_sample
(
probs
,
generators
)
return
flashinfer_sample
(
probs
,
no_top_k
,
k
,
no_top_p
,
p
,
generators
)
return
flashinfer_sample
(
probs
,
k
,
p
,
generators
)
def
apply_top_k_top_p
(
def
apply_top_k_top_p
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
no_top_k
:
bool
,
k
:
Optional
[
torch
.
Tensor
],
k
:
torch
.
Tensor
,
p
:
Optional
[
torch
.
Tensor
],
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Apply top-k and top-p masks to the logits.
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
This function sorts the logits tensor, which can be slow for large batches.
"""
"""
if
no_top_k
and
no_top_p
:
if
k
is
None
and
p
is
None
:
return
logits
return
logits
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
if
not
no_top_k
:
if
k
is
not
None
:
# Apply top-k.
# Apply top-k.
top_k_mask
=
logits_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
top_k_mask
=
logits_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
# Get all the top_k values.
# Get all the top_k values.
...
@@ -107,7 +101,7 @@ def apply_top_k_top_p(
...
@@ -107,7 +101,7 @@ def apply_top_k_top_p(
top_k_mask
=
logits_sort
<
top_k_mask
top_k_mask
=
logits_sort
<
top_k_mask
logits_sort
.
masked_fill_
(
top_k_mask
,
-
float
(
"inf"
))
logits_sort
.
masked_fill_
(
top_k_mask
,
-
float
(
"inf"
))
if
not
no_top_p
:
if
p
is
not
None
:
# Apply top-p.
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
...
@@ -147,10 +141,8 @@ def random_sample(
...
@@ -147,10 +141,8 @@ def random_sample(
def
flashinfer_sample
(
def
flashinfer_sample
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
no_top_k
:
bool
,
k
:
Optional
[
torch
.
Tensor
],
k
:
torch
.
Tensor
,
p
:
Optional
[
torch
.
Tensor
],
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
generators
:
Dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Sample from the probabilities using FlashInfer.
"""Sample from the probabilities using FlashInfer.
...
@@ -167,7 +159,7 @@ def flashinfer_sample(
...
@@ -167,7 +159,7 @@ def flashinfer_sample(
does not. Call this function at the end of the forward pass to minimize
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
the synchronization overhead.
"""
"""
assert
not
(
no_top_k
and
no_top_p
)
assert
not
(
k
is
None
and
p
is
None
)
max_top_k_round
=
32
max_top_k_round
=
32
batch_size
=
probs
.
shape
[
0
]
batch_size
=
probs
.
shape
[
0
]
uniform_samples
=
torch
.
empty
((
max_top_k_round
,
batch_size
),
uniform_samples
=
torch
.
empty
((
max_top_k_round
,
batch_size
),
...
@@ -178,11 +170,11 @@ def flashinfer_sample(
...
@@ -178,11 +170,11 @@ def flashinfer_sample(
for
i
,
generator
in
generators
.
items
():
for
i
,
generator
in
generators
.
items
():
uniform_samples
[:,
i
].
uniform_
(
generator
=
generator
)
uniform_samples
[:,
i
].
uniform_
(
generator
=
generator
)
if
no_top_k
:
if
k
is
None
:
# Top-p only.
# Top-p only.
next_token_ids
,
success
=
flashinfer
.
sampling
.
top_p_sampling_from_probs
(
next_token_ids
,
success
=
flashinfer
.
sampling
.
top_p_sampling_from_probs
(
probs
,
uniform_samples
,
p
,
deterministic
=
True
)
probs
,
uniform_samples
,
p
,
deterministic
=
True
)
elif
no_top_p
:
elif
p
is
None
:
# Top-k only.
# Top-k only.
next_token_ids
,
success
=
flashinfer
.
sampling
.
top_k_sampling_from_probs
(
next_token_ids
,
success
=
flashinfer
.
sampling
.
top_k_sampling_from_probs
(
probs
,
uniform_samples
,
k
,
deterministic
=
True
)
probs
,
uniform_samples
,
k
,
deterministic
=
True
)
...
@@ -194,9 +186,9 @@ def flashinfer_sample(
...
@@ -194,9 +186,9 @@ def flashinfer_sample(
# NOTE: CPU-GPU synchronization happens here.
# NOTE: CPU-GPU synchronization happens here.
if
not
success
.
all
():
if
not
success
.
all
():
if
not
no_top_k
:
if
k
is
not
None
:
probs
=
flashinfer
.
sampling
.
top_k_renorm_prob
(
probs
,
k
)
probs
=
flashinfer
.
sampling
.
top_k_renorm_prob
(
probs
,
k
)
if
not
no_top_p
:
if
p
is
not
None
:
probs
=
flashinfer
.
sampling
.
top_p_renorm_prob
(
probs
,
p
)
probs
=
flashinfer
.
sampling
.
top_p_renorm_prob
(
probs
,
p
)
next_token_ids
=
flashinfer
.
sampling
.
sampling_from_probs
(
next_token_ids
=
flashinfer
.
sampling
.
sampling_from_probs
(
probs
,
uniform_samples
[
0
],
deterministic
=
True
)
probs
,
uniform_samples
[
0
],
deterministic
=
True
)
...
...
vllm/v1/sample/rejection_sampler.py
View file @
30172b49
...
@@ -68,6 +68,7 @@ class RejectionSampler(nn.Module):
...
@@ -68,6 +68,7 @@ class RejectionSampler(nn.Module):
# NOTE: The following input preparationg can be moved
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# to the model runner with a persistent manner for better
# performance.
# performance.
assert
sampling_metadata
.
spec_token_ids
is
not
None
spec_token_ids
=
sampling_metadata
.
spec_token_ids
spec_token_ids
=
sampling_metadata
.
spec_token_ids
max_spec_len
=
max
(
len
(
s
)
for
s
in
spec_token_ids
)
max_spec_len
=
max
(
len
(
s
)
for
s
in
spec_token_ids
)
batch_size
=
len
(
spec_token_ids
)
batch_size
=
len
(
spec_token_ids
)
...
@@ -119,6 +120,7 @@ class RejectionSampler(nn.Module):
...
@@ -119,6 +120,7 @@ class RejectionSampler(nn.Module):
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
assert
sampling_metadata
.
spec_token_ids
is
not
None
spec_lens
=
[
len
(
x
)
for
x
in
sampling_metadata
.
spec_token_ids
]
spec_lens
=
[
len
(
x
)
for
x
in
sampling_metadata
.
spec_token_ids
]
# Add 1 to include the 'bonus' token.
# Add 1 to include the 'bonus' token.
sample_lens
=
[
x
+
1
for
x
in
spec_lens
]
sample_lens
=
[
x
+
1
for
x
in
spec_lens
]
...
...
vllm/v1/sample/sampler.py
View file @
30172b49
...
@@ -26,7 +26,7 @@ class Sampler(nn.Module):
...
@@ -26,7 +26,7 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
if
sampling_metadata
.
rejection_sampling
:
if
sampling_metadata
.
spec_token_ids
:
if
sampling_metadata
.
max_num_logprobs
:
if
sampling_metadata
.
max_num_logprobs
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Rejection sampling does not support logprobs."
)
"Rejection sampling does not support logprobs."
)
...
@@ -104,16 +104,14 @@ class Sampler(nn.Module):
...
@@ -104,16 +104,14 @@ class Sampler(nn.Module):
logits
=
self
.
apply_temperature
(
logits
,
sampling_metadata
.
temperature
)
logits
=
self
.
apply_temperature
(
logits
,
sampling_metadata
.
temperature
)
# Apply min_p.
# Apply min_p.
if
not
sampling_metadata
.
no_
min_p
:
if
sampling_metadata
.
min_p
is
not
None
:
logits
=
self
.
apply_min_p
(
logits
,
sampling_metadata
.
min_p
)
logits
=
self
.
apply_min_p
(
logits
,
sampling_metadata
.
min_p
)
# Apply top_k and/or top_p.
# Apply top_k and/or top_p.
random_sampled
=
self
.
topk_topp_sampler
(
random_sampled
=
self
.
topk_topp_sampler
(
logits
,
logits
,
sampling_metadata
.
generators
,
sampling_metadata
.
generators
,
sampling_metadata
.
no_top_k
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
,
sampling_metadata
.
no_top_p
,
sampling_metadata
.
top_p
,
sampling_metadata
.
top_p
,
)
)
...
@@ -179,9 +177,10 @@ class Sampler(nn.Module):
...
@@ -179,9 +177,10 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
apply_min_token_penalties
(
logits
,
sampling_metadata
.
output_token_ids
,
if
sampling_metadata
.
min_tokens
:
sampling_metadata
.
stop_token_ids
,
apply_min_token_penalties
(
logits
,
sampling_metadata
.
min_tokens
)
sampling_metadata
.
output_token_ids
,
sampling_metadata
.
min_tokens
)
if
not
sampling_metadata
.
no_penalties
:
if
not
sampling_metadata
.
no_penalties
:
assert
sampling_metadata
.
prompt_token_ids
is
not
None
assert
sampling_metadata
.
prompt_token_ids
is
not
None
logits
=
apply_all_penalties
(
logits
=
apply_all_penalties
(
...
...
vllm/v1/utils.py
View file @
30172b49
...
@@ -188,3 +188,14 @@ def bind_kv_cache(
...
@@ -188,3 +188,14 @@ def bind_kv_cache(
for
layer_name
,
kv_cache
in
kv_caches
.
items
():
for
layer_name
,
kv_cache
in
kv_caches
.
items
():
# NOTE: Use list because of v0 PP virtual engine.
# NOTE: Use list because of v0 PP virtual engine.
forward_context
[
layer_name
].
kv_cache
=
[
kv_cache
]
forward_context
[
layer_name
].
kv_cache
=
[
kv_cache
]
def
copy_slice
(
from_tensor
:
torch
.
Tensor
,
to_tensor
:
torch
.
Tensor
,
length
:
int
)
->
None
:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
"""
to_tensor
[:
length
].
copy_
(
from_tensor
[:
length
],
non_blocking
=
True
)
vllm/v1/worker/gpu_input_batch.py
View file @
30172b49
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# Datastructures defining an input batch
# Datastructures defining an input batch
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -12,6 +11,7 @@ from vllm.lora.request import LoRARequest
...
@@ -12,6 +11,7 @@ from vllm.lora.request import LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.utils
import
copy_slice
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -63,7 +63,7 @@ class InputBatch:
...
@@ -63,7 +63,7 @@ class InputBatch:
self
.
pin_memory
=
pin_memory
self
.
pin_memory
=
pin_memory
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
req_ids
:
List
[
Optional
[
str
]]
=
[
None
]
*
max_num_reqs
self
.
_
req_ids
:
List
[
Optional
[
str
]]
=
[
]
self
.
req_id_to_index
:
Dict
[
str
,
int
]
=
{}
self
.
req_id_to_index
:
Dict
[
str
,
int
]
=
{}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# TODO(woosuk): This buffer could be too large if max_model_len is big.
...
@@ -171,11 +171,8 @@ class InputBatch:
...
@@ -171,11 +171,8 @@ class InputBatch:
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_reqs
:
Set
[
str
]
=
set
()
self
.
repetition_penalties_reqs
:
Set
[
str
]
=
set
()
self
.
min_tokens
:
List
[
int
]
=
[
0
]
*
max_num_reqs
# req_index -> (min_tokens, stop_token_ids)
self
.
stop_token_ids
:
List
[
Set
[
int
]]
=
[
self
.
min_tokens
:
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]]
=
{}
set
()
for
_
in
range
(
max_num_reqs
)
]
self
.
prompt_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# lora related
# lora related
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,
),
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,
),
...
@@ -196,6 +193,17 @@ class InputBatch:
...
@@ -196,6 +193,17 @@ class InputBatch:
self
.
logit_bias
:
List
[
Optional
[
Dict
[
int
,
self
.
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
=
[
None
]
*
max_num_reqs
float
]]]
=
[
None
]
*
max_num_reqs
self
.
req_output_token_ids
:
List
[
Optional
[
List
[
int
]]]
=
[]
# This is updated each time the batch constituents change.
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
@
property
def
req_ids
(
self
)
->
List
[
str
]:
# None elements should only be present transiently
# while performing state updates to the batch.
return
cast
(
List
[
str
],
self
.
_req_ids
)
def
add_request
(
def
add_request
(
self
,
self
,
request
:
"CachedRequestState"
,
request
:
"CachedRequestState"
,
...
@@ -206,7 +214,13 @@ class InputBatch:
...
@@ -206,7 +214,13 @@ class InputBatch:
assert
req_index
<
self
.
max_num_reqs
assert
req_index
<
self
.
max_num_reqs
req_id
=
request
.
req_id
req_id
=
request
.
req_id
self
.
req_ids
[
req_index
]
=
req_id
if
req_index
==
len
(
self
.
_req_ids
):
self
.
_req_ids
.
append
(
req_id
)
self
.
req_output_token_ids
.
append
(
request
.
output_token_ids
)
else
:
self
.
_req_ids
[
req_index
]
=
req_id
self
.
req_output_token_ids
[
req_index
]
=
request
.
output_token_ids
self
.
req_id_to_index
[
req_id
]
=
req_index
self
.
req_id_to_index
[
req_id
]
=
req_index
# Copy the prompt token ids and output token ids.
# Copy the prompt token ids and output token ids.
...
@@ -255,8 +269,9 @@ class InputBatch:
...
@@ -255,8 +269,9 @@ class InputBatch:
req_index
]
=
sampling_params
.
repetition_penalty
req_index
]
=
sampling_params
.
repetition_penalty
if
sampling_params
.
repetition_penalty
!=
1.0
:
if
sampling_params
.
repetition_penalty
!=
1.0
:
self
.
repetition_penalties_reqs
.
add
(
req_id
)
self
.
repetition_penalties_reqs
.
add
(
req_id
)
self
.
min_tokens
[
req_index
]
=
sampling_params
.
min_tokens
if
sampling_params
.
min_tokens
:
self
.
stop_token_ids
[
req_index
]
=
sampling_params
.
all_stop_token_ids
self
.
min_tokens
[
req_index
]
=
(
sampling_params
.
min_tokens
,
sampling_params
.
all_stop_token_ids
)
# NOTE(woosuk): self.generators should not include the requests that
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
# do not have their own generator.
...
@@ -284,16 +299,20 @@ class InputBatch:
...
@@ -284,16 +299,20 @@ class InputBatch:
self
.
request_lora_mapping
[
req_index
]
=
0
self
.
request_lora_mapping
[
req_index
]
=
0
def
remove_request
(
self
,
req_id
:
str
)
->
Optional
[
int
]:
def
remove_request
(
self
,
req_id
:
str
)
->
Optional
[
int
]:
"""This method must always be followed by a call to condense()."""
req_index
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
req_index
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
if
req_index
is
None
:
if
req_index
is
None
:
return
None
return
None
self
.
req_ids
[
req_index
]
=
None
self
.
_req_ids
[
req_index
]
=
None
self
.
req_output_token_ids
[
req_index
]
=
None
self
.
greedy_reqs
.
discard
(
req_id
)
self
.
greedy_reqs
.
discard
(
req_id
)
self
.
random_reqs
.
discard
(
req_id
)
self
.
random_reqs
.
discard
(
req_id
)
self
.
top_p_reqs
.
discard
(
req_id
)
self
.
top_p_reqs
.
discard
(
req_id
)
self
.
top_k_reqs
.
discard
(
req_id
)
self
.
top_k_reqs
.
discard
(
req_id
)
self
.
min_p_reqs
.
discard
(
req_id
)
self
.
min_p_reqs
.
discard
(
req_id
)
self
.
min_tokens
.
pop
(
req_index
,
None
)
self
.
frequency_penalties_reqs
.
discard
(
req_id
)
self
.
frequency_penalties_reqs
.
discard
(
req_id
)
self
.
presence_penalties_reqs
.
discard
(
req_id
)
self
.
presence_penalties_reqs
.
discard
(
req_id
)
self
.
repetition_penalties_reqs
.
discard
(
req_id
)
self
.
repetition_penalties_reqs
.
discard
(
req_id
)
...
@@ -313,33 +332,17 @@ class InputBatch:
...
@@ -313,33 +332,17 @@ class InputBatch:
self
.
logit_bias
[
req_index
]
=
None
self
.
logit_bias
[
req_index
]
=
None
return
req_index
return
req_index
def
clear
(
self
)
->
None
:
self
.
req_ids
=
[
None
]
*
self
.
max_num_reqs
self
.
req_id_to_index
.
clear
()
self
.
greedy_reqs
.
clear
()
self
.
random_reqs
.
clear
()
self
.
top_p_reqs
.
clear
()
self
.
top_k_reqs
.
clear
()
self
.
min_p_reqs
.
clear
()
self
.
frequency_penalties_reqs
.
clear
()
self
.
presence_penalties_reqs
.
clear
()
self
.
repetition_penalties_reqs
.
clear
()
self
.
generators
.
clear
()
self
.
num_logprobs
.
clear
()
self
.
num_prompt_logprobs
.
clear
()
self
.
request_lora_mapping
.
fill
(
0
)
self
.
lora_id_to_lora_request
.
clear
()
self
.
lora_id_to_request_ids
.
clear
()
self
.
logit_bias
=
[
None
]
*
self
.
max_num_reqs
def
condense
(
self
,
empty_req_indices
:
List
[
int
])
->
None
:
def
condense
(
self
,
empty_req_indices
:
List
[
int
])
->
None
:
if
self
.
num_reqs
==
0
:
num_reqs
=
self
.
num_reqs
if
num_reqs
==
0
:
# The batched states are empty.
# The batched states are empty.
self
.
_req_ids
.
clear
()
self
.
req_output_token_ids
.
clear
()
return
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
# is sorted in descending order.
last_req_index
=
self
.
num_reqs
+
len
(
empty_req_indices
)
-
1
last_req_index
=
num_reqs
+
len
(
empty_req_indices
)
-
1
while
empty_req_indices
:
while
empty_req_indices
:
# Find the largest non-empty index.
# Find the largest non-empty index.
while
last_req_index
in
empty_req_indices
:
while
last_req_index
in
empty_req_indices
:
...
@@ -351,10 +354,13 @@ class InputBatch:
...
@@ -351,10 +354,13 @@ class InputBatch:
break
break
# Swap the states.
# Swap the states.
req_id
=
self
.
req_ids
[
last_req_index
]
req_id
=
self
.
_req_ids
[
last_req_index
]
output_token_ids
=
self
.
req_output_token_ids
[
last_req_index
]
assert
req_id
is
not
None
assert
req_id
is
not
None
self
.
req_ids
[
empty_index
]
=
req_id
self
.
_req_ids
[
empty_index
]
=
req_id
self
.
req_ids
[
last_req_index
]
=
None
self
.
_req_ids
[
last_req_index
]
=
None
self
.
req_output_token_ids
[
empty_index
]
=
output_token_ids
self
.
req_output_token_ids
[
last_req_index
]
=
None
self
.
req_id_to_index
[
req_id
]
=
empty_index
self
.
req_id_to_index
[
req_id
]
=
empty_index
num_tokens
=
self
.
num_tokens
[
last_req_index
]
num_tokens
=
self
.
num_tokens
[
last_req_index
]
...
@@ -379,13 +385,14 @@ class InputBatch:
...
@@ -379,13 +385,14 @@ class InputBatch:
self
.
repetition_penalties_cpu
[
self
.
repetition_penalties_cpu
[
empty_index
]
=
self
.
repetition_penalties_cpu
[
last_req_index
]
empty_index
]
=
self
.
repetition_penalties_cpu
[
last_req_index
]
self
.
min_p_cpu
[
empty_index
]
=
self
.
min_p_cpu
[
last_req_index
]
self
.
min_p_cpu
[
empty_index
]
=
self
.
min_p_cpu
[
last_req_index
]
self
.
min_tokens
[
empty_index
]
=
self
.
min_tokens
[
last_req_index
]
self
.
stop_token_ids
[
empty_index
]
=
self
.
stop_token_ids
[
last_req_index
]
generator
=
self
.
generators
.
pop
(
last_req_index
,
None
)
generator
=
self
.
generators
.
pop
(
last_req_index
,
None
)
if
generator
is
not
None
:
if
generator
is
not
None
:
self
.
generators
[
empty_index
]
=
generator
self
.
generators
[
empty_index
]
=
generator
min_token
=
self
.
min_tokens
.
pop
(
last_req_index
,
None
)
if
min_token
is
not
None
:
self
.
min_tokens
[
empty_index
]
=
min_token
self
.
request_lora_mapping
[
empty_index
]
=
self
.
request_lora_mapping
[
self
.
request_lora_mapping
[
empty_index
]
=
self
.
request_lora_mapping
[
last_req_index
]
last_req_index
]
...
@@ -394,87 +401,71 @@ class InputBatch:
...
@@ -394,87 +401,71 @@ class InputBatch:
# Decrement last_req_index since it is now empty.
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
last_req_index
-=
1
def
make_sampling_metadata
(
# Trim lists to the batch size.
self
,
del
self
.
_req_ids
[
self
.
num_reqs
:]
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]],
del
self
.
req_output_token_ids
[
self
.
num_reqs
:]
req_id_to_spec_token_ids
:
Dict
[
str
,
List
[
int
]],
skip_copy
:
bool
=
False
,
def
refresh_sampling_metadata
(
self
):
)
->
SamplingMetadata
:
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
if
not
skip_copy
:
self
.
temperature
[:
self
.
num_reqs
].
copy_
(
def
_make_sampling_metadata
(
self
)
->
SamplingMetadata
:
self
.
temperature_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
num_reqs
=
self
.
num_reqs
self
.
top_p
[:
self
.
num_reqs
].
copy_
(
copy_slice
(
self
.
temperature_cpu_tensor
,
self
.
temperature
,
num_reqs
)
self
.
top_p_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
if
not
self
.
no_top_p
:
self
.
top_k
[:
self
.
num_reqs
].
copy_
(
copy_slice
(
self
.
top_p_cpu_tensor
,
self
.
top_p
,
num_reqs
)
self
.
top_k_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
if
not
self
.
no_top_k
:
self
.
min_p
[:
self
.
num_reqs
].
copy_
(
copy_slice
(
self
.
top_k_cpu_tensor
,
self
.
top_k
,
num_reqs
)
self
.
min_p_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
if
not
self
.
no_min_p
:
if
not
self
.
no_penalties
:
copy_slice
(
self
.
min_p_cpu_tensor
,
self
.
min_p
,
num_reqs
)
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
if
not
self
.
no_penalties
:
# penalties to be applied during sampling.
# Since syncing these tensors is expensive only copy them
self
.
frequency_penalties
[:
self
.
num_reqs
].
copy_
(
# if necessary i.e. if there are requests which require
self
.
frequency_penalties_cpu_tensor
[:
self
.
num_reqs
],
# penalties to be applied during sampling.
non_blocking
=
True
,
copy_slice
(
self
.
frequency_penalties_cpu_tensor
,
)
self
.
frequency_penalties
,
num_reqs
)
self
.
presence_penalties
[:
self
.
num_reqs
].
copy_
(
copy_slice
(
self
.
presence_penalties_cpu_tensor
,
self
.
presence_penalties_cpu_tensor
[:
self
.
num_reqs
],
self
.
presence_penalties
,
num_reqs
)
non_blocking
=
True
,
copy_slice
(
self
.
repetition_penalties_cpu_tensor
,
)
self
.
repetition_penalties
,
num_reqs
)
self
.
repetition_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
repetition_penalties_cpu_tensor
[:
self
.
num_reqs
],
# The prompt tokens are used only for applying penalties during
non_blocking
=
True
,
# the sampling process. Hence copy these tensors only when
)
# there are requests which need penalties to be applied.
# The prompt tokens are used only for applying penalties during
prompt_token_ids
=
self
.
_make_prompt_token_ids_tensor
()
# the sampling process. Hence copy these tensors only when
else
:
# there are requests which need penalties to be applied.
prompt_token_ids
=
None
self
.
prompt_token_ids
=
self
.
_make_prompt_token_ids_tensor
()
output_token_ids
:
List
[
List
[
int
]]
=
[]
spec_token_ids
:
List
[
List
[
int
]]
=
[]
rejection_sampling
=
False
for
req_id
in
self
.
req_ids
[:
self
.
num_reqs
]:
assert
req_id
is
not
None
# Currently we create a tensor for output_token_ids from scratch
# at each step. However, for the penalties computation what we
# need is stats about the token ids present in the output. This
# stats can be maintained incrementally instead of computing it
# from scratch at each step.
# TODO - Replace this with incremental update to output token
# statistics.
output_token_ids
.
append
(
req_id_output_token_ids
[
req_id
])
req_spec_token_ids
=
req_id_to_spec_token_ids
.
get
(
req_id
,
[])
spec_token_ids
.
append
(
req_spec_token_ids
)
if
req_spec_token_ids
:
# If any of the requests require speculative decoding, set the
# flag to True.
rejection_sampling
=
True
return
SamplingMetadata
(
return
SamplingMetadata
(
temperature
=
self
.
temperature
[:
self
.
num_reqs
],
temperature
=
self
.
temperature
[:
num_reqs
],
all_greedy
=
self
.
all_greedy
,
all_greedy
=
self
.
all_greedy
,
all_random
=
self
.
all_random
,
all_random
=
self
.
all_random
,
rejection_sampling
=
rejection_sampling
,
top_p
=
None
if
self
.
no_top_p
else
self
.
top_p
[:
num_reqs
],
top_p
=
self
.
top_p
[:
self
.
num_reqs
],
top_k
=
None
if
self
.
no_top_k
else
self
.
top_k
[:
num_reqs
],
top_k
=
self
.
top_k
[:
self
.
num_reqs
],
min_p
=
None
if
self
.
no_min_p
else
self
.
min_p
[:
num_reqs
],
min_p
=
self
.
min_p
[:
self
.
num_reqs
],
no_min_p
=
self
.
no_min_p
,
no_top_p
=
self
.
no_top_p
,
no_top_k
=
self
.
no_top_k
,
generators
=
self
.
generators
,
generators
=
self
.
generators
,
max_num_logprobs
=
self
.
max_num_logprobs
,
max_num_logprobs
=
self
.
max_num_logprobs
,
prompt_token_ids
=
self
.
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
frequency_penalties
=
self
.
frequency_penalties
[:
self
.
num_reqs
],
frequency_penalties
=
self
.
frequency_penalties
[:
num_reqs
],
presence_penalties
=
self
.
presence_penalties
[:
self
.
num_reqs
],
presence_penalties
=
self
.
presence_penalties
[:
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
self
.
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
num_reqs
],
output_token_ids
=
output_token_ids
,
output_token_ids
=
cast
(
List
[
List
[
int
]],
self
.
req_output_token_ids
),
spec_token_ids
=
spec_token_ids
,
spec_token_ids
=
None
,
min_tokens
=
self
.
min_tokens
[:
self
.
num_reqs
],
min_tokens
=
self
.
min_tokens
,
stop_token_ids
=
self
.
stop_token_ids
[:
self
.
num_reqs
],
no_penalties
=
self
.
no_penalties
,
no_penalties
=
self
.
no_penalties
,
logit_bias
=
self
.
logit_bias
[:
self
.
num_reqs
],
logit_bias
=
self
.
logit_bias
[:
num_reqs
],
)
)
def
get_sampling_metadata
(
self
,
req_id_to_spec_token_ids
:
Dict
[
str
,
List
[
int
]],
)
->
SamplingMetadata
:
# Set the new spec token ids in the cached sampling metadata.
self
.
sampling_metadata
.
spec_token_ids
=
[
req_id_to_spec_token_ids
.
get
(
req_id
,
[])
for
req_id
in
self
.
req_ids
]
if
req_id_to_spec_token_ids
else
None
return
self
.
sampling_metadata
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
max_prompt_len
=
self
.
num_prompt_tokens
[:
self
.
num_reqs
].
max
()
max_prompt_len
=
self
.
num_prompt_tokens
[:
self
.
num_reqs
].
max
()
prompt_token_ids_cpu_tensor
=
torch
.
empty
(
prompt_token_ids_cpu_tensor
=
torch
.
empty
(
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
30172b49
...
@@ -31,7 +31,6 @@ from vllm.v1.engine.mm_input_cache import MMInputCacheClient
...
@@ -31,7 +31,6 @@ from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
KVCacheSpec
)
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.utils
import
bind_kv_cache
...
@@ -224,16 +223,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -224,16 +223,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory
=
self
.
pin_memory
)
pin_memory
=
self
.
pin_memory
)
self
.
seq_lens_np
=
self
.
seq_lens_cpu
.
numpy
()
self
.
seq_lens_np
=
self
.
seq_lens_cpu
.
numpy
()
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""Update the cached states and the persistent batch with the scheduler
"""Update the cached states and the persistent batch with the scheduler
output.
output.
The updated states are used by the `_prepare_inputs` function to create
The updated states are used by the `_prepare_inputs` function to create
the input GPU tensors for the model.
the input GPU tensors for the model.
Returns:
The SamplingMetadata is updated and copied to the GPU if there is a
True if there is a new/resumed/paused/finished request in the batch.
new/resumed/paused/finished request in the batch.
If False, we can skip copying SamplingMetadata to the GPU.
"""
"""
# Remove finished requests from the cached states.
# Remove finished requests from the cached states.
for
req_id
in
scheduler_output
.
finished_req_ids
:
for
req_id
in
scheduler_output
.
finished_req_ids
:
...
@@ -344,9 +342,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -344,9 +342,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_new_tokens
=
(
num_computed_tokens
+
num_new_tokens
=
(
num_computed_tokens
+
len
(
req_data
.
new_token_ids
)
-
len
(
req_data
.
new_token_ids
)
-
req_state
.
num_tokens
)
req_state
.
num_tokens
)
new_token_ids
=
(
req_data
.
new_token_ids
[
-
num_new_tokens
:]
if
num_new_tokens
==
1
:
if
num_new_tokens
>
0
else
[])
# Avoid slicing list in most common case.
req_state
.
output_token_ids
.
extend
(
new_token_ids
)
req_state
.
output_token_ids
.
append
(
req_data
.
new_token_ids
[
-
1
])
elif
num_new_tokens
>
0
:
req_state
.
output_token_ids
.
extend
(
req_data
.
new_token_ids
[
-
num_new_tokens
:])
# Update the block IDs.
# Update the block IDs.
if
not
req_data
.
resumed_from_preemption
:
if
not
req_data
.
resumed_from_preemption
:
# Append the new blocks to the existing block IDs.
# Append the new blocks to the existing block IDs.
...
@@ -380,7 +381,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -380,7 +381,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
=
end_token_index
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
=
end_token_index
# Add spec_token_ids to token_ids_cpu.
# Add spec_token_ids to token_ids_cpu.
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[]
)
req_id
,
()
)
if
spec_token_ids
:
if
spec_token_ids
:
start_index
=
end_token_index
start_index
=
end_token_index
end_token_index
+=
len
(
spec_token_ids
)
end_token_index
+=
len
(
spec_token_ids
)
...
@@ -410,7 +411,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -410,7 +411,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
removed_req_indices
:
if
removed_req_indices
:
self
.
input_batch
.
condense
(
removed_req_indices
)
self
.
input_batch
.
condense
(
removed_req_indices
)
return
batch_changed
if
batch_changed
:
self
.
input_batch
.
refresh_sampling_metadata
()
def
_prepare_inputs
(
def
_prepare_inputs
(
self
,
self
,
...
@@ -429,8 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -429,8 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO: The Python loop can be slow. Optimize.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
num_scheduled_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
0
max_num_scheduled_tokens
=
0
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens
[
i
]
=
num_tokens
num_scheduled_tokens
[
i
]
=
num_tokens
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
...
@@ -669,10 +670,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -669,10 +670,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_calc_mrope_positions
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_calc_mrope_positions
(
self
,
scheduler_output
:
"SchedulerOutput"
):
mrope_pos_ptr
=
0
mrope_pos_ptr
=
0
num_reqs
=
self
.
input_batch
.
num_reqs
for
index
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
for
index
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
[:
num_reqs
]):
assert
req_id
is
not
None
req
=
self
.
requests
[
req_id
]
req
=
self
.
requests
[
req_id
]
assert
req
.
mrope_positions
is
not
None
assert
req
.
mrope_positions
is
not
None
...
@@ -726,12 +724,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -726,12 +724,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
cu_num_tokens
:
np
.
ndarray
,
cu_num_tokens
:
np
.
ndarray
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
# Get the number of spec decode tokens for each request.
# Get the number of spec decode tokens for each request.
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
num_spec_decode_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
num_spec_decode_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
num_spec_decode_tokens
[
i
]
=
len
(
num_spec_decode_tokens
[
i
]
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
()))
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
()))
...
@@ -769,22 +766,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -769,22 +766,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
torch
.
from_numpy
(
spec_decode_logits_indices
).
to
(
return
torch
.
from_numpy
(
spec_decode_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
device
,
non_blocking
=
True
)
def
_prepare_sampling
(
self
,
batch_changed
:
bool
,
req_to_spec_token_ids
:
Dict
[
str
,
List
[
int
]],
)
->
SamplingMetadata
:
# Create the sampling metadata.
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]]
=
\
{
req_id
:
req
.
output_token_ids
\
for
req_id
,
req
in
self
.
requests
.
items
()}
sampling_metadata
=
self
.
input_batch
.
make_sampling_metadata
(
req_id_output_token_ids
,
req_to_spec_token_ids
,
skip_copy
=
not
batch_changed
)
return
sampling_metadata
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
if
not
scheduled_encoder_inputs
:
...
@@ -838,9 +819,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -838,9 +819,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
List
[
torch
.
Tensor
]:
)
->
List
[
torch
.
Tensor
]:
encoder_outputs
:
List
[
torch
.
Tensor
]
=
[]
encoder_outputs
:
List
[
torch
.
Tensor
]
=
[]
num_reqs
=
self
.
input_batch
.
num_reqs
for
req_id
in
self
.
input_batch
.
req_ids
:
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
assert
req_id
is
not
None
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
req_id
]
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
...
@@ -882,7 +861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -882,7 +861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
ModelRunnerOutput
,
torch
.
Tensor
]:
)
->
Union
[
ModelRunnerOutput
,
torch
.
Tensor
]:
batch_changed
=
self
.
_update_states
(
scheduler_output
)
self
.
_update_states
(
scheduler_output
)
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
# Run the multimodal encoder if any.
...
@@ -964,8 +943,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -964,8 +943,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
# Sample the next token and get logprobs if needed.
# Sample the next token and get logprobs if needed.
sampling_metadata
=
self
.
_prepare_sampling
(
sampling_metadata
=
self
.
input_batch
.
get_sampling_metadata
(
batch_changed
,
scheduler_output
.
scheduled_spec_decode_tokens
)
scheduler_output
.
scheduled_spec_decode_tokens
)
sampler_output
=
self
.
model
.
sample
(
sampler_output
=
self
.
model
.
sample
(
logits
=
logits
,
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
...
@@ -973,14 +952,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -973,14 +952,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
# the requests one by one. Optimize.
num_reqs
=
self
.
input_batch
.
num_reqs
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
req_ids
:
List
[
str
]
=
[]
# Because `input_batch.req_ids` is a list of length `max_num_reqs`,
# we need to stop at `num_reqs`.
# FIXME(woosuk): This is hacky. Refactor.
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
req_ids
.
append
(
req_id
)
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
scheduler_output
.
num_scheduled_tokens
[
req_id
])
...
@@ -1027,7 +999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1027,7 +999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
valid_sampled_token_ids
)
valid_sampled_token_ids
)
model_runner_output
=
ModelRunnerOutput
(
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
req_ids
=
self
.
input_batch
.
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
valid_sampled_token_ids
,
sampled_token_ids
=
valid_sampled_token_ids
,
spec_token_ids
=
spec_token_ids
,
spec_token_ids
=
spec_token_ids
,
...
@@ -1041,19 +1013,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1041,19 +1013,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampled_token_ids
:
List
[
List
[
int
]],
sampled_token_ids
:
List
[
List
[
int
]],
)
->
List
[
List
[
int
]]:
)
->
List
[
List
[
int
]]:
# TODO(woosuk): Optimize.
# TODO(woosuk): Optimize.
num_reqs
=
len
(
sampled_token_ids
)
draft_token_ids
:
List
[
List
[
int
]]
=
[]
draft_token_ids
:
List
[
List
[
int
]]
=
[]
for
i
in
range
(
num_reqs
):
for
i
,
sampled_ids
in
enumerate
(
sampled_token_ids
):
if
len
(
sampled_token_ids
[
i
])
==
0
:
num_sampled_ids
=
len
(
sampled_ids
)
if
not
num_sampled_ids
:
# Skip speculative decoding.
# Skip speculative decoding.
draft_token_ids
.
append
([])
draft_token_ids
.
append
([])
continue
continue
# Add sampled_token_ids to token_ids_cpu.
# Add sampled_token_ids to token_ids_cpu.
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
i
]
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
i
]
end_idx
=
start_idx
+
len
(
sampled_token_ids
[
i
])
end_idx
=
start_idx
+
num_sampled_ids
self
.
input_batch
.
token_ids_cpu
[
self
.
input_batch
.
token_ids_cpu
[
i
,
start_idx
:
end_idx
]
=
sampled_ids
i
,
start_idx
:
end_idx
]
=
sampled_token_ids
[
i
]
drafter_output
=
self
.
drafter
.
propose
(
drafter_output
=
self
.
drafter
.
propose
(
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
],
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
],
self
.
speculative_config
.
ngram_prompt_lookup_min
,
self
.
speculative_config
.
ngram_prompt_lookup_min
,
...
@@ -1204,7 +1175,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1204,7 +1175,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# multiplying the list, to avoid Dynamo from treating them as
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
# tensor aliasing.
dummy_kv_caches
=
[
dummy_kv_caches
=
[
torch
.
tensor
(
[]
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
torch
.
tensor
(
()
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
_
in
range
(
self
.
num_attn_layers
)
for
_
in
range
(
self
.
num_attn_layers
)
]
]
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
30172b49
...
@@ -1048,8 +1048,6 @@ def swap_positions(b: InputBatch, id_1, id_2):
...
@@ -1048,8 +1048,6 @@ def swap_positions(b: InputBatch, id_1, id_2):
b
.
min_tokens
[
id_1
],
b
.
min_tokens
[
id_2
]
=
b
.
min_tokens
[
id_2
],
b
.
min_tokens
[
b
.
min_tokens
[
id_1
],
b
.
min_tokens
[
id_2
]
=
b
.
min_tokens
[
id_2
],
b
.
min_tokens
[
id_1
]
id_1
]
b
.
stop_token_ids
[
id_1
],
b
.
stop_token_ids
[
id_2
]
=
b
.
stop_token_ids
[
id_2
],
b
.
stop_token_ids
[
id_1
]
gen_1
=
b
.
generators
.
pop
(
id_1
,
None
)
gen_1
=
b
.
generators
.
pop
(
id_1
,
None
)
gen_2
=
b
.
generators
.
pop
(
id_2
,
None
)
gen_2
=
b
.
generators
.
pop
(
id_2
,
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment