Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7134303c
Unverified
Commit
7134303c
authored
Apr 27, 2024
by
Roy
Committed by
GitHub
Apr 27, 2024
Browse files
[Bugfix][Core] Fix get decoding config from ray (#4335)
parent
3da24c2d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
174 additions
and
3 deletions
+174
-3
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+2
-0
tests/async_engine/test_openapi_server_ray.py
tests/async_engine/test_openapi_server_ray.py
+157
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+9
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+1
-1
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-1
No files found.
tests/async_engine/test_async_llm_engine.py
View file @
7134303c
...
...
@@ -91,4 +91,6 @@ async def test_new_requests_event():
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
engine
=
MockAsyncLLMEngine
(
worker_use_ray
=
True
,
engine_use_ray
=
True
)
assert
engine
.
get_model_config
()
is
not
None
assert
engine
.
get_tokenizer
()
is
not
None
assert
engine
.
get_decoding_config
()
is
not
None
tests/async_engine/test_openapi_server_ray.py
0 → 100644
View file @
7134303c
# imports for guided decoding tests
import
os
import
subprocess
import
sys
import
time
import
openai
# use the official client for correctness check
import
pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import
ray
import
requests
MAX_SERVER_START_WAIT_S
=
600
# wait for server to start for 60 seconds
# any model with a chat template should work here
MODEL_NAME
=
"facebook/opt-125m"
@
ray
.
remote
(
num_gpus
=
1
)
class
ServerRunner
:
def
__init__
(
self
,
args
):
env
=
os
.
environ
.
copy
()
env
[
"PYTHONUNBUFFERED"
]
=
"1"
self
.
proc
=
subprocess
.
Popen
(
[
"python3"
,
"-m"
,
"vllm.entrypoints.openai.api_server"
]
+
args
,
env
=
env
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stderr
,
)
self
.
_wait_for_server
()
def
ready
(
self
):
return
True
def
_wait_for_server
(
self
):
# run health check
start
=
time
.
time
()
while
True
:
try
:
if
requests
.
get
(
"http://localhost:8000/health"
).
status_code
==
200
:
break
except
Exception
as
err
:
if
self
.
proc
.
poll
()
is
not
None
:
raise
RuntimeError
(
"Server exited unexpectedly."
)
from
err
time
.
sleep
(
0.5
)
if
time
.
time
()
-
start
>
MAX_SERVER_START_WAIT_S
:
raise
RuntimeError
(
"Server failed to start in time."
)
from
err
def
__del__
(
self
):
if
hasattr
(
self
,
"proc"
):
self
.
proc
.
terminate
()
@
pytest
.
fixture
(
scope
=
"session"
)
def
server
():
ray
.
init
()
server_runner
=
ServerRunner
.
remote
([
"--model"
,
MODEL_NAME
,
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"float16"
,
"--max-model-len"
,
"2048"
,
"--enforce-eager"
,
"--engine-use-ray"
])
ray
.
get
(
server_runner
.
ready
.
remote
())
yield
server_runner
ray
.
shutdown
()
@
pytest
.
fixture
(
scope
=
"session"
)
def
client
():
client
=
openai
.
AsyncOpenAI
(
base_url
=
"http://localhost:8000/v1"
,
api_key
=
"token-abc123"
,
)
yield
client
@
pytest
.
mark
.
asyncio
async
def
test_check_models
(
server
,
client
:
openai
.
AsyncOpenAI
):
models
=
await
client
.
models
.
list
()
models
=
models
.
data
served_model
=
models
[
0
]
assert
served_model
.
id
==
MODEL_NAME
assert
all
(
model
.
root
==
MODEL_NAME
for
model
in
models
)
@
pytest
.
mark
.
asyncio
async
def
test_single_completion
(
server
,
client
:
openai
.
AsyncOpenAI
):
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
"Hello, my name is"
,
max_tokens
=
5
,
temperature
=
0.0
)
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
1
assert
completion
.
choices
[
0
].
text
is
not
None
and
len
(
completion
.
choices
[
0
].
text
)
>=
5
assert
completion
.
choices
[
0
].
finish_reason
==
"length"
assert
completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
5
,
prompt_tokens
=
6
,
total_tokens
=
11
)
# test using token IDs
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
)
assert
completion
.
choices
[
0
].
text
is
not
None
and
len
(
completion
.
choices
[
0
].
text
)
>=
5
@
pytest
.
mark
.
asyncio
async
def
test_single_chat_session
(
server
,
client
:
openai
.
AsyncOpenAI
):
messages
=
[{
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}]
# test single completion
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
10
,
logprobs
=
True
,
top_logprobs
=
5
)
assert
chat_completion
.
id
is
not
None
assert
chat_completion
.
choices
is
not
None
and
len
(
chat_completion
.
choices
)
==
1
assert
chat_completion
.
choices
[
0
].
message
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
is
not
None
assert
len
(
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
[
0
])
==
5
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
role
==
"assistant"
messages
.
append
({
"role"
:
"assistant"
,
"content"
:
message
.
content
})
# test multi-turn dialogue
messages
.
append
({
"role"
:
"user"
,
"content"
:
"express your result in json"
})
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
10
,
)
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
0
vllm/engine/async_llm_engine.py
View file @
7134303c
...
...
@@ -7,7 +7,7 @@ from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
ModelConfig
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
...
...
@@ -697,6 +697,14 @@ class AsyncLLMEngine:
else
:
return
self
.
engine
.
get_model_config
()
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
"""Get the decoding configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_decoding_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_decoding_config
()
async
def
do_log_stats
(
self
)
->
None
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
do_log_stats
.
remote
()
# type: ignore
...
...
vllm/engine/llm_engine.py
View file @
7134303c
...
...
@@ -467,6 +467,10 @@ class LLMEngine:
"""Gets the model configuration."""
return
self
.
model_config
def
get_decoding_config
(
self
)
->
DecodingConfig
:
"""Gets the decoding configuration."""
return
self
.
decoding_config
def
get_num_unfinished_requests
(
self
)
->
int
:
"""Gets the number of unfinished requests."""
return
self
.
scheduler
.
get_num_unfinished_seq_groups
()
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
7134303c
...
...
@@ -101,7 +101,7 @@ class OpenAIServingChat(OpenAIServing):
request
,
prompt
=
prompt
)
sampling_params
=
request
.
to_sampling_params
()
lora_request
=
self
.
_maybe_get_lora
(
request
)
decoding_config
=
self
.
engine
.
engine
.
decoding_config
decoding_config
=
await
self
.
engine
.
get_
decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logits_processor
=
(
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
7134303c
...
...
@@ -89,7 +89,7 @@ class OpenAIServingCompletion(OpenAIServing):
try
:
sampling_params
=
request
.
to_sampling_params
()
lora_request
=
self
.
_maybe_get_lora
(
request
)
decoding_config
=
self
.
engine
.
engine
.
decoding_config
decoding_config
=
await
self
.
engine
.
get_
decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logit_processor
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment