"tests/vscode:/vscode.git/clone" did not exist on "100ddd067e388891da8d5677f3af72fc7c5f9edf"
Unverified Commit 1acbaf1b authored by Xingyao Wang's avatar Xingyao Wang Committed by GitHub
Browse files

Add generator-style run_batch function (#2513)


Co-authored-by: default avataropenhands <openhands@all-hands.dev>
parent 287427e2
...@@ -96,6 +96,7 @@ def run_program_batch( ...@@ -96,6 +96,7 @@ def run_program_batch(
default_sampling_para, default_sampling_para,
num_threads, num_threads,
progress_bar, progress_bar,
generator_style=False,
): ):
if hasattr(backend, "endpoint"): if hasattr(backend, "endpoint"):
backend = backend.endpoint backend = backend.endpoint
...@@ -109,6 +110,17 @@ def run_program_batch( ...@@ -109,6 +110,17 @@ def run_program_batch(
num_threads = max(96, multiprocessing.cpu_count() * 16) num_threads = max(96, multiprocessing.cpu_count() * 16)
num_threads = min(num_threads, len(batch_arguments)) num_threads = min(num_threads, len(batch_arguments))
if generator_style:
return _run_program_batch_generator(
program,
backend,
batch_arguments,
default_sampling_para,
num_threads,
progress_bar,
)
# Original code path when generator_style=False
if num_threads == 1: if num_threads == 1:
rets = [] rets = []
if progress_bar: if progress_bar:
...@@ -168,6 +180,64 @@ def run_program_batch( ...@@ -168,6 +180,64 @@ def run_program_batch(
return rets return rets
def _run_program_batch_generator(
program,
backend,
batch_arguments,
default_sampling_para,
num_threads,
progress_bar,
):
"""Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor."""
if num_threads == 1:
iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments
for arguments in iterator:
yield run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
else:
pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None
# Process in chunks to avoid overwhelming ThreadPoolExecutor
# Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks
# so we will never reach "yield" until all tasks are done
chunk_size = 200
with ThreadPoolExecutor(num_threads) as executor:
for chunk_start in range(0, len(batch_arguments), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(batch_arguments))
chunk_futures = []
# Submit chunk of tasks
for i in range(chunk_start, chunk_end):
future = executor.submit(
run_program,
program,
backend,
(),
batch_arguments[i],
default_sampling_para,
False,
True,
)
if pbar:
future.add_done_callback(lambda _: pbar.update())
chunk_futures.append(future)
# Yield results from this chunk as they complete
for future in chunk_futures:
yield future.result()
if pbar:
pbar.close()
def cache_program(program, backend): def cache_program(program, backend):
from sglang.lang.tracer import extract_prefix_by_tracing from sglang.lang.tracer import extract_prefix_by_tracing
......
...@@ -227,6 +227,7 @@ class SglFunction: ...@@ -227,6 +227,7 @@ class SglFunction:
backend=None, backend=None,
num_threads: Union[str, int] = "auto", num_threads: Union[str, int] = "auto",
progress_bar: bool = False, progress_bar: bool = False,
generator_style: bool = False,
): ):
from sglang.lang.interpreter import run_program_batch from sglang.lang.interpreter import run_program_batch
...@@ -277,6 +278,7 @@ class SglFunction: ...@@ -277,6 +278,7 @@ class SglFunction:
default_sampling_para, default_sampling_para,
num_threads, num_threads,
progress_bar, progress_bar,
generator_style=generator_style,
) )
def trace(self, *, backend=None, **kwargs): def trace(self, *, backend=None, **kwargs):
......
...@@ -509,13 +509,35 @@ def test_hellaswag_select(): ...@@ -509,13 +509,35 @@ def test_hellaswag_select():
temperature=0, temperature=0,
num_threads=64, num_threads=64,
progress_bar=True, progress_bar=True,
generator_style=False,
) )
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))] preds = []
for i, ret in enumerate(rets):
preds.append(choices[i].index(ret["answer"]))
latency = time.time() - tic latency = time.time() - tic
# Compute accuracy # Compute accuracy
accuracy = np.mean(np.array(preds) == np.array(labels)) accuracy = np.mean(np.array(preds) == np.array(labels))
# Test generator style of run_batch
tic = time.time()
rets = few_shot_hellaswag.run_batch(
arguments,
temperature=0,
num_threads=64,
progress_bar=True,
generator_style=True,
)
preds_gen = []
for i, ret in enumerate(rets):
preds_gen.append(choices[i].index(ret["answer"]))
latency_gen = time.time() - tic
# Compute accuracy
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
assert np.abs(accuracy_gen - accuracy) < 0.01
assert np.abs(latency_gen - latency) < 1
return accuracy, latency return accuracy, latency
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment