Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d04b13a3
Unverified
Commit
d04b13a3
authored
Nov 26, 2024
by
Chauncey
Committed by
GitHub
Nov 25, 2024
Browse files
[Bug]: Authorization ignored when root_path is set (#10606)
Signed-off-by:
chaunceyjiang
<
chaunceyjiang@gmail.com
>
parent
2b0879bf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
2 deletions
+107
-2
tests/entrypoints/openai/test_root_path.py
tests/entrypoints/openai/test_root_path.py
+103
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+4
-2
No files found.
tests/entrypoints/openai/test_root_path.py
0 → 100644
View file @
d04b13a3
import
contextlib
import
os
from
typing
import
Any
,
List
,
NamedTuple
import
openai
# use the official client for correctness check
import
pytest
from
...utils
import
RemoteOpenAIServer
# # any model with a chat template should work here
MODEL_NAME
=
"Qwen/Qwen2-1.5B-Instruct"
DUMMY_CHAT_TEMPLATE
=
"""{% for message in messages %}{{message['role'] + ': ' + message['content'] + '
\\
n'}}{% endfor %}"""
# noqa: E501
API_KEY
=
"abc-123"
ERROR_API_KEY
=
"abc"
ROOT_PATH
=
"llm"
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
args
=
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"float16"
,
"--enforce-eager"
,
"--max-model-len"
,
"4080"
,
"--root-path"
,
# use --root-path=/llm for testing
"/"
+
ROOT_PATH
,
"--chat-template"
,
DUMMY_CHAT_TEMPLATE
,
]
envs
=
os
.
environ
.
copy
()
envs
[
"VLLM_API_KEY"
]
=
API_KEY
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
,
env_dict
=
envs
)
as
remote_server
:
yield
remote_server
class
TestCase
(
NamedTuple
):
model_name
:
str
base_url
:
List
[
str
]
api_key
:
str
expected_error
:
Any
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"test_case"
,
[
TestCase
(
model_name
=
MODEL_NAME
,
base_url
=
[
"v1"
],
# http://localhost:8000/v1
api_key
=
ERROR_API_KEY
,
expected_error
=
openai
.
AuthenticationError
),
TestCase
(
model_name
=
MODEL_NAME
,
base_url
=
[
ROOT_PATH
,
"v1"
],
# http://localhost:8000/llm/v1
api_key
=
ERROR_API_KEY
,
expected_error
=
openai
.
AuthenticationError
),
TestCase
(
model_name
=
MODEL_NAME
,
base_url
=
[
"v1"
],
# http://localhost:8000/v1
api_key
=
API_KEY
,
expected_error
=
None
),
TestCase
(
model_name
=
MODEL_NAME
,
base_url
=
[
ROOT_PATH
,
"v1"
],
# http://localhost:8000/llm/v1
api_key
=
API_KEY
,
expected_error
=
None
),
],
)
async
def
test_chat_session_root_path_with_api_key
(
server
:
RemoteOpenAIServer
,
test_case
:
TestCase
):
saying
:
str
=
"Here is a common saying about apple. An apple a day, keeps"
ctx
=
contextlib
.
nullcontext
()
if
test_case
.
expected_error
is
not
None
:
ctx
=
pytest
.
raises
(
test_case
.
expected_error
)
with
ctx
:
client
=
openai
.
AsyncOpenAI
(
api_key
=
test_case
.
api_key
,
base_url
=
server
.
url_for
(
*
test_case
.
base_url
),
max_retries
=
0
)
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
test_case
.
model_name
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"tell me a common saying"
},
{
"role"
:
"assistant"
,
"content"
:
saying
}],
extra_body
=
{
"continue_final_message"
:
True
,
"add_generation_prompt"
:
False
})
assert
chat_completion
.
id
is
not
None
assert
len
(
chat_completion
.
choices
)
==
1
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
==
"stop"
message
=
choice
.
message
assert
len
(
message
.
content
)
>
0
assert
message
.
role
==
"assistant"
vllm/entrypoints/openai/api_server.py
View file @
d04b13a3
...
...
@@ -499,10 +499,12 @@ def build_app(args: Namespace) -> FastAPI:
@
app
.
middleware
(
"http"
)
async
def
authentication
(
request
:
Request
,
call_next
):
root_path
=
""
if
args
.
root_path
is
None
else
args
.
root_path
if
request
.
method
==
"OPTIONS"
:
return
await
call_next
(
request
)
if
not
request
.
url
.
path
.
startswith
(
f
"
{
root_path
}
/v1"
):
url_path
=
request
.
url
.
path
if
app
.
root_path
and
url_path
.
startswith
(
app
.
root_path
):
url_path
=
url_path
[
len
(
app
.
root_path
):]
if
not
url_path
.
startswith
(
"/v1"
):
return
await
call_next
(
request
)
if
request
.
headers
.
get
(
"Authorization"
)
!=
"Bearer "
+
token
:
return
JSONResponse
(
content
=
{
"error"
:
"Unauthorized"
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment