Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9c939a3d
"tests/pytorch/test_basics_anonymous.py" did not exist on "4673b96f9298bc6e8ad25adbde0bfd17d57ff232"
Unverified
Commit
9c939a3d
authored
Nov 09, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 09, 2024
Browse files
Clean up metrics code (#1972)
parent
549e8b83
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
101 additions
and
107 deletions
+101
-107
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+6
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+7
-43
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+38
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-3
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+1
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_bench_latency.py
test/srt/test_bench_latency.py
+0
-2
test/srt/test_cache_report.py
test/srt/test_cache_report.py
+2
-2
test/srt/test_large_max_new_tokens.py
test/srt/test_large_max_new_tokens.py
+13
-4
test/srt/test_metrics.py
test/srt/test_metrics.py
+9
-30
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+3
-3
test/srt/test_radix_attention.py
test/srt/test_radix_attention.py
+4
-4
test/srt/test_skip_tokenizer_init.py
test/srt/test_skip_tokenizer_init.py
+2
-2
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+2
-2
test/srt/test_srt_engine.py
test/srt/test_srt_engine.py
+8
-5
test/srt/test_update_weights.py
test/srt/test_update_weights.py
+4
-4
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
9c939a3d
...
...
@@ -391,8 +391,12 @@ class TokenizerManager:
async
with
self
.
model_update_lock
:
# wait for the previous generation requests to finish
while
len
(
self
.
rid_to_state
)
>
0
:
await
asyncio
.
sleep
(
0.001
)
for
i
in
range
(
3
):
while
len
(
self
.
rid_to_state
)
>
0
:
await
asyncio
.
sleep
(
0.001
)
# FIXME: We add some sleep here to avoid some race conditions.
# We can use a read-write lock as a better fix.
await
asyncio
.
sleep
(
0.01
)
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
model_update_result
=
asyncio
.
Future
()
...
...
python/sglang/srt/server.py
View file @
9c939a3d
...
...
@@ -25,20 +25,16 @@ import json
import
logging
import
multiprocessing
as
mp
import
os
import
re
import
tempfile
import
threading
import
time
from
http
import
HTTPStatus
from
typing
import
AsyncIterator
,
Dict
,
List
,
Optional
,
Union
import
orjson
from
starlette.routing
import
Mount
# Fix a bug of Python threading
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
import
aiohttp
import
orjson
import
requests
import
uvicorn
import
uvloop
...
...
@@ -77,6 +73,7 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
add_api_key_middleware
,
add_prometheus_middleware
,
assert_pkg_version
,
configure_logger
,
delete_directory
,
...
...
@@ -84,16 +81,13 @@ from sglang.srt.utils import (
kill_child_process
,
maybe_set_triton_cache_manager
,
prepare_model_and_tokenizer
,
set_prometheus_multiproc_dir
,
set_ulimit
,
)
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -445,10 +439,6 @@ def launch_server(
1. The HTTP server and Tokenizer Manager both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
"""
if
server_args
.
enable_metrics
:
_set_prometheus_env
()
launch_engine
(
server_args
=
server_args
)
# Add api key authorization
...
...
@@ -487,36 +477,6 @@ def launch_server(
t
.
join
()
def
add_prometheus_middleware
(
app
:
FastAPI
):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from
prometheus_client
import
CollectorRegistry
,
make_asgi_app
,
multiprocess
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
(
registry
=
registry
))
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
"^/metrics(?P<path>.*)$"
)
app
.
routes
.
append
(
metrics_route
)
def
_set_prometheus_env
():
# Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode
# we need to set this before importing prometheus_client
# https://prometheus.github.io/client_python/multiprocess/
global
prometheus_multiproc_dir
if
"PROMETHEUS_MULTIPROC_DIR"
in
os
.
environ
:
logger
.
debug
(
f
"User set PROMETHEUS_MULTIPROC_DIR detected."
)
prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
(
dir
=
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
)
else
:
prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
()
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
=
prometheus_multiproc_dir
.
name
logger
.
debug
(
f
"PROMETHEUS_MULTIPROC_DIR:
{
os
.
environ
[
'PROMETHEUS_MULTIPROC_DIR'
]
}
"
)
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
# Set global environments
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
...
...
@@ -543,6 +503,10 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html."
,
)
# Set prometheus env vars
if
server_args
.
enable_metrics
:
set_prometheus_multiproc_dir
()
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
...
...
python/sglang/srt/utils.py
View file @
9c939a3d
...
...
@@ -22,10 +22,12 @@ import logging
import
os
import
pickle
import
random
import
re
import
resource
import
shutil
import
signal
import
socket
import
tempfile
import
time
import
warnings
from
importlib.metadata
import
PackageNotFoundError
,
version
...
...
@@ -41,6 +43,7 @@ import triton
import
zmq
from
fastapi.responses
import
ORJSONResponse
from
packaging
import
version
as
pkg_version
from
starlette.routing
import
Mount
from
torch
import
nn
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
triton.runtime.cache
import
(
...
...
@@ -752,3 +755,38 @@ def delete_directory(dirpath):
shutil
.
rmtree
(
dirpath
)
except
OSError
as
e
:
print
(
f
"Warning:
{
dirpath
}
:
{
e
.
strerror
}
"
)
# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
def
set_prometheus_multiproc_dir
():
# Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode
# we need to set this before importing prometheus_client
# https://prometheus.github.io/client_python/multiprocess/
global
prometheus_multiproc_dir
if
"PROMETHEUS_MULTIPROC_DIR"
in
os
.
environ
:
logger
.
debug
(
"User set PROMETHEUS_MULTIPROC_DIR detected."
)
prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
(
dir
=
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
)
else
:
prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
()
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
=
prometheus_multiproc_dir
.
name
logger
.
debug
(
f
"PROMETHEUS_MULTIPROC_DIR:
{
os
.
environ
[
'PROMETHEUS_MULTIPROC_DIR'
]
}
"
)
def
add_prometheus_middleware
(
app
):
from
prometheus_client
import
CollectorRegistry
,
make_asgi_app
,
multiprocess
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
(
registry
=
registry
))
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
"^/metrics(?P<path>.*)$"
)
app
.
routes
.
append
(
metrics_route
)
python/sglang/test/test_utils.py
View file @
9c939a3d
...
...
@@ -27,6 +27,7 @@ from sglang.utils import get_exception_traceback
DEFAULT_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/Meta-Llama-3.1-8B-FP8"
DEFAULT_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_MLA_MODEL_NAME_FOR_TEST
=
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
...
...
@@ -404,7 +405,6 @@ def popen_launch_server(
other_args
:
tuple
=
(),
env
:
Optional
[
dict
]
=
None
,
return_stdout_stderr
:
Optional
[
tuple
]
=
None
,
enable_metrics
:
bool
=
False
,
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
...
...
@@ -423,8 +423,6 @@ def popen_launch_server(
]
if
api_key
:
command
+=
[
"--api-key"
,
api_key
]
if
enable_metrics
:
command
+=
[
"--enable-metrics"
]
if
return_stdout_stderr
:
process
=
subprocess
.
Popen
(
...
...
scripts/ci_install_dependency.sh
View file @
9c939a3d
...
...
@@ -4,5 +4,5 @@ Install the dependency in CI.
pip
install
--upgrade
pip
pip
install
-e
"python[all]"
pip
install
transformers
==
4.45.2 sentence_transformers
pip
install
transformers
==
4.45.2 sentence_transformers
accelerate peft
pip
install
flashinfer
-i
https://flashinfer.ai/whl/cu121/torch2.4/
--force-reinstall
test/srt/run_suite.py
View file @
9c939a3d
...
...
@@ -16,6 +16,7 @@ suites = {
"test_eval_accuracy_mini.py"
,
"test_json_constrained.py"
,
"test_large_max_new_tokens.py"
,
"test_metrics.py"
,
"test_openai_server.py"
,
"test_overlap_schedule.py"
,
"test_pytorch_sampling_backend.py"
,
...
...
test/srt/test_bench_latency.py
View file @
9c939a3d
import
subprocess
import
unittest
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MOE_MODEL_NAME_FOR_TEST
,
...
...
test/srt/test_cache_report.py
View file @
9c939a3d
...
...
@@ -6,7 +6,7 @@ import requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
...
...
@@ -15,7 +15,7 @@ from sglang.test.test_utils import (
class
TestCacheReport
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
min_cached
=
5
cls
.
process
=
popen_launch_server
(
...
...
test/srt/test_large_max_new_tokens.py
View file @
9c939a3d
...
...
@@ -3,6 +3,7 @@ python3 -m unittest test_large_max_new_tokens.TestLargeMaxNewTokens.test_chat_co
"""
import
os
import
time
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
...
...
@@ -11,7 +12,7 @@ import openai
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
...
...
@@ -21,7 +22,7 @@ from sglang.test.test_utils import (
class
TestLargeMaxNewTokens
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
...
...
@@ -33,12 +34,19 @@ class TestLargeMaxNewTokens(unittest.TestCase):
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
other_args
=
(
"--max-total-token"
,
"1024"
,
"--context-len"
,
"8192"
),
other_args
=
(
"--max-total-token"
,
"1024"
,
"--context-len"
,
"8192"
,
"--decode-log-interval"
,
"2"
,
),
env
=
{
"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION"
:
"256"
,
**
os
.
environ
},
return_stdout_stderr
=
(
cls
.
stdout
,
cls
.
stderr
),
)
cls
.
base_url
+=
"/v1"
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_MODEL_NAME_FOR_TEST
)
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
)
@
classmethod
def
tearDownClass
(
cls
):
...
...
@@ -75,6 +83,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
# Ensure that they are running concurrently
pt
=
0
while
pt
>=
0
:
time
.
sleep
(
5
)
lines
=
open
(
"stderr.txt"
).
readlines
()
for
line
in
lines
[
pt
:]:
print
(
line
,
end
=
""
,
flush
=
True
)
...
...
test/srt/test_
enable_
metrics.py
→
test/srt/test_metrics.py
View file @
9c939a3d
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
TEST_MODEL
=
(
DEFAULT_MODEL_NAME_FOR_TEST
# I used "google/gemma-2-2b-it" for testing locally
)
class
TestEnableMetrics
(
unittest
.
TestCase
):
def
test_metrics_enabled
(
self
):
"""Test that metrics endpoint returns data when enabled"""
# Launch server with metrics enabled
process
=
popen_launch_server
(
model
=
TEST_MODEL
,
base_url
=
DEFAULT_URL_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
enable
_
metrics
=
True
,
other_args
=
[
"--
enable
-
metrics
"
]
,
)
try
:
...
...
@@ -38,6 +31,8 @@ class TestEnableMetrics(unittest.TestCase):
self
.
assertEqual
(
metrics_response
.
status_code
,
200
)
metrics_content
=
metrics_response
.
text
print
(
f
"
{
metrics_content
=
}
"
)
# Verify essential metrics are present
essential_metrics
=
[
"sglang:prompt_tokens_total"
,
...
...
@@ -53,7 +48,7 @@ class TestEnableMetrics(unittest.TestCase):
self
.
assertIn
(
metric
,
metrics_content
,
f
"Missing metric:
{
metric
}
"
)
# Verify model name label is present and correct
expected_model_name
=
TEST_MODEL
expected_model_name
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
self
.
assertIn
(
f
'model_name="
{
expected_model_name
}
"'
,
metrics_content
)
# Verify metrics have values (not empty)
self
.
assertIn
(
"_sum{"
,
metrics_content
)
...
...
@@ -63,22 +58,6 @@ class TestEnableMetrics(unittest.TestCase):
finally
:
kill_child_process
(
process
.
pid
,
include_self
=
True
)
def
test_metrics_disabled
(
self
):
"""Test that metrics endpoint returns 404 when disabled"""
# Launch server with metrics disabled
process
=
popen_launch_server
(
model
=
TEST_MODEL
,
base_url
=
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
enable_metrics
=
False
,
)
try
:
response
=
requests
.
get
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/health_generate"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Verify metrics endpoint is not available
metrics_response
=
requests
.
get
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/metrics"
)
self
.
assertEqual
(
metrics_response
.
status_code
,
404
)
finally
:
kill_child_process
(
process
.
pid
,
include_self
=
True
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_openai_server.py
View file @
9c939a3d
...
...
@@ -13,7 +13,7 @@ import openai
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
...
...
@@ -23,7 +23,7 @@ from sglang.test.test_utils import (
class
TestOpenAIServer
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
...
...
@@ -33,7 +33,7 @@ class TestOpenAIServer(unittest.TestCase):
api_key
=
cls
.
api_key
,
)
cls
.
base_url
+=
"/v1"
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_MODEL_NAME_FOR_TEST
)
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
)
@
classmethod
def
tearDownClass
(
cls
):
...
...
test/srt/test_radix_attention.py
View file @
9c939a3d
...
...
@@ -5,7 +5,7 @@ import unittest
import
requests
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
kill_child_process
,
...
...
@@ -62,7 +62,7 @@ def run_test(base_url, nodes):
class
TestRadixCacheFCFS
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
...
...
@@ -90,7 +90,7 @@ class TestRadixCacheFCFS(unittest.TestCase):
class
TestRadixCacheLPM
(
TestRadixCacheFCFS
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
...
...
@@ -110,7 +110,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
class
TestRadixCacheOverlapLPM
(
TestRadixCacheFCFS
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
...
...
test/srt/test_skip_tokenizer_init.py
View file @
9c939a3d
...
...
@@ -9,7 +9,7 @@ import requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
...
...
@@ -19,7 +19,7 @@ from sglang.test.test_utils import (
class
TestSkipTokenizerInit
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
...
...
test/srt/test_srt_endpoint.py
View file @
9c939a3d
...
...
@@ -10,7 +10,7 @@ import requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
...
...
@@ -20,7 +20,7 @@ from sglang.test.test_utils import (
class
TestSRTEndpoint
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
...
...
test/srt/test_srt_engine.py
View file @
9c939a3d
...
...
@@ -11,14 +11,17 @@ from types import SimpleNamespace
import
sglang
as
sgl
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.test.few_shot_gsm8k_engine
import
run_eval
from
sglang.test.test_utils
import
DEFAULT_MODEL_NAME_FOR_TEST
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
)
class
TestSRTEngine
(
unittest
.
TestCase
):
def
test_1_engine_runtime_consistency
(
self
):
prompt
=
"Today is a sunny day and I like"
model_path
=
DEFAULT_MODEL_NAME_FOR_TEST
model_path
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
...
...
@@ -40,7 +43,7 @@ class TestSRTEngine(unittest.TestCase):
def
test_2_engine_multiple_generate
(
self
):
# just to ensure there is no issue running multiple generate calls
prompt
=
"Today is a sunny day and I like"
model_path
=
DEFAULT_MODEL_NAME_FOR_TEST
model_path
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
...
...
@@ -66,7 +69,7 @@ class TestSRTEngine(unittest.TestCase):
# Create an LLM.
llm
=
sgl
.
Engine
(
model_path
=
DEFAULT_MODEL_NAME_FOR_TEST
,
model_path
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
log_level
=
"error"
,
)
...
...
@@ -110,7 +113,7 @@ class TestSRTEngine(unittest.TestCase):
def
test_5_prompt_input_ids_consistency
(
self
):
prompt
=
"The capital of UK is"
model_path
=
DEFAULT_MODEL_NAME_FOR_TEST
model_path
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
engine
=
sgl
.
Engine
(
model_path
=
model_path
,
random_seed
=
42
,
log_level
=
"error"
)
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
out1
=
engine
.
generate
(
prompt
,
sampling_params
)[
"text"
]
...
...
test/srt/test_update_weights.py
View file @
9c939a3d
...
...
@@ -5,7 +5,7 @@ import requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
...
...
@@ -15,7 +15,7 @@ from sglang.test.test_utils import (
class
TestUpdateWeights
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
...
...
@@ -64,7 +64,7 @@ class TestUpdateWeights(unittest.TestCase):
origin_response
=
self
.
run_decode
()
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B"
new_model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
.
replace
(
"-Instruct"
,
""
)
ret
=
self
.
run_update_weights
(
new_model_path
)
assert
ret
[
"success"
]
...
...
@@ -92,7 +92,7 @@ class TestUpdateWeights(unittest.TestCase):
origin_response
=
self
.
run_decode
()
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B-1"
new_model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
.
replace
(
"-Instruct"
,
"wrong"
)
ret
=
self
.
run_update_weights
(
new_model_path
)
assert
not
ret
[
"success"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment