Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
18fd4a83
Unverified
Commit
18fd4a83
authored
Jan 21, 2025
by
Andy Lo
Committed by
GitHub
Jan 21, 2025
Browse files
[Bugfix] Multi-sequence broken (#11898)
Signed-off-by:
Andy Lo
<
andy@mistral.ai
>
parent
132a1321
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
39 deletions
+59
-39
tests/samplers/test_seeded_generate.py
tests/samplers/test_seeded_generate.py
+6
-1
vllm/outputs.py
vllm/outputs.py
+1
-1
vllm/sequence.py
vllm/sequence.py
+52
-37
No files found.
tests/samplers/test_seeded_generate.py
View file @
18fd4a83
...
@@ -31,7 +31,7 @@ def test_random_sample_with_seed(
...
@@ -31,7 +31,7 @@ def test_random_sample_with_seed(
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
# Parameters to ensure sufficient randomness
# Parameters to ensure sufficient randomness
temperature
=
2
.0
,
temperature
=
3
.0
,
top_p
=
min
(
random
.
random
()
+
0.3
,
1
),
top_p
=
min
(
random
.
random
()
+
0.3
,
1
),
top_k
=
random
.
randint
(
5
,
20
),
top_k
=
random
.
randint
(
5
,
20
),
n
=
random
.
randint
(
1
,
10
),
n
=
random
.
randint
(
1
,
10
),
...
@@ -75,3 +75,8 @@ def test_random_sample_with_seed(
...
@@ -75,3 +75,8 @@ def test_random_sample_with_seed(
# verify requests with the same seed match
# verify requests with the same seed match
assert
outputs
[
1
]
==
outputs
[
4
]
assert
outputs
[
1
]
==
outputs
[
4
]
assert
outputs
[
2
]
==
outputs
[
5
]
assert
outputs
[
2
]
==
outputs
[
5
]
# verify generations within the same parallel sampling group differ
for
output
in
outputs
:
for
sub_output_a
,
sub_output_b
in
combinations
(
output
,
2
):
assert
sub_output_a
!=
sub_output_b
vllm/outputs.py
View file @
18fd4a83
...
@@ -172,9 +172,9 @@ class RequestOutput:
...
@@ -172,9 +172,9 @@ class RequestOutput:
if
seq_group
.
request_id
in
seq_id_to_seq_group
:
if
seq_group
.
request_id
in
seq_id_to_seq_group
:
group
:
SequenceGroupBase
=
seq_id_to_seq_group
[
group
:
SequenceGroupBase
=
seq_id_to_seq_group
[
seq_group
.
request_id
]
seq_group
.
request_id
]
assembled_seq_group
=
group
.
maybe_assemble_group
(
seq_group
)
if
finished
:
if
finished
:
group
.
finish_seq
(
seq_group
)
group
.
finish_seq
(
seq_group
)
assembled_seq_group
=
group
.
maybe_assemble_group
(
seq_group
)
if
assembled_seq_group
is
None
:
if
assembled_seq_group
is
None
:
return
None
return
None
return
cls
.
from_seq_group
(
assembled_seq_group
,
use_cache
,
return
cls
.
from_seq_group
(
assembled_seq_group
,
use_cache
,
...
...
vllm/sequence.py
View file @
18fd4a83
...
@@ -815,7 +815,9 @@ class SequenceGroup:
...
@@ -815,7 +815,9 @@ class SequenceGroup:
def
get_max_num_running_seqs
(
self
)
->
int
:
def
get_max_num_running_seqs
(
self
)
->
int
:
"""The maximum number of sequences running in parallel in the remaining
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
lifetime of the request."""
if
self
.
is_single_seq
:
return
0
if
self
.
first_seq
.
is_finished
()
else
1
return
0
if
self
.
first_seq
.
is_finished
()
else
1
return
self
.
num_seqs
()
-
self
.
num_finished_seqs
()
def
get_seqs
(
def
get_seqs
(
self
,
self
,
...
@@ -824,8 +826,11 @@ class SequenceGroup:
...
@@ -824,8 +826,11 @@ class SequenceGroup:
if
status
is
None
:
if
status
is
None
:
return
self
.
seqs
return
self
.
seqs
if
self
.
is_single_seq
:
return
self
.
seqs
if
self
.
first_seq
.
status
==
status
else
[]
return
self
.
seqs
if
self
.
first_seq
.
status
==
status
else
[]
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
]
def
is_encoder_decoder
(
self
)
->
bool
:
def
is_encoder_decoder
(
self
)
->
bool
:
return
self
.
encoder_seq
is
not
None
return
self
.
encoder_seq
is
not
None
...
@@ -833,17 +838,20 @@ class SequenceGroup:
...
@@ -833,17 +838,20 @@ class SequenceGroup:
return
self
.
encoder_seq
return
self
.
encoder_seq
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
if
self
.
is_single_seq
:
return
self
.
seqs
if
self
.
first_seq
.
is_finished
()
else
[]
return
self
.
seqs
if
self
.
first_seq
.
is_finished
()
else
[]
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
is_finished
()]
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
"""Update number of tokens computed so far."""
"""Update number of tokens computed so far."""
seq
=
self
.
first_
seq
for
seq
in
self
.
seq
s
:
if
not
seq
.
is_finished
():
if
not
seq
.
is_finished
():
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
def
get_num_uncomputed_tokens
(
self
)
->
int
:
def
get_num_uncomputed_tokens
(
self
)
->
int
:
num_uncomputed_tokens
=
0
num_uncomputed_tokens
=
0
seq
=
self
.
first_
seq
for
seq
in
self
.
seq
s
:
if
not
seq
.
is_finished
():
if
not
seq
.
is_finished
():
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
return
num_uncomputed_tokens
return
num_uncomputed_tokens
...
@@ -860,10 +868,14 @@ class SequenceGroup:
...
@@ -860,10 +868,14 @@ class SequenceGroup:
return
len
(
self
.
get_seqs
(
status
))
return
len
(
self
.
get_seqs
(
status
))
def
num_finished_seqs
(
self
)
->
int
:
def
num_finished_seqs
(
self
)
->
int
:
return
1
if
self
.
first_seq
.
is_finished
()
else
0
if
self
.
is_single_seq
:
return
1
if
self
.
seqs
[
0
].
is_finished
()
else
0
return
len
(
self
.
get_finished_seqs
())
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
if
self
.
is_single_seq
:
return
self
.
first_seq
.
is_finished
()
return
self
.
first_seq
.
is_finished
()
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
def
is_prefill
(
self
)
->
bool
:
def
is_prefill
(
self
)
->
bool
:
return
self
.
first_seq
.
is_prefill
()
return
self
.
first_seq
.
is_prefill
()
...
@@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
...
@@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
@
staticmethod
@
staticmethod
def
add_request
(
request_id
:
str
,
engine
,
params
,
**
kwargs
):
def
add_request
(
request_id
:
str
,
engine
,
params
,
**
kwargs
):
original_params
=
params
original_params
=
params
params
=
original_params
.
clone
()
params
.
n
=
1
group
=
ParallelSampleSequenceGroup
(
request_id
)
group
=
ParallelSampleSequenceGroup
(
request_id
)
seqs
=
[]
seqs
=
[]
for
i
in
range
(
original_params
.
n
):
for
i
in
range
(
original_params
.
n
):
request_id_i
=
f
"
{
request_id
}
_parallel_sample_
{
i
}
"
request_id_i
=
f
"
{
request_id
}
_parallel_sample_
{
i
}
"
group
.
seq_id_to_index
[
request_id_i
]
=
i
group
.
seq_id_to_index
[
request_id_i
]
=
i
params
=
copy
.
deepcopy
(
original_params
)
params
.
n
=
1
if
params
.
seed
is
not
None
:
params
.
seed
+=
i
seq_group
=
engine
.
_add_processed_request
(
seq_group
=
engine
.
_add_processed_request
(
request_id_i
,
request_id_i
,
params
=
params
,
params
=
params
,
...
@@ -1432,20 +1446,20 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
...
@@ -1432,20 +1446,20 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
self
,
seq_group
:
SequenceGroup
)
->
Optional
[
SequenceGroup
]:
self
,
seq_group
:
SequenceGroup
)
->
Optional
[
SequenceGroup
]:
# in the streaming mode, we will return the assembled sequence
# in the streaming mode, we will return the assembled sequence
# for the first sequence, and then return None for the
rest of
# for the first
remaining
sequence, and then return None for the
# sequences
#
rest of
sequences
if
self
.
streaming
:
if
self
.
streaming
:
if
self
.
seq_id_to_index
[
seq_group
.
request_id
]
==
0
:
first_remaining_id
=
next
(
iter
(
self
.
to_be_finished
))
if
seq_group
.
request_id
==
first_remaining_id
:
return
self
.
assembled_seq_group
return
self
.
assembled_seq_group
return
None
return
None
# in the non-streaming mode, we will return the assembled sequence
# in the non-streaming mode, we will return the assembled sequence
#
once after all
sequences finish, and then return None for the
#
when the last
sequences finish
es
, and then return None for the
# rest of the time
# rest of the time
if
(
len
(
self
.
to_be_finished
)
==
1
if
len
(
self
.
to_be_finished
)
>
0
:
and
seq_group
.
request_id
in
self
.
to_be_finished
return
None
and
seq_group
.
is_finished
()):
assert
self
.
assembled_seq_group
is
not
None
assert
self
.
assembled_seq_group
is
not
None
params
=
self
.
assembled_seq_group
.
sampling_params
params
=
self
.
assembled_seq_group
.
sampling_params
assert
isinstance
(
params
,
SamplingParams
)
assert
isinstance
(
params
,
SamplingParams
)
...
@@ -1462,3 +1476,4 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
...
@@ -1462,3 +1476,4 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
return
self
.
assembled_seq_group
return
self
.
assembled_seq_group
if
self
.
output_produced
:
if
self
.
output_produced
:
return
None
return
None
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment