Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9c939a3d
Unverified
Commit
9c939a3d
authored
Nov 09, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 09, 2024
Browse files
Clean up metrics code (#1972)
parent
549e8b83
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
101 additions
and
107 deletions
+101
-107
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+6
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+7
-43
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+38
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-3
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+1
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_bench_latency.py
test/srt/test_bench_latency.py
+0
-2
test/srt/test_cache_report.py
test/srt/test_cache_report.py
+2
-2
test/srt/test_large_max_new_tokens.py
test/srt/test_large_max_new_tokens.py
+13
-4
test/srt/test_metrics.py
test/srt/test_metrics.py
+9
-30
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+3
-3
test/srt/test_radix_attention.py
test/srt/test_radix_attention.py
+4
-4
test/srt/test_skip_tokenizer_init.py
test/srt/test_skip_tokenizer_init.py
+2
-2
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+2
-2
test/srt/test_srt_engine.py
test/srt/test_srt_engine.py
+8
-5
test/srt/test_update_weights.py
test/srt/test_update_weights.py
+4
-4
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
9c939a3d
...
...
@@ -391,8 +391,12 @@ class TokenizerManager:
async
with
self
.
model_update_lock
:
# wait for the previous generation requests to finish
while
len
(
self
.
rid_to_state
)
>
0
:
await
asyncio
.
sleep
(
0.001
)
for
i
in
range
(
3
):
while
len
(
self
.
rid_to_state
)
>
0
:
await
asyncio
.
sleep
(
0.001
)
# FIXME: We add some sleep here to avoid some race conditions.
# We can use a read-write lock as a better fix.
await
asyncio
.
sleep
(
0.01
)
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
model_update_result
=
asyncio
.
Future
()
...
...
python/sglang/srt/server.py
View file @
9c939a3d
...
...
@@ -25,20 +25,16 @@ import json
import
logging
import
multiprocessing
as
mp
import
os
import
re
import
tempfile
import
threading
import
time
from
http
import
HTTPStatus
from
typing
import
AsyncIterator
,
Dict
,
List
,
Optional
,
Union
import
orjson
from
starlette.routing
import
Mount
# Fix a bug of Python threading
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
import
aiohttp
import
orjson
import
requests
import
uvicorn
import
uvloop
...
...
@@ -77,6 +73,7 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
add_api_key_middleware
,
add_prometheus_middleware
,
assert_pkg_version
,
configure_logger
,
delete_directory
,
...
...
@@ -84,16 +81,13 @@ from sglang.srt.utils import (
kill_child_process
,
maybe_set_triton_cache_manager
,
prepare_model_and_tokenizer
,
set_prometheus_multiproc_dir
,
set_ulimit
,
)
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -445,10 +439,6 @@ def launch_server(
1. The HTTP server and Tokenizer Manager both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
"""
if
server_args
.
enable_metrics
:
_set_prometheus_env
()
launch_engine
(
server_args
=
server_args
)
# Add api key authorization
...
...
@@ -487,36 +477,6 @@ def launch_server(
t
.
join
()
def
add_prometheus_middleware
(
app
:
FastAPI
):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from
prometheus_client
import
CollectorRegistry
,
make_asgi_app
,
multiprocess
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
(
registry
=
registry
))
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
"^/metrics(?P<path>.*)$"
)
app
.
routes
.
append
(
metrics_route
)
def
_set_prometheus_env
():
# Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode
# we need to set this before importing prometheus_client
# https://prometheus.github.io/client_python/multiprocess/
global
prometheus_multiproc_dir
if
"PROMETHEUS_MULTIPROC_DIR"
in
os
.
environ
:
logger
.
debug
(
f
"User set PROMETHEUS_MULTIPROC_DIR detected."
)
prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
(
dir
=
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
)
else
:
prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
()
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
=
prometheus_multiproc_dir
.
name
logger
.
debug
(
f
"PROMETHEUS_MULTIPROC_DIR:
{
os
.
environ
[
'PROMETHEUS_MULTIPROC_DIR'
]
}
"
)
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
# Set global environments
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
...
...
@@ -543,6 +503,10 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html."
,
)
# Set prometheus env vars
if
server_args
.
enable_metrics
:
set_prometheus_multiproc_dir
()
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
...
...
python/sglang/srt/utils.py
View file @
9c939a3d
...
...
@@ -22,10 +22,12 @@ import logging
import
os
import
pickle
import
random
import
re
import
resource
import
shutil
import
signal
import
socket
import
tempfile
import
time
import
warnings
from
importlib.metadata
import
PackageNotFoundError
,
version
...
...
@@ -41,6 +43,7 @@ import triton
import
zmq
from
fastapi.responses
import
ORJSONResponse
from
packaging
import
version
as
pkg_version
from
starlette.routing
import
Mount
from
torch
import
nn
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
triton.runtime.cache
import
(
...
...
@@ -752,3 +755,38 @@ def delete_directory(dirpath):
shutil
.
rmtree
(
dirpath
)
except
OSError
as
e
:
print
(
f
"Warning:
{
dirpath
}
:
{
e
.
strerror
}
"
)
# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
def
set_prometheus_multiproc_dir
():
# Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode
# we need to set this before importing prometheus_client
# https://prometheus.github.io/client_python/multiprocess/
global
prometheus_multiproc_dir
if
"PROMETHEUS_MULTIPROC_DIR"
in
os
.
environ
:
logger
.
debug
(
"User set PROMETHEUS_MULTIPROC_DIR detected."
)
prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
(
dir
=
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
)
else
:
prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
()
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
=
prometheus_multiproc_dir
.
name
logger
.
debug
(
f
"PROMETHEUS_MULTIPROC_DIR:
{
os
.
environ
[
'PROMETHEUS_MULTIPROC_DIR'
]
}
"
)
def
add_prometheus_middleware
(
app
):
from
prometheus_client
import
CollectorRegistry
,
make_asgi_app
,
multiprocess
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
(
registry
=
registry
))
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
"^/metrics(?P<path>.*)$"
)
app
.
routes
.
append
(
metrics_route
)
python/sglang/test/test_utils.py
View file @
9c939a3d
...
...
@@ -27,6 +27,7 @@ from sglang.utils import get_exception_traceback
DEFAULT_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/Meta-Llama-3.1-8B-FP8"
DEFAULT_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_MLA_MODEL_NAME_FOR_TEST
=
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
...
...
@@ -404,7 +405,6 @@ def popen_launch_server(
other_args
:
tuple
=
(),
env
:
Optional
[
dict
]
=
None
,
return_stdout_stderr
:
Optional
[
tuple
]
=
None
,
enable_metrics
:
bool
=
False
,
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
...
...
@@ -423,8 +423,6 @@ def popen_launch_server(
]
if
api_key
:
command
+=
[
"--api-key"
,
api_key
]
if
enable_metrics
:
command
+=
[
"--enable-metrics"
]
if
return_stdout_stderr
:
process
=
subprocess
.
Popen
(
...
...
scripts/ci_install_dependency.sh
View file @
9c939a3d
...
...
@@ -4,5 +4,5 @@ Install the dependency in CI.
pip
install
--upgrade
pip
pip
install
-e
"python[all]"
pip
install
transformers
==
4.45.2 sentence_transformers
pip
install
transformers
==
4.45.2 sentence_transformers
accelerate peft
pip
install
flashinfer
-i
https://flashinfer.ai/whl/cu121/torch2.4/
--force-reinstall
test/srt/run_suite.py
View file @
9c939a3d
...
...
@@ -16,6 +16,7 @@ suites = {
"test_eval_accuracy_mini.py"
,
"test_json_constrained.py"
,
"test_large_max_new_tokens.py"
,
"test_metrics.py"
,
"test_openai_server.py"
,
"test_overlap_schedule.py"
,
"test_pytorch_sampling_backend.py"
,
...
...
test/srt/test_bench_latency.py
View file @
9c939a3d
import
subprocess
import
unittest
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MOE_MODEL_NAME_FOR_TEST
,
...
...
test/srt/test_cache_report.py
View file @
9c939a3d
...
...
@@ -6,7 +6,7 @@ import requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
...
...
@@ -15,7 +15,7 @@ from sglang.test.test_utils import (
class
TestCacheReport
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
min_cached
=
5
cls
.
process
=
popen_launch_server
(
...
...
test/srt/test_large_max_new_tokens.py
View file @
9c939a3d
...
...
@@ -3,6 +3,7 @@ python3 -m unittest test_large_max_new_tokens.TestLargeMaxNewTokens.test_chat_co
"""
import
os
import
time
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
...
...
@@ -11,7 +12,7 @@ import openai
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
...
...
@@ -21,7 +22,7 @@ from sglang.test.test_utils import (
class
TestLargeMaxNewTokens
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
...
...
@@ -33,12 +34,19 @@ class TestLargeMaxNewTokens(unittest.TestCase):
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
other_args
=
(
"--max-total-token"
,
"1024"
,
"--context-len"
,
"8192"
),
other_args
=
(
"--max-total-token"
,
"1024"
,
"--context-len"
,
"8192"
,
"--decode-log-interval"
,
"2"
,
),
env
=
{
"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION"
:
"256"
,
**
os
.
environ
},
return_stdout_stderr
=
(
cls
.
stdout
,
cls
.
stderr
),
)
cls
.
base_url
+=
"/v1"
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_MODEL_NAME_FOR_TEST
)
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
)
@
classmethod
def
tearDownClass
(
cls
):
...
...
@@ -75,6 +83,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
# Ensure that they are running concurrently
pt
=
0
while
pt
>=
0
:
time
.
sleep
(
5
)
lines
=
open
(
"stderr.txt"
).
readlines
()
for
line
in
lines
[
pt
:]:
print
(
line
,
end
=
""
,
flush
=
True
)
...
...
test/srt/test_
enable_
metrics.py
→
test/srt/test_metrics.py
View file @
9c939a3d
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
TEST_MODEL
=
(
DEFAULT_MODEL_NAME_FOR_TEST
# I used "google/gemma-2-2b-it" for testing locally
)
class
TestEnableMetrics
(
unittest
.
TestCase
):
def
test_metrics_enabled
(
self
):
"""Test that metrics endpoint returns data when enabled"""
# Launch server with metrics enabled
process
=
popen_launch_server
(
model
=
TEST_MODEL
,
base_url
=
DEFAULT_URL_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
enable
_
metrics
=
True
,
other_args
=
[
"--
enable
-
metrics
"
]
,
)
try
:
...
...
@@ -38,6 +31,8 @@ class TestEnableMetrics(unittest.TestCase):
self
.
assertEqual
(
metrics_response
.
status_code
,
200
)
metrics_content
=
metrics_response
.
text
print
(
f
"
{
metrics_content
=
}
"
)
# Verify essential metrics are present
essential_metrics
=
[
"sglang:prompt_tokens_total"
,
...
...
@@ -53,7 +48,7 @@ class TestEnableMetrics(unittest.TestCase):
self
.
assertIn
(
metric
,
metrics_content
,
f
"Missing metric:
{
metric
}
"
)
# Verify model name label is present and correct
expected_model_name
=
TEST_MODEL
expected_model_name
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
self
.
assertIn
(
f
'model_name="
{
expected_model_name
}
"'
,
metrics_content
)
# Verify metrics have values (not empty)
self
.
assertIn
(
"_sum{"
,
metrics_content
)
...
...
@@ -63,22 +58,6 @@ class TestEnableMetrics(unittest.TestCase):
finally
:
kill_child_process
(
process
.
pid
,
include_self
=
True
)
def
test_metrics_disabled
(
self
):
"""Test that metrics endpoint returns 404 when disabled"""
# Launch server with metrics disabled
process
=
popen_launch_server
(
model
=
TEST_MODEL
,
base_url
=
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
enable_metrics
=
False
,
)
try
:
response
=
requests
.
get
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/health_generate"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Verify metrics endpoint is not available
metrics_response
=
requests
.
get
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/metrics"
)
self
.
assertEqual
(
metrics_response
.
status_code
,
404
)
finally
:
kill_child_process
(
process
.
pid
,
include_self
=
True
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_openai_server.py
View file @
9c939a3d
...
...
@@ -13,7 +13,7 @@ import openai
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
...
...
@@ -23,7 +23,7 @@ from sglang.test.test_utils import (
class
TestOpenAIServer
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
...
...
@@ -33,7 +33,7 @@ class TestOpenAIServer(unittest.TestCase):
api_key
=
cls
.
api_key
,
)
cls
.
base_url
+=
"/v1"
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_MODEL_NAME_FOR_TEST
)
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
)
@
classmethod
def
tearDownClass
(
cls
):
...
...
test/srt/test_radix_attention.py
View file @
9c939a3d
...
...
@@ -5,7 +5,7 @@ import unittest
import
requests
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
kill_child_process
,
...
...
@@ -62,7 +62,7 @@ def run_test(base_url, nodes):
class
TestRadixCacheFCFS
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
...
...
@@ -90,7 +90,7 @@ class TestRadixCacheFCFS(unittest.TestCase):
class
TestRadixCacheLPM
(
TestRadixCacheFCFS
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
...
...
@@ -110,7 +110,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
class
TestRadixCacheOverlapLPM
(
TestRadixCacheFCFS
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
...
...
test/srt/test_skip_tokenizer_init.py
View file @
9c939a3d
...
...
@@ -9,7 +9,7 @@ import requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
...
...
@@ -19,7 +19,7 @@ from sglang.test.test_utils import (
class
TestSkipTokenizerInit
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
...
...
test/srt/test_srt_endpoint.py
View file @
9c939a3d
...
...
@@ -10,7 +10,7 @@ import requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
...
...
@@ -20,7 +20,7 @@ from sglang.test.test_utils import (
class
TestSRTEndpoint
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
...
...
test/srt/test_srt_engine.py
View file @
9c939a3d
...
...
@@ -11,14 +11,17 @@ from types import SimpleNamespace
import
sglang
as
sgl
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.test.few_shot_gsm8k_engine
import
run_eval
from
sglang.test.test_utils
import
DEFAULT_MODEL_NAME_FOR_TEST
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
)
class
TestSRTEngine
(
unittest
.
TestCase
):
def
test_1_engine_runtime_consistency
(
self
):
prompt
=
"Today is a sunny day and I like"
model_path
=
DEFAULT_MODEL_NAME_FOR_TEST
model_path
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
...
...
@@ -40,7 +43,7 @@ class TestSRTEngine(unittest.TestCase):
def
test_2_engine_multiple_generate
(
self
):
# just to ensure there is no issue running multiple generate calls
prompt
=
"Today is a sunny day and I like"
model_path
=
DEFAULT_MODEL_NAME_FOR_TEST
model_path
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
...
...
@@ -66,7 +69,7 @@ class TestSRTEngine(unittest.TestCase):
# Create an LLM.
llm
=
sgl
.
Engine
(
model_path
=
DEFAULT_MODEL_NAME_FOR_TEST
,
model_path
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
log_level
=
"error"
,
)
...
...
@@ -110,7 +113,7 @@ class TestSRTEngine(unittest.TestCase):
def
test_5_prompt_input_ids_consistency
(
self
):
prompt
=
"The capital of UK is"
model_path
=
DEFAULT_MODEL_NAME_FOR_TEST
model_path
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
engine
=
sgl
.
Engine
(
model_path
=
model_path
,
random_seed
=
42
,
log_level
=
"error"
)
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
out1
=
engine
.
generate
(
prompt
,
sampling_params
)[
"text"
]
...
...
test/srt/test_update_weights.py
View file @
9c939a3d
...
...
@@ -5,7 +5,7 @@ import requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
...
...
@@ -15,7 +15,7 @@ from sglang.test.test_utils import (
class
TestUpdateWeights
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
...
...
@@ -64,7 +64,7 @@ class TestUpdateWeights(unittest.TestCase):
origin_response
=
self
.
run_decode
()
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B"
new_model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
.
replace
(
"-Instruct"
,
""
)
ret
=
self
.
run_update_weights
(
new_model_path
)
assert
ret
[
"success"
]
...
...
@@ -92,7 +92,7 @@ class TestUpdateWeights(unittest.TestCase):
origin_response
=
self
.
run_decode
()
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B-1"
new_model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
.
replace
(
"-Instruct"
,
"wrong"
)
ret
=
self
.
run_update_weights
(
new_model_path
)
assert
not
ret
[
"success"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment