Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
18fd4a83
Unverified
Commit
18fd4a83
authored
Jan 21, 2025
by
Andy Lo
Committed by
GitHub
Jan 21, 2025
Browse files
[Bugfix] Multi-sequence broken (#11898)
Signed-off-by:
Andy Lo
<
andy@mistral.ai
>
parent
132a1321
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
39 deletions
+59
-39
tests/samplers/test_seeded_generate.py
tests/samplers/test_seeded_generate.py
+6
-1
vllm/outputs.py
vllm/outputs.py
+1
-1
vllm/sequence.py
vllm/sequence.py
+52
-37
No files found.
tests/samplers/test_seeded_generate.py
View file @
18fd4a83
...
...
@@ -31,7 +31,7 @@ def test_random_sample_with_seed(
sampling_params
=
SamplingParams
(
# Parameters to ensure sufficient randomness
temperature
=
2
.0
,
temperature
=
3
.0
,
top_p
=
min
(
random
.
random
()
+
0.3
,
1
),
top_k
=
random
.
randint
(
5
,
20
),
n
=
random
.
randint
(
1
,
10
),
...
...
@@ -75,3 +75,8 @@ def test_random_sample_with_seed(
# verify requests with the same seed match
assert
outputs
[
1
]
==
outputs
[
4
]
assert
outputs
[
2
]
==
outputs
[
5
]
# verify generations within the same parallel sampling group differ
for
output
in
outputs
:
for
sub_output_a
,
sub_output_b
in
combinations
(
output
,
2
):
assert
sub_output_a
!=
sub_output_b
vllm/outputs.py
View file @
18fd4a83
...
...
@@ -172,9 +172,9 @@ class RequestOutput:
if
seq_group
.
request_id
in
seq_id_to_seq_group
:
group
:
SequenceGroupBase
=
seq_id_to_seq_group
[
seq_group
.
request_id
]
assembled_seq_group
=
group
.
maybe_assemble_group
(
seq_group
)
if
finished
:
group
.
finish_seq
(
seq_group
)
assembled_seq_group
=
group
.
maybe_assemble_group
(
seq_group
)
if
assembled_seq_group
is
None
:
return
None
return
cls
.
from_seq_group
(
assembled_seq_group
,
use_cache
,
...
...
vllm/sequence.py
View file @
18fd4a83
...
...
@@ -815,7 +815,9 @@ class SequenceGroup:
def
get_max_num_running_seqs
(
self
)
->
int
:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if
self
.
is_single_seq
:
return
0
if
self
.
first_seq
.
is_finished
()
else
1
return
self
.
num_seqs
()
-
self
.
num_finished_seqs
()
def
get_seqs
(
self
,
...
...
@@ -824,8 +826,11 @@ class SequenceGroup:
if
status
is
None
:
return
self
.
seqs
if
self
.
is_single_seq
:
return
self
.
seqs
if
self
.
first_seq
.
status
==
status
else
[]
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
]
def
is_encoder_decoder
(
self
)
->
bool
:
return
self
.
encoder_seq
is
not
None
...
...
@@ -833,17 +838,20 @@ class SequenceGroup:
return
self
.
encoder_seq
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
if
self
.
is_single_seq
:
return
self
.
seqs
if
self
.
first_seq
.
is_finished
()
else
[]
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
is_finished
()]
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
"""Update number of tokens computed so far."""
seq
=
self
.
first_
seq
for
seq
in
self
.
seq
s
:
if
not
seq
.
is_finished
():
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
def
get_num_uncomputed_tokens
(
self
)
->
int
:
num_uncomputed_tokens
=
0
seq
=
self
.
first_
seq
for
seq
in
self
.
seq
s
:
if
not
seq
.
is_finished
():
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
return
num_uncomputed_tokens
...
...
@@ -860,10 +868,14 @@ class SequenceGroup:
return
len
(
self
.
get_seqs
(
status
))
def
num_finished_seqs
(
self
)
->
int
:
return
1
if
self
.
first_seq
.
is_finished
()
else
0
if
self
.
is_single_seq
:
return
1
if
self
.
seqs
[
0
].
is_finished
()
else
0
return
len
(
self
.
get_finished_seqs
())
def
is_finished
(
self
)
->
bool
:
if
self
.
is_single_seq
:
return
self
.
first_seq
.
is_finished
()
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
def
is_prefill
(
self
)
->
bool
:
return
self
.
first_seq
.
is_prefill
()
...
...
@@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
@
staticmethod
def
add_request
(
request_id
:
str
,
engine
,
params
,
**
kwargs
):
original_params
=
params
params
=
original_params
.
clone
()
params
.
n
=
1
group
=
ParallelSampleSequenceGroup
(
request_id
)
seqs
=
[]
for
i
in
range
(
original_params
.
n
):
request_id_i
=
f
"
{
request_id
}
_parallel_sample_
{
i
}
"
group
.
seq_id_to_index
[
request_id_i
]
=
i
params
=
copy
.
deepcopy
(
original_params
)
params
.
n
=
1
if
params
.
seed
is
not
None
:
params
.
seed
+=
i
seq_group
=
engine
.
_add_processed_request
(
request_id_i
,
params
=
params
,
...
...
@@ -1432,20 +1446,20 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
self
,
seq_group
:
SequenceGroup
)
->
Optional
[
SequenceGroup
]:
# in the streaming mode, we will return the assembled sequence
# for the first sequence, and then return None for the
rest of
# sequences
# for the first
remaining
sequence, and then return None for the
#
rest of
sequences
if
self
.
streaming
:
if
self
.
seq_id_to_index
[
seq_group
.
request_id
]
==
0
:
first_remaining_id
=
next
(
iter
(
self
.
to_be_finished
))
if
seq_group
.
request_id
==
first_remaining_id
:
return
self
.
assembled_seq_group
return
None
# in the non-streaming mode, we will return the assembled sequence
#
once after all
sequences finish, and then return None for the
#
when the last
sequences finish
es
, and then return None for the
# rest of the time
if
len
(
self
.
to_be_finished
)
>
0
:
return
None
if
(
len
(
self
.
to_be_finished
)
==
1
and
seq_group
.
request_id
in
self
.
to_be_finished
and
seq_group
.
is_finished
()):
assert
self
.
assembled_seq_group
is
not
None
params
=
self
.
assembled_seq_group
.
sampling_params
assert
isinstance
(
params
,
SamplingParams
)
...
...
@@ -1462,3 +1476,4 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
return
self
.
assembled_seq_group
if
self
.
output_produced
:
return
None
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment