Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d84c5e70
Unverified
Commit
d84c5e70
authored
Aug 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 11, 2024
Browse files
Test the case when max_new_tokens is very large (#1038)
parent
d7854120
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
99 additions
and
13 deletions
+99
-13
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-3
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+5
-2
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+2
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+12
-1
test/srt/run_suite.py
test/srt/run_suite.py
+6
-4
test/srt/test_large_max_new_tokens.py
test/srt/test_large_max_new_tokens.py
+72
-0
No files found.
python/sglang/srt/managers/detokenizer_manager.py
View file @
d84c5e70
...
@@ -32,7 +32,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -32,7 +32,7 @@ from sglang.srt.managers.io_struct import (
)
)
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
find_printable_text
,
get_exception_traceback
,
graceful_registry
from
sglang.utils
import
find_printable_text
,
get_exception_traceback
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
@@ -164,8 +164,6 @@ def start_detokenizer_process(
...
@@ -164,8 +164,6 @@ def start_detokenizer_process(
port_args
:
PortArgs
,
port_args
:
PortArgs
,
pipe_writer
,
pipe_writer
,
):
):
graceful_registry
(
inspect
.
currentframe
().
f_code
.
co_name
)
try
:
try
:
manager
=
DetokenizerManager
(
server_args
,
port_args
)
manager
=
DetokenizerManager
(
server_args
,
port_args
)
except
Exception
:
except
Exception
:
...
...
python/sglang/srt/managers/policy_scheduler.py
View file @
d84c5e70
...
@@ -15,6 +15,7 @@ limitations under the License.
...
@@ -15,6 +15,7 @@ limitations under the License.
"""Request policy scheduler"""
"""Request policy scheduler"""
import
os
import
random
import
random
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -24,9 +25,11 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
...
@@ -24,9 +25,11 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.radix_cache
import
TreeNode
from
sglang.srt.mem_cache.radix_cache
import
TreeNode
# Clip the
max
new
tokens for the request whose max_new_tokens is very large.
# Clip the
estimation of max_
new
_
tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
# This can prevent the server from being too conservative.
CLIP_MAX_NEW_TOKENS
=
4096
# Note that this only clips the estimation in the scheduler but does not change the stop
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
CLIP_MAX_NEW_TOKENS
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS"
,
"4096"
))
class
PolicyScheduler
:
class
PolicyScheduler
:
...
...
python/sglang/srt/openai_api/adapter.py
View file @
d84c5e70
...
@@ -77,7 +77,7 @@ class FileMetadata:
...
@@ -77,7 +77,7 @@ class FileMetadata:
batch_storage
:
Dict
[
str
,
BatchResponse
]
=
{}
batch_storage
:
Dict
[
str
,
BatchResponse
]
=
{}
file_id_request
:
Dict
[
str
,
FileMetadata
]
=
{}
file_id_request
:
Dict
[
str
,
FileMetadata
]
=
{}
file_id_response
:
Dict
[
str
,
FileResponse
]
=
{}
file_id_response
:
Dict
[
str
,
FileResponse
]
=
{}
# map file id to file path in SG
l
ang backend
# map file id to file path in SG
L
ang backend
file_id_storage
:
Dict
[
str
,
str
]
=
{}
file_id_storage
:
Dict
[
str
,
str
]
=
{}
...
@@ -335,7 +335,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
...
@@ -335,7 +335,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
}
}
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"error in SG
l
ang:"
,
e
)
print
(
"error in SG
L
ang:"
,
e
)
# Update batch status to "failed"
# Update batch status to "failed"
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
status
=
"failed"
retrieve_batch
.
status
=
"failed"
...
...
python/sglang/srt/server_args.py
View file @
d84c5e70
...
@@ -64,7 +64,7 @@ class ServerArgs:
...
@@ -64,7 +64,7 @@ class ServerArgs:
# Other
# Other
api_key
:
Optional
[
str
]
=
None
api_key
:
Optional
[
str
]
=
None
file_storage_pth
:
str
=
"SG
l
ang_storage"
file_storage_pth
:
str
=
"SG
L
ang_storage"
# Data parallelism
# Data parallelism
dp_size
:
int
=
1
dp_size
:
int
=
1
...
...
python/sglang/test/test_utils.py
View file @
d84c5e70
...
@@ -398,6 +398,8 @@ def popen_launch_server(
...
@@ -398,6 +398,8 @@ def popen_launch_server(
timeout
:
float
,
timeout
:
float
,
api_key
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
other_args
:
tuple
=
(),
other_args
:
tuple
=
(),
env
:
Optional
[
dict
]
=
None
,
return_stdout_stderr
:
bool
=
False
,
):
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
host
=
host
[
2
:]
...
@@ -417,7 +419,16 @@ def popen_launch_server(
...
@@ -417,7 +419,16 @@ def popen_launch_server(
if
api_key
:
if
api_key
:
command
+=
[
"--api-key"
,
api_key
]
command
+=
[
"--api-key"
,
api_key
]
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
)
if
return_stdout_stderr
:
process
=
subprocess
.
Popen
(
command
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
env
=
env
,
text
=
True
,
)
else
:
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
,
env
=
env
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
while
time
.
time
()
-
start_time
<
timeout
:
while
time
.
time
()
-
start_time
<
timeout
:
...
...
test/srt/run_suite.py
View file @
d84c5e70
...
@@ -5,13 +5,15 @@ from sglang.test.test_utils import run_unittest_files
...
@@ -5,13 +5,15 @@ from sglang.test.test_utils import run_unittest_files
suites
=
{
suites
=
{
"minimal"
:
[
"minimal"
:
[
"test_chunked_prefill.py"
,
"test_embedding_openai_server.py"
,
"test_eval_accuracy.py"
,
"test_eval_accuracy.py"
,
"test_large_max_new_tokens.py"
,
"test_openai_server.py"
,
"test_openai_server.py"
,
"test_vision_openai_server.py"
,
"test_skip_tokenizer_init.py"
,
"test_embedding_openai_server.py"
,
"test_chunked_prefill.py"
,
"test_torch_compile.py"
,
"test_torch_compile.py"
,
"test_models_from_modelscope.py"
,
"test_vision_openai_server.py"
,
"test_large_max_new_tokens.py"
,
"models/test_generation_models.py"
,
"models/test_generation_models.py"
,
"models/test_embedding_models.py"
,
"models/test_embedding_models.py"
,
"sampling/penaltylib"
,
"sampling/penaltylib"
,
...
...
test/srt/test_large_max_new_tokens.py
0 → 100644
View file @
d84c5e70
import
json
import
os
import
time
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
import
openai
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
DEFAULT_MODEL_NAME_FOR_TEST
,
popen_launch_server
class
TestOpenAIServer
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
"http://127.0.0.1:8157"
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
api_key
=
cls
.
api_key
,
other_args
=
(
"--max-total-token"
,
"1024"
),
env
=
{
"SGLANG_CLIP_MAX_NEW_TOKENS"
:
"256"
,
**
os
.
environ
},
return_stdout_stderr
=
True
,
)
cls
.
base_url
+=
"/v1"
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_MODEL_NAME_FOR_TEST
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
run_chat_completion
(
self
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
response
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"Please repeat the world 'hello' for 10000 times."
,
},
],
temperature
=
0
,
)
return
response
def
test_chat_completion
(
self
):
num_requests
=
4
futures
=
[]
with
ThreadPoolExecutor
(
16
)
as
executor
:
for
i
in
range
(
num_requests
):
futures
.
append
(
executor
.
submit
(
self
.
run_chat_completion
))
all_requests_running
=
False
for
line
in
iter
(
self
.
process
.
stderr
.
readline
,
""
):
line
=
str
(
line
)
print
(
line
,
end
=
""
)
if
f
"#running-req:
{
num_requests
}
"
in
line
:
all_requests_running
=
True
break
assert
all_requests_running
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment