Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
197b4484
Unverified
Commit
197b4484
authored
Nov 27, 2024
by
Mor Zusman
Committed by
GitHub
Nov 27, 2024
Browse files
[Bugfix][Mamba] Fix Multistep on Mamba-like models (#10705)
Signed-off-by:
mzusman
<
mor.zusmann@gmail.com
>
parent
b98c62ba
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
4 deletions
+84
-4
tests/models/decoder_only/language/test_jamba.py
tests/models/decoder_only/language/test_jamba.py
+38
-0
tests/models/decoder_only/language/test_mamba.py
tests/models/decoder_only/language/test_mamba.py
+36
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+5
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+5
-2
No files found.
tests/models/decoder_only/language/test_jamba.py
View file @
197b4484
...
...
@@ -275,6 +275,44 @@ def test_state_cleanup(
"could be related to finished_requests_ids"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_multistep
(
vllm_runner
,
model
:
str
,
dtype
:
str
,
example_prompts
,
)
->
None
:
# This test is verifying that multistep works correctly
#on mamba-like models
with
vllm_runner
(
model
,
num_scheduler_steps
=
8
,
max_num_seqs
=
2
)
as
vllm_model
:
vllm_model
.
generate_greedy
([
example_prompts
[
0
]]
*
10
,
1
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
def
test_multistep_correctness
(
vllm_runner
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
example_prompts
)
->
None
:
with
vllm_runner
(
model
,
num_scheduler_steps
=
8
,
max_num_seqs
=
2
)
as
vllm_model
:
vllm_outputs_multistep
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
with
vllm_runner
(
model
,
num_scheduler_steps
=
1
,
max_num_seqs
=
2
)
as
vllm_model
:
vllm_outputs_single_step
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
check_outputs_equal
(
outputs_0_lst
=
vllm_outputs_multistep
,
outputs_1_lst
=
vllm_outputs_single_step
,
name_0
=
"vllm_outputs_multistep"
,
name_1
=
"vllm_outputs_single_step"
,
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
...
...
tests/models/decoder_only/language/test_mamba.py
View file @
197b4484
...
...
@@ -283,3 +283,39 @@ def test_state_cleanup(
except
ValueError
:
pytest
.
fail
(
"Mamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_multistep
(
vllm_runner
,
model
:
str
,
dtype
:
str
,
example_prompts
,
)
->
None
:
with
vllm_runner
(
model
,
num_scheduler_steps
=
8
,
max_num_seqs
=
2
)
as
vllm_model
:
vllm_model
.
generate_greedy
([
example_prompts
[
0
]]
*
10
,
1
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
def
test_multistep_correctness
(
vllm_runner
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
example_prompts
)
->
None
:
with
vllm_runner
(
model
,
num_scheduler_steps
=
8
,
max_num_seqs
=
2
)
as
vllm_model
:
vllm_outputs_multistep
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
with
vllm_runner
(
model
,
num_scheduler_steps
=
1
,
max_num_seqs
=
2
)
as
vllm_model
:
vllm_outputs_single_step
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
check_outputs_equal
(
outputs_0_lst
=
vllm_outputs_multistep
,
outputs_1_lst
=
vllm_outputs_single_step
,
name_0
=
"vllm_outputs_multistep"
,
name_1
=
"vllm_outputs_single_step"
,
)
vllm/engine/async_llm_engine.py
View file @
197b4484
...
...
@@ -300,6 +300,9 @@ class _AsyncLLMEngine(LLMEngine):
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
...
...
@@ -311,13 +314,13 @@ class _AsyncLLMEngine(LLMEngine):
self
.
_cache_scheduler_outputs_for_multi_step
(
virtual_engine
,
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
else
:
finished_requests_ids
=
list
()
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
...
...
vllm/engine/llm_engine.py
View file @
197b4484
...
...
@@ -1398,6 +1398,9 @@ class LLMEngine:
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
...
...
@@ -1409,13 +1412,13 @@ class LLMEngine:
self
.
_cache_scheduler_outputs_for_multi_step
(
virtual_engine
,
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
else
:
finished_requests_ids
=
list
()
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment