Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1acbaf1b
Unverified
Commit
1acbaf1b
authored
Jan 06, 2025
by
Xingyao Wang
Committed by
GitHub
Jan 06, 2025
Browse files
Add generator-style run_batch function (#2513)
Co-authored-by:
openhands
<
openhands@all-hands.dev
>
parent
287427e2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
95 additions
and
1 deletion
+95
-1
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+70
-0
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+2
-0
python/sglang/test/test_programs.py
python/sglang/test/test_programs.py
+23
-1
No files found.
python/sglang/lang/interpreter.py
View file @
1acbaf1b
...
...
@@ -96,6 +96,7 @@ def run_program_batch(
default_sampling_para
,
num_threads
,
progress_bar
,
generator_style
=
False
,
):
if
hasattr
(
backend
,
"endpoint"
):
backend
=
backend
.
endpoint
...
...
@@ -109,6 +110,17 @@ def run_program_batch(
num_threads
=
max
(
96
,
multiprocessing
.
cpu_count
()
*
16
)
num_threads
=
min
(
num_threads
,
len
(
batch_arguments
))
if
generator_style
:
return
_run_program_batch_generator
(
program
,
backend
,
batch_arguments
,
default_sampling_para
,
num_threads
,
progress_bar
,
)
# Original code path when generator_style=False
if
num_threads
==
1
:
rets
=
[]
if
progress_bar
:
...
...
@@ -168,6 +180,64 @@ def run_program_batch(
return
rets
def
_run_program_batch_generator
(
program
,
backend
,
batch_arguments
,
default_sampling_para
,
num_threads
,
progress_bar
,
):
"""Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor."""
if
num_threads
==
1
:
iterator
=
tqdm
.
tqdm
(
batch_arguments
)
if
progress_bar
else
batch_arguments
for
arguments
in
iterator
:
yield
run_program
(
program
,
backend
,
(),
arguments
,
default_sampling_para
,
False
,
True
,
)
else
:
pbar
=
tqdm
.
tqdm
(
total
=
len
(
batch_arguments
))
if
progress_bar
else
None
# Process in chunks to avoid overwhelming ThreadPoolExecutor
# Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks
# so we will never reach "yield" until all tasks are done
chunk_size
=
200
with
ThreadPoolExecutor
(
num_threads
)
as
executor
:
for
chunk_start
in
range
(
0
,
len
(
batch_arguments
),
chunk_size
):
chunk_end
=
min
(
chunk_start
+
chunk_size
,
len
(
batch_arguments
))
chunk_futures
=
[]
# Submit chunk of tasks
for
i
in
range
(
chunk_start
,
chunk_end
):
future
=
executor
.
submit
(
run_program
,
program
,
backend
,
(),
batch_arguments
[
i
],
default_sampling_para
,
False
,
True
,
)
if
pbar
:
future
.
add_done_callback
(
lambda
_
:
pbar
.
update
())
chunk_futures
.
append
(
future
)
# Yield results from this chunk as they complete
for
future
in
chunk_futures
:
yield
future
.
result
()
if
pbar
:
pbar
.
close
()
def
cache_program
(
program
,
backend
):
from
sglang.lang.tracer
import
extract_prefix_by_tracing
...
...
python/sglang/lang/ir.py
View file @
1acbaf1b
...
...
@@ -227,6 +227,7 @@ class SglFunction:
backend
=
None
,
num_threads
:
Union
[
str
,
int
]
=
"auto"
,
progress_bar
:
bool
=
False
,
generator_style
:
bool
=
False
,
):
from
sglang.lang.interpreter
import
run_program_batch
...
...
@@ -277,6 +278,7 @@ class SglFunction:
default_sampling_para
,
num_threads
,
progress_bar
,
generator_style
=
generator_style
,
)
def
trace
(
self
,
*
,
backend
=
None
,
**
kwargs
):
...
...
python/sglang/test/test_programs.py
View file @
1acbaf1b
...
...
@@ -509,13 +509,35 @@ def test_hellaswag_select():
temperature
=
0
,
num_threads
=
64
,
progress_bar
=
True
,
generator_style
=
False
,
)
preds
=
[
choices
[
i
].
index
(
rets
[
i
][
"answer"
])
for
i
in
range
(
len
(
rets
))]
preds
=
[]
for
i
,
ret
in
enumerate
(
rets
):
preds
.
append
(
choices
[
i
].
index
(
ret
[
"answer"
]))
latency
=
time
.
time
()
-
tic
# Compute accuracy
accuracy
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
# Test generator style of run_batch
tic
=
time
.
time
()
rets
=
few_shot_hellaswag
.
run_batch
(
arguments
,
temperature
=
0
,
num_threads
=
64
,
progress_bar
=
True
,
generator_style
=
True
,
)
preds_gen
=
[]
for
i
,
ret
in
enumerate
(
rets
):
preds_gen
.
append
(
choices
[
i
].
index
(
ret
[
"answer"
]))
latency_gen
=
time
.
time
()
-
tic
# Compute accuracy
accuracy_gen
=
np
.
mean
(
np
.
array
(
preds_gen
)
==
np
.
array
(
labels
))
assert
np
.
abs
(
accuracy_gen
-
accuracy
)
<
0.01
assert
np
.
abs
(
latency_gen
-
latency
)
<
1
return
accuracy
,
latency
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment