Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
fb9296f0
"ci/env-osx.sh" did not exist on "957ed9fb5210a8e0e51f713387961d2538921aed"
Unverified
Commit
fb9296f0
authored
Jun 12, 2024
by
Ying Sheng
Committed by
GitHub
Jun 12, 2024
Browse files
Higher priority for user input of max_prefill_tokens & format (#540)
parent
1374334d
Changes
50
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
92 additions
and
53 deletions
+92
-53
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+1
-1
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+1
-1
python/sglang/srt/models/stablelm.py
python/sglang/srt/models/stablelm.py
+1
-1
python/sglang/srt/models/yivl.py
python/sglang/srt/models/yivl.py
+2
-2
python/sglang/srt/openai_api_adapter.py
python/sglang/srt/openai_api_adapter.py
+33
-23
python/sglang/srt/server.py
python/sglang/srt/server.py
+12
-5
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+8
-7
python/sglang/test/test_programs.py
python/sglang/test/test_programs.py
+27
-7
python/sglang/utils.py
python/sglang/utils.py
+4
-3
test/lang/test_openai_backend.py
test/lang/test_openai_backend.py
+3
-3
No files found.
python/sglang/srt/models/qwen.py
View file @
fb9296f0
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
from
typing
import
Any
,
Dict
,
Optional
,
Iterable
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
...
...
python/sglang/srt/models/qwen2.py
View file @
fb9296f0
# Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Iterable
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
...
...
python/sglang/srt/models/stablelm.py
View file @
fb9296f0
...
...
@@ -2,7 +2,7 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
model compatible with HuggingFace weights."""
from
typing
import
Optional
,
Tuple
,
Iterable
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
...
...
python/sglang/srt/models/yivl.py
View file @
fb9296f0
"""Inference-only Yi-VL model."""
from
typing
import
Tuple
,
Iterable
,
Optional
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.models.llava
import
(
LlavaLlamaForCausalLM
,
monkey_path_clip_vision_embed_forward
,
...
...
python/sglang/srt/openai_api_adapter.py
View file @
fb9296f0
...
...
@@ -6,7 +6,7 @@ import os
from
http
import
HTTPStatus
from
fastapi
import
Request
from
fastapi.responses
import
Streaming
Response
,
JSON
Response
from
fastapi.responses
import
JSON
Response
,
Streaming
Response
from
sglang.srt.conversation
import
(
Conversation
,
...
...
@@ -40,21 +40,18 @@ chat_template_name = None
def
create_error_response
(
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
):
error
=
ErrorResponse
(
message
=
message
,
type
=
err_type
,
code
=
status_code
.
value
)
return
JSONResponse
(
content
=
error
.
model_dump
(),
status_code
=
error
.
code
)
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
,
):
error
=
ErrorResponse
(
message
=
message
,
type
=
err_type
,
code
=
status_code
.
value
)
return
JSONResponse
(
content
=
error
.
model_dump
(),
status_code
=
error
.
code
)
def
create_streaming_error_response
(
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
)
->
str
:
error
=
ErrorResponse
(
message
=
message
,
type
=
err_type
,
code
=
status_code
.
value
)
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
,
)
->
str
:
error
=
ErrorResponse
(
message
=
message
,
type
=
err_type
,
code
=
status_code
.
value
)
json_str
=
json
.
dumps
({
"error"
:
error
.
model_dump
()})
return
json_str
...
...
@@ -125,7 +122,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
n_prev_token
=
0
try
:
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
adapted_request
,
raw_request
):
text
=
content
[
"text"
]
prompt_tokens
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
content
[
"meta_info"
][
"completion_tokens"
]
...
...
@@ -154,12 +152,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
decode_token_logprobs
=
content
[
"meta_info"
][
"decode_token_logprobs"
][
n_prev_token
:],
decode_top_logprobs
=
content
[
"meta_info"
][
"decode_top_logprobs"
][
n_prev_token
:
],
decode_top_logprobs
=
content
[
"meta_info"
][
"decode_top_logprobs"
][
n_prev_token
:
],
)
n_prev_token
=
len
(
content
[
"meta_info"
][
"decode_token_logprobs"
])
n_prev_token
=
len
(
content
[
"meta_info"
][
"decode_token_logprobs"
]
)
else
:
logprobs
=
None
...
...
@@ -188,13 +188,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
yield
f
"data:
{
error
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
generate_stream_resp
(),
media_type
=
"text/event-stream"
,
background
=
tokenizer_manager
.
create_abort_task
(
adapted_request
))
return
StreamingResponse
(
generate_stream_resp
(),
media_type
=
"text/event-stream"
,
background
=
tokenizer_manager
.
create_abort_task
(
adapted_request
),
)
# Non-streaming response.
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
...
...
@@ -299,7 +303,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
stream_buffer
=
""
try
:
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
if
is_first
:
# First chunk with role
is_first
=
False
...
...
@@ -334,13 +340,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
yield
f
"data:
{
error
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
generate_stream_resp
(),
media_type
=
"text/event-stream"
,
background
=
tokenizer_manager
.
create_abort_task
(
adapted_request
))
return
StreamingResponse
(
generate_stream_resp
(),
media_type
=
"text/event-stream"
,
background
=
tokenizer_manager
.
create_abort_task
(
adapted_request
),
)
# Non-streaming response.
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
...
...
python/sglang/srt/server.py
View file @
fb9296f0
...
...
@@ -13,7 +13,7 @@ import sys
import
threading
import
time
from
http
import
HTTPStatus
from
typing
import
Optional
,
Dict
from
typing
import
Dict
,
Optional
# Fix a bug of Python threading
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
...
...
@@ -29,10 +29,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.controller.manager_multi
import
(
start_controller_process
as
start_controller_process_multi
,
)
from
sglang.srt.managers.controller.manager_single
import
(
start_controller_process
as
start_controller_process_single
,
)
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.controller.manager_single
import
start_controller_process
as
start_controller_process_single
from
sglang.srt.managers.controller.manager_multi
import
start_controller_process
as
start_controller_process_multi
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.openai_api_adapter
import
(
load_chat_template_for_openai_api
,
...
...
@@ -97,8 +101,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
yield
f
"data:
{
json
.
dumps
(
out
,
ensure_ascii
=
False
)
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
,
background
=
tokenizer_manager
.
create_abort_task
(
obj
))
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
,
background
=
tokenizer_manager
.
create_abort_task
(
obj
),
)
else
:
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
...
...
python/sglang/srt/utils.py
View file @
fb9296f0
"""Common utilities."""
import
base64
import
multiprocessing
import
logging
import
multiprocessing
import
os
import
random
import
socket
...
...
@@ -17,12 +17,11 @@ import requests
import
rpyc
import
torch
import
triton
from
rpyc.utils.server
import
ThreadedServer
from
fastapi.responses
import
JSONResponse
from
packaging
import
version
as
pkg_version
from
rpyc.utils.server
import
ThreadedServer
from
starlette.middleware.base
import
BaseHTTPMiddleware
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -377,7 +376,7 @@ def init_rpyc_service(service: rpyc.Service, port: int):
protocol_config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
3600
"sync_request_timeout"
:
3600
,
},
)
t
.
logger
.
setLevel
(
logging
.
WARN
)
...
...
@@ -396,7 +395,7 @@ def connect_to_rpyc_service(port, host="localhost"):
config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
3600
"sync_request_timeout"
:
3600
,
},
)
break
...
...
@@ -423,7 +422,9 @@ def suppress_other_loggers():
vllm_default_logger
.
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.config"
).
setLevel
(
logging
.
ERROR
)
logging
.
getLogger
(
"vllm.distributed.device_communicators.pynccl"
).
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.distributed.device_communicators.pynccl"
).
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.selector"
).
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.utils"
).
setLevel
(
logging
.
WARN
)
...
...
@@ -464,6 +465,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
device_name
=
torch
.
cuda
.
get_device_name
(
gpu_id
)
if
"RTX 40"
not
in
device_name
:
import
vllm.distributed.device_communicators.custom_all_reduce_utils
as
tgt
setattr
(
tgt
,
"gpu_p2p_access_check"
,
lambda
*
arg
,
**
kwargs
:
True
)
...
...
@@ -485,4 +487,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
)
response
=
await
call_next
(
request
)
return
response
python/sglang/test/test_programs.py
View file @
fb9296f0
...
...
@@ -356,16 +356,25 @@ def test_completion_speculative():
s
+=
"Construct a character within the following format:
\n
"
s
+=
"Name: Steve Jobs.
\n
Birthday: February 24, 1955.
\n
Job: Apple CEO.
\n
"
s
+=
"
\n
Please generate new Name, Birthday and Job.
\n
"
s
+=
"Name:"
+
sgl
.
gen
(
"name"
,
stop
=
"
\n
"
)
+
"
\n
Birthday:"
+
sgl
.
gen
(
"birthday"
,
stop
=
"
\n
"
)
s
+=
(
"Name:"
+
sgl
.
gen
(
"name"
,
stop
=
"
\n
"
)
+
"
\n
Birthday:"
+
sgl
.
gen
(
"birthday"
,
stop
=
"
\n
"
)
)
s
+=
"
\n
Job:"
+
sgl
.
gen
(
"job"
,
stop
=
"
\n
"
)
+
"
\n
"
@
sgl
.
function
def
gen_character_no_spec
(
s
):
s
+=
"Construct a character within the following format:
\n
"
s
+=
"Name: Steve Jobs.
\n
Birthday: February 24, 1955.
\n
Job: Apple CEO.
\n
"
s
+=
"
\n
Please generate new Name, Birthday and Job.
\n
"
s
+=
"Name:"
+
sgl
.
gen
(
"name"
,
stop
=
"
\n
"
)
+
"
\n
Birthday:"
+
sgl
.
gen
(
"birthday"
,
stop
=
"
\n
"
)
s
+=
(
"Name:"
+
sgl
.
gen
(
"name"
,
stop
=
"
\n
"
)
+
"
\n
Birthday:"
+
sgl
.
gen
(
"birthday"
,
stop
=
"
\n
"
)
)
s
+=
"
\n
Job:"
+
sgl
.
gen
(
"job"
,
stop
=
"
\n
"
)
+
"
\n
"
token_usage
=
sgl
.
global_config
.
default_backend
.
token_usage
...
...
@@ -378,7 +387,9 @@ def test_completion_speculative():
gen_character_no_spec
().
sync
()
usage_with_no_spec
=
token_usage
.
prompt_tokens
assert
usage_with_spec
<
usage_with_no_spec
,
f
"
{
usage_with_spec
}
vs
{
usage_with_no_spec
}
"
assert
(
usage_with_spec
<
usage_with_no_spec
),
f
"
{
usage_with_spec
}
vs
{
usage_with_no_spec
}
"
def
test_chat_completion_speculative
():
...
...
@@ -386,8 +397,17 @@ def test_chat_completion_speculative():
def
gen_character_spec
(
s
):
s
+=
sgl
.
system
(
"You are a helpful assistant."
)
s
+=
sgl
.
user
(
"Construct a character within the following format:"
)
s
+=
sgl
.
assistant
(
"Name: Steve Jobs.
\n
Birthday: February 24, 1955.
\n
Job: Apple CEO.
\n
"
)
s
+=
sgl
.
assistant
(
"Name: Steve Jobs.
\n
Birthday: February 24, 1955.
\n
Job: Apple CEO.
\n
"
)
s
+=
sgl
.
user
(
"Please generate new Name, Birthday and Job.
\n
"
)
s
+=
sgl
.
assistant
(
"Name:"
+
sgl
.
gen
(
"name"
,
stop
=
"
\n
"
)
+
"
\n
Birthday:"
+
sgl
.
gen
(
"birthday"
,
stop
=
"
\n
"
)
+
"
\n
Job:"
+
sgl
.
gen
(
"job"
,
stop
=
"
\n
"
))
s
+=
sgl
.
assistant
(
"Name:"
+
sgl
.
gen
(
"name"
,
stop
=
"
\n
"
)
+
"
\n
Birthday:"
+
sgl
.
gen
(
"birthday"
,
stop
=
"
\n
"
)
+
"
\n
Job:"
+
sgl
.
gen
(
"job"
,
stop
=
"
\n
"
)
)
gen_character_spec
().
sync
()
python/sglang/utils.py
View file @
fb9296f0
...
...
@@ -15,7 +15,6 @@ from json import dumps
import
numpy
as
np
import
requests
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -255,7 +254,9 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
def
graceful_registry
(
sub_module_name
):
def
graceful_shutdown
(
signum
,
frame
):
logger
.
info
(
f
"
{
sub_module_name
}
Received signal to shutdown. Performing graceful shutdown..."
)
logger
.
info
(
f
"
{
sub_module_name
}
Received signal to shutdown. Performing graceful shutdown..."
)
if
signum
==
signal
.
SIGTERM
:
logger
.
info
(
f
"
{
sub_module_name
}
recive sigterm"
)
...
...
test/lang/test_openai_backend.py
View file @
fb9296f0
...
...
@@ -2,6 +2,8 @@ import unittest
from
sglang
import
OpenAI
,
set_default_backend
from
sglang.test.test_programs
import
(
test_chat_completion_speculative
,
test_completion_speculative
,
test_decode_int
,
test_decode_json
,
test_expert_answer
,
...
...
@@ -14,8 +16,6 @@ from sglang.test.test_programs import (
test_select
,
test_stream
,
test_tool_use
,
test_completion_speculative
,
test_chat_completion_speculative
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment