Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8999ec3c
Unverified
Commit
8999ec3c
authored
Mar 05, 2024
by
Nick Hill
Committed by
GitHub
Mar 05, 2024
Browse files
Store `eos_token_id` in `Sequence` for easy access (#3166)
parent
05af6da8
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
44 additions
and
49 deletions
+44
-49
tests/test_cache_block_hashing.py
tests/test_cache_block_hashing.py
+2
-1
vllm/core/scheduler.py
vllm/core/scheduler.py
+3
-4
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+13
-17
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+0
-1
vllm/outputs.py
vllm/outputs.py
+21
-20
vllm/sequence.py
vllm/sequence.py
+5
-6
No files found.
tests/test_cache_block_hashing.py
View file @
8999ec3c
...
@@ -54,7 +54,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
...
@@ -54,7 +54,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
for
prompt
in
prompts
:
for
prompt
in
prompts
:
hashes
[
-
1
].
append
([])
hashes
[
-
1
].
append
([])
prompt_token_ids
=
tokenizer
.
encode
(
prompt
)
prompt_token_ids
=
tokenizer
.
encode
(
prompt
)
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
)
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
,
tokenizer
.
tokenizer
.
eos_token_id
)
num_blocks
=
len
(
prompt_token_ids
)
//
block_size
num_blocks
=
len
(
prompt_token_ids
)
//
block_size
for
idx
in
range
(
num_blocks
):
for
idx
in
range
(
num_blocks
):
...
...
vllm/core/scheduler.py
View file @
8999ec3c
...
@@ -59,10 +59,9 @@ class SchedulerOutputs:
...
@@ -59,10 +59,9 @@ class SchedulerOutputs:
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
def
_sort_by_lora_ids
(
self
)
->
bool
:
def
_sort_by_lora_ids
(
self
)
->
bool
:
self
.
scheduled_seq_groups
=
sorted
(
self
.
scheduled_seq_groups
=
sorted
(
self
.
scheduled_seq_groups
,
self
.
scheduled_seq_groups
,
key
=
lambda
g
:
key
=
lambda
g
:
(
g
.
lora_request
.
lora_int_id
(
g
.
lora_int_id
,
g
.
request_id
))
if
g
.
lora_request
else
0
,
g
.
request_id
))
@
property
@
property
def
lora_requests
(
self
)
->
Set
[
LoRARequest
]:
def
lora_requests
(
self
)
->
Set
[
LoRARequest
]:
...
...
vllm/engine/llm_engine.py
View file @
8999ec3c
...
@@ -491,8 +491,10 @@ class LLMEngine:
...
@@ -491,8 +491,10 @@ class LLMEngine:
# Create the sequences.
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
seq_id
=
next
(
self
.
seq_counter
)
eos_token_id
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
,
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
,
lora_request
)
eos_token_id
,
lora_request
)
# Defensive copy of SamplingParams, which are used by the sampler,
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
# this doesn't deep-copy LogitsProcessor objects
...
@@ -548,15 +550,13 @@ class LLMEngine:
...
@@ -548,15 +550,13 @@ class LLMEngine:
if
early_stopping
is
True
:
if
early_stopping
is
True
:
return
True
return
True
current_worst_score
=
(
current_worst_seq
.
get_beam_search_score
(
current_worst_score
=
current_worst_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
get_tokenizer_for_seq
(
eos_token_id
=
current_worst_seq
.
eos_token_id
)
current_worst_seq
).
eos_token_id
))
if
early_stopping
is
False
:
if
early_stopping
is
False
:
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
highest_attainable_score
=
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
get_tokenizer_for_seq
(
eos_token_id
=
best_running_seq
.
eos_token_id
)
best_running_seq
).
eos_token_id
))
else
:
else
:
assert
early_stopping
==
"never"
assert
early_stopping
==
"never"
if
length_penalty
>
0.0
:
if
length_penalty
>
0.0
:
...
@@ -570,8 +570,7 @@ class LLMEngine:
...
@@ -570,8 +570,7 @@ class LLMEngine:
highest_attainable_score
=
(
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
get_tokenizer_for_seq
(
eos_token_id
=
best_running_seq
.
eos_token_id
,
best_running_seq
).
eos_token_id
,
seq_len
=
max_possible_length
))
seq_len
=
max_possible_length
))
else
:
else
:
# Otherwise, beam search will prefer shorter sequences. The
# Otherwise, beam search will prefer shorter sequences. The
...
@@ -580,8 +579,7 @@ class LLMEngine:
...
@@ -580,8 +579,7 @@ class LLMEngine:
highest_attainable_score
=
(
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
get_tokenizer_for_seq
(
eos_token_id
=
best_running_seq
.
eos_token_id
))
best_running_seq
).
eos_token_id
))
return
current_worst_score
>=
highest_attainable_score
return
current_worst_score
>=
highest_attainable_score
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
...
@@ -679,8 +677,7 @@ class LLMEngine:
...
@@ -679,8 +677,7 @@ class LLMEngine:
all_finished_seqs
=
existing_finished_seqs
+
new_finished_seqs
all_finished_seqs
=
existing_finished_seqs
+
new_finished_seqs
# Sort the finished sequences by their scores.
# Sort the finished sequences by their scores.
all_finished_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
all_finished_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
x
[
0
].
eos_token_id
),
eos_token_id
=
self
.
get_tokenizer_for_seq
(
x
[
0
]).
eos_token_id
),
reverse
=
True
)
reverse
=
True
)
for
seq
,
parent
,
is_new
in
all_finished_seqs
[:
beam_width
]:
for
seq
,
parent
,
is_new
in
all_finished_seqs
[:
beam_width
]:
if
is_new
:
if
is_new
:
...
@@ -707,8 +704,7 @@ class LLMEngine:
...
@@ -707,8 +704,7 @@ class LLMEngine:
if
not
seq
.
is_finished
()]
if
not
seq
.
is_finished
()]
# Sort the running sequences by their scores.
# Sort the running sequences by their scores.
running_child_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
running_child_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
eos_token_id
=
x
[
0
].
eos_token_id
),
eos_token_id
=
self
.
get_tokenizer_for_seq
(
x
[
0
]).
eos_token_id
),
reverse
=
True
)
reverse
=
True
)
# Check if we can stop the beam search.
# Check if we can stop the beam search.
...
@@ -1014,8 +1010,8 @@ class LLMEngine:
...
@@ -1014,8 +1010,8 @@ class LLMEngine:
return
return
# Check if the sequence has generated the EOS token.
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
if
((
not
sampling_params
.
ignore_eos
)
==
se
lf
.
get_
tokenizer_for_seq
(
seq
)
.
eos_token_id
):
and
se
q
.
get_
last_token_id
()
==
seq
.
eos_token_id
):
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
return
...
...
vllm/model_executor/layers/sampler.py
View file @
8999ec3c
...
@@ -516,7 +516,6 @@ def _get_logprobs(
...
@@ -516,7 +516,6 @@ def _get_logprobs(
if
(
i
<
sampling_metadata
.
num_prompts
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
num_logprobs
=
sampling_params
.
prompt_logprobs
num_logprobs
=
sampling_params
.
prompt_logprobs
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
prompt_tokens
=
sampling_metadata
.
seq_data
[
prompt_tokens
=
sampling_metadata
.
seq_data
[
seq_ids
[
0
]].
prompt_token_ids
seq_ids
[
0
]].
prompt_token_ids
group_prompt_logprobs
:
PromptLogprobs
=
[
None
]
group_prompt_logprobs
:
PromptLogprobs
=
[
None
]
...
...
vllm/outputs.py
View file @
8999ec3c
...
@@ -90,6 +90,9 @@ class RequestOutput:
...
@@ -90,6 +90,9 @@ class RequestOutput:
# Get the top-n sequences.
# Get the top-n sequences.
n
=
seq_group
.
sampling_params
.
n
n
=
seq_group
.
sampling_params
.
n
seqs
=
seq_group
.
get_seqs
()
seqs
=
seq_group
.
get_seqs
()
if
n
==
1
:
top_n_seqs
=
seqs
else
:
if
seq_group
.
sampling_params
.
use_beam_search
:
if
seq_group
.
sampling_params
.
use_beam_search
:
sorting_key
=
lambda
seq
:
seq
.
get_beam_search_score
(
sorting_key
=
lambda
seq
:
seq
.
get_beam_search_score
(
seq_group
.
sampling_params
.
length_penalty
)
seq_group
.
sampling_params
.
length_penalty
)
...
@@ -99,20 +102,18 @@ class RequestOutput:
...
@@ -99,20 +102,18 @@ class RequestOutput:
top_n_seqs
=
sorted_seqs
[:
n
]
top_n_seqs
=
sorted_seqs
[:
n
]
# Create the outputs.
# Create the outputs.
outputs
:
List
[
CompletionOutput
]
=
[]
# NOTE: We need omit logprobs here explicitly because the sequence
for
seq
in
top_n_seqs
:
logprobs
=
seq
.
output_logprobs
if
seq_group
.
sampling_params
.
logprobs
is
None
:
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
# logprobs are not requested.
logprobs
=
None
include_
logprobs
=
seq_group
.
sampling_params
.
logprobs
finshed_reason
=
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
outputs
=
[
output
=
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
output_text
,
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
output_text
,
seq
.
get_output_token_ids
(),
seq
.
get_output_token_ids
(),
seq
.
get_cumulative_logprob
(),
logprobs
,
seq
.
get_cumulative_logprob
(),
finshed_reason
)
seq
.
output_logprobs
if
include_logprobs
else
None
,
outputs
.
append
(
output
)
SequenceStatus
.
get_finished_reason
(
seq
.
status
))
for
seq
in
top_n_seqs
]
# Every sequence in the sequence group should have the same prompt.
# Every sequence in the sequence group should have the same prompt.
prompt
=
seq_group
.
prompt
prompt
=
seq_group
.
prompt
...
...
vllm/sequence.py
View file @
8999ec3c
...
@@ -142,11 +142,13 @@ class Sequence:
...
@@ -142,11 +142,13 @@ class Sequence:
prompt
:
str
,
prompt
:
str
,
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
block_size
:
int
,
block_size
:
int
,
eos_token_id
:
int
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
)
->
None
:
self
.
seq_id
=
seq_id
self
.
seq_id
=
seq_id
self
.
prompt
=
prompt
self
.
prompt
=
prompt
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
eos_token_id
=
eos_token_id
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
data
=
SequenceData
(
prompt_token_ids
)
self
.
data
=
SequenceData
(
prompt_token_ids
)
...
@@ -362,10 +364,7 @@ class SequenceGroup:
...
@@ -362,10 +364,7 @@ class SequenceGroup:
self
,
self
,
status
:
Optional
[
SequenceStatus
]
=
None
,
status
:
Optional
[
SequenceStatus
]
=
None
,
)
->
List
[
Sequence
]:
)
->
List
[
Sequence
]:
if
status
is
None
:
return
list
(
self
.
seqs_dict
.
values
())
if
status
is
None
else
[
return
list
(
self
.
seqs_dict
.
values
())
else
:
return
[
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
seq
.
status
==
status
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
seq
.
status
==
status
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment