Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0f8eb153
Unverified
Commit
0f8eb153
authored
Dec 09, 2024
by
Yineng Zhang
Committed by
GitHub
Dec 09, 2024
Browse files
feat: support custom task runner (#2407)
parent
67470bbb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
392 additions
and
0 deletions
+392
-0
.github/workflows/experiment-runner.yml
.github/workflows/experiment-runner.yml
+26
-0
test/srt/configs/sharegpt_config.yaml
test/srt/configs/sharegpt_config.yaml
+7
-0
test/srt/experiment_runner.py
test/srt/experiment_runner.py
+359
-0
No files found.
.github/workflows/experiment-runner.yml
0 → 100644
View file @
0f8eb153
name
:
Experiment Runner
on
:
workflow_dispatch
:
concurrency
:
group
:
experiment-runner-${{ github.ref }}
cancel-in-progress
:
true
jobs
:
experiment-runner-1-gpu
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on
:
1-gpu-runner
steps
:
-
name
:
Checkout code
uses
:
actions/checkout@v3
-
name
:
Install dependencies
run
:
|
bash scripts/ci_install_dependency.sh
-
name
:
Test experiment runner
timeout-minutes
:
10
run
:
|
cd test/srt
python3 experiment_runner.py --config configs/sharegpt_config.yaml
test/srt/configs/sharegpt_config.yaml
0 → 100644
View file @
0f8eb153
tasks
:
-
name
:
sglang-benchmark
server_cmd
:
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
client_cmd
:
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --request-rate
16
-
name
:
vllm-benchmark
server_cmd
:
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
client_cmd
:
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --request-rate
16
test/srt/experiment_runner.py
0 → 100644
View file @
0f8eb153
import
argparse
import
logging
import
os
import
queue
import
re
import
subprocess
import
threading
import
time
from
dataclasses
import
dataclass
from
datetime
import
datetime
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
psutil
import
requests
import
yaml
@
dataclass
class
ServerConfig
:
command
:
str
process_names
:
List
[
str
]
default_port
:
int
@
dataclass
class
TaskConfig
:
server_cmd
:
str
client_cmd
:
str
name
:
Optional
[
str
]
=
None
server_type
:
Optional
[
str
]
=
None
@
dataclass
class
TaskResult
:
name
:
str
success
:
bool
output
:
str
runtime
:
float
timestamp
:
str
SERVER_DEFAULTS
=
{
"sglang"
:
ServerConfig
(
command
=
"sglang.launch_server"
,
process_names
=
[
"sglang.launch_server"
],
default_port
=
30000
,
),
"vllm"
:
ServerConfig
(
command
=
"vllm.entrypoints.openai.api_server"
,
process_names
=
[
"vllm.entrypoints.openai.api_server"
],
default_port
=
8000
,
),
}
def
parse_key_info
(
output
:
str
)
->
str
:
"""Extract and format key information from the output"""
key_info
=
[]
# Extract Args namespace
args_match
=
re
.
search
(
r
"Namespace\(.*?\)"
,
output
,
re
.
DOTALL
)
if
args_match
:
key_info
.
append
(
args_match
.
group
(
0
))
# Extract input/output token counts
token_matches
=
re
.
findall
(
r
"#(Input|Output) tokens: \d+"
,
output
)
key_info
.
extend
(
token_matches
)
# Extract benchmark result section
result_match
=
re
.
search
(
r
"============ Serving Benchmark Result ============.*?={50,}"
,
output
,
re
.
DOTALL
,
)
if
result_match
:
key_info
.
append
(
result_match
.
group
(
0
))
return
"
\n\n
"
.
join
(
key_info
)
def
extract_port_from_command
(
cmd
:
str
,
server_type
:
str
)
->
int
:
port_match
=
re
.
search
(
r
"--port[= ](\d+)"
,
cmd
)
if
port_match
:
return
int
(
port_match
.
group
(
1
))
return
SERVER_DEFAULTS
.
get
(
server_type
,
ServerConfig
(
""
,
[],
8000
)).
default_port
def
detect_server_type
(
cmd
:
str
)
->
str
:
for
server_type
,
config
in
SERVER_DEFAULTS
.
items
():
if
config
.
command
in
cmd
:
return
server_type
return
"unknown"
def
stream_output
(
process
:
subprocess
.
Popen
,
prefix
:
str
,
logger
:
logging
.
Logger
)
->
queue
.
Queue
:
output_queue
=
queue
.
Queue
()
def
stream_pipe
(
pipe
,
prefix
):
for
line
in
iter
(
pipe
.
readline
,
""
):
if
prefix
==
"CLIENT"
:
output_queue
.
put
(
line
.
rstrip
())
logger
.
debug
(
f
"
{
prefix
}
|
{
line
.
rstrip
()
}
"
)
stdout_thread
=
threading
.
Thread
(
target
=
stream_pipe
,
args
=
(
process
.
stdout
,
prefix
),
daemon
=
True
)
stderr_thread
=
threading
.
Thread
(
target
=
stream_pipe
,
args
=
(
process
.
stderr
,
prefix
),
daemon
=
True
)
stdout_thread
.
start
()
stderr_thread
.
start
()
return
output_queue
,
(
stdout_thread
,
stderr_thread
)
class
ProcessManager
:
def
__init__
(
self
):
self
.
server_process
:
Optional
[
subprocess
.
Popen
]
=
None
self
.
client_process
:
Optional
[
subprocess
.
Popen
]
=
None
self
.
logger
=
logging
.
getLogger
(
__name__
)
def
start_process
(
self
,
command
:
str
,
prefix
:
str
)
->
Tuple
[
subprocess
.
Popen
,
queue
.
Queue
]:
process
=
subprocess
.
Popen
(
command
,
shell
=
True
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
text
=
True
,
bufsize
=
1
,
)
output_queue
,
threads
=
stream_output
(
process
,
prefix
,
self
.
logger
)
return
process
,
output_queue
,
threads
def
kill_process_tree
(
self
,
process
:
subprocess
.
Popen
):
try
:
parent
=
psutil
.
Process
(
process
.
pid
)
children
=
parent
.
children
(
recursive
=
True
)
for
child
in
children
:
try
:
child
.
kill
()
except
psutil
.
NoSuchProcess
:
pass
parent
.
kill
()
gone
,
alive
=
psutil
.
wait_procs
(
children
+
[
parent
],
timeout
=
3
)
for
p
in
alive
:
try
:
p
.
kill
()
except
psutil
.
NoSuchProcess
:
pass
except
psutil
.
NoSuchProcess
:
pass
def
cleanup
(
self
,
process_names
:
List
[
str
]):
if
self
.
client_process
:
self
.
kill_process_tree
(
self
.
client_process
)
self
.
client_process
=
None
if
self
.
server_process
:
self
.
kill_process_tree
(
self
.
server_process
)
self
.
server_process
=
None
for
proc
in
psutil
.
process_iter
([
"pid"
,
"name"
,
"cmdline"
]):
try
:
cmdline
=
" "
.
join
(
proc
.
cmdline
())
if
any
(
name
in
cmdline
for
name
in
process_names
):
proc
.
kill
()
except
(
psutil
.
NoSuchProcess
,
psutil
.
AccessDenied
):
continue
class
ExperimentRunner
:
def
__init__
(
self
):
self
.
process_manager
=
ProcessManager
()
self
.
logger
=
logging
.
getLogger
(
__name__
)
def
wait_for_server
(
self
,
port
:
int
,
timeout
:
int
=
300
)
->
bool
:
start_time
=
time
.
time
()
while
time
.
time
()
-
start_time
<
timeout
:
try
:
response
=
requests
.
get
(
f
"http://localhost:
{
port
}
/health"
)
if
response
.
status_code
==
200
:
self
.
logger
.
debug
(
f
"Server ready on port
{
port
}
"
)
return
True
except
requests
.
RequestException
:
time
.
sleep
(
2
)
return
False
def
run_task
(
self
,
config
:
TaskConfig
)
->
TaskResult
:
start_time
=
time
.
time
()
client_output
=
[]
try
:
if
not
config
.
server_type
:
config
.
server_type
=
detect_server_type
(
config
.
server_cmd
)
server_config
=
SERVER_DEFAULTS
.
get
(
config
.
server_type
)
if
not
server_config
:
raise
ValueError
(
f
"Unknown server type:
{
config
.
server_type
}
"
)
port
=
extract_port_from_command
(
config
.
server_cmd
,
config
.
server_type
)
self
.
process_manager
.
cleanup
(
server_config
.
process_names
)
self
.
logger
.
debug
(
f
"Starting server:
{
config
.
name
}
"
)
self
.
process_manager
.
server_process
,
_
,
server_threads
=
(
self
.
process_manager
.
start_process
(
config
.
server_cmd
,
"SERVER"
)
)
if
not
self
.
wait_for_server
(
port
):
raise
TimeoutError
(
"Server startup timeout"
)
time
.
sleep
(
10
)
self
.
logger
.
debug
(
"Starting client"
)
self
.
process_manager
.
client_process
,
output_queue
,
client_threads
=
(
self
.
process_manager
.
start_process
(
config
.
client_cmd
,
"CLIENT"
)
)
returncode
=
self
.
process_manager
.
client_process
.
wait
()
while
True
:
try
:
line
=
output_queue
.
get_nowait
()
client_output
.
append
(
line
)
except
queue
.
Empty
:
break
if
returncode
!=
0
:
raise
RuntimeError
(
f
"Client failed with code
{
returncode
}
"
)
# Parse and format the output
full_output
=
"
\n
"
.
join
(
client_output
)
formatted_output
=
parse_key_info
(
full_output
)
return
TaskResult
(
name
=
config
.
name
,
success
=
True
,
output
=
formatted_output
,
runtime
=
time
.
time
()
-
start_time
,
timestamp
=
datetime
.
now
().
isoformat
(),
)
except
Exception
as
e
:
return
TaskResult
(
name
=
config
.
name
,
success
=
False
,
output
=
str
(
e
),
runtime
=
time
.
time
()
-
start_time
,
timestamp
=
datetime
.
now
().
isoformat
(),
)
finally
:
if
config
.
server_type
in
SERVER_DEFAULTS
:
self
.
process_manager
.
cleanup
(
SERVER_DEFAULTS
[
config
.
server_type
].
process_names
)
time
.
sleep
(
10
)
def
load_config
(
config_path
:
str
)
->
List
[
TaskConfig
]:
with
open
(
config_path
,
"r"
)
as
f
:
config_data
=
yaml
.
safe_load
(
f
)
configs
=
[]
for
idx
,
entry
in
enumerate
(
config_data
.
get
(
"tasks"
,
[])):
if
not
isinstance
(
entry
,
dict
):
raise
ValueError
(
f
"Invalid entry at index
{
idx
}
"
)
config
=
TaskConfig
(
server_cmd
=
entry
.
get
(
"server_cmd"
),
client_cmd
=
entry
.
get
(
"client_cmd"
),
name
=
entry
.
get
(
"name"
,
f
"task-
{
idx
+
1
}
"
),
server_type
=
entry
.
get
(
"server_type"
),
)
if
not
config
.
server_cmd
or
not
config
.
client_cmd
:
raise
ValueError
(
f
"Missing commands in
{
config
.
name
}
"
)
configs
.
append
(
config
)
return
configs
def
setup_logging
(
debug
:
bool
=
False
):
level
=
logging
.
DEBUG
if
debug
else
logging
.
INFO
logging
.
basicConfig
(
level
=
level
,
format
=
"%(asctime)s - %(levelname)s - %(message)s"
,
handlers
=
[
logging
.
StreamHandler
(),
logging
.
FileHandler
(
"experiment.log"
)],
)
def
format_results
(
results
:
List
[
TaskResult
])
->
str
:
"""Format experiment results in Markdown for GitHub step summary."""
output
=
[
"# Experiment Results
\n
"
]
for
result
in
results
:
output
.
append
(
f
"##
{
result
.
name
}
"
)
output
.
append
(
f
"**Status**:
{
'✅ Success'
if
result
.
success
else
'❌ Failed'
}
"
)
output
.
append
(
f
"**Runtime**:
{
result
.
runtime
:.
2
f
}
seconds"
)
output
.
append
(
f
"**Timestamp**:
{
result
.
timestamp
}
"
)
output
.
append
(
"
\n
**Output**:
\n
```"
)
output
.
append
(
result
.
output
)
output
.
append
(
"```
\n
"
)
return
"
\n
"
.
join
(
output
)
def
write_in_github_step_summary
(
results
:
List
[
TaskResult
]):
"""Write formatted results to GitHub step summary."""
if
not
os
.
environ
.
get
(
"GITHUB_STEP_SUMMARY"
):
logging
.
warning
(
"GITHUB_STEP_SUMMARY environment variable not set"
)
return
formatted_content
=
format_results
(
results
)
with
open
(
os
.
environ
[
"GITHUB_STEP_SUMMARY"
],
"a"
)
as
f
:
f
.
write
(
formatted_content
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Experiment Runner"
)
parser
.
add_argument
(
"--config"
,
type
=
str
,
required
=
True
,
help
=
"Path to YAML config file"
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
help
=
"Enable debug output"
)
args
=
parser
.
parse_args
()
setup_logging
(
args
.
debug
)
logger
=
logging
.
getLogger
(
__name__
)
results
=
[]
try
:
configs
=
load_config
(
args
.
config
)
runner
=
ExperimentRunner
()
for
config
in
configs
:
logger
.
info
(
f
"Running
{
config
.
name
}
"
)
result
=
runner
.
run_task
(
config
)
results
.
append
(
result
)
write_in_github_step_summary
(
results
)
except
Exception
as
e
:
logger
.
error
(
f
"Error:
{
e
}
"
)
raise
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment