Unverified Commit 1e495e08 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Fix select by ensuring each request has at least one token (#1318)

parent 12cb115d
......@@ -178,19 +178,22 @@ class Req:
def adjust_max_prefix_ids(self):
self.fill_ids = self.origin_input_ids + self.output_ids
input_len = len(self.fill_ids)
max_prefix_len = input_len
# FIXME: To work around some bugs in logprob computation, we need to ensure each
# request has at least one token. Later, we can relax this requirement and use `input_len`.
max_prefix_len = input_len - 1
if self.sampling_params.max_new_tokens > 0:
# Need at least one token to compute logits
max_prefix_len = min(max_prefix_len, input_len - 1)
if self.return_logprob:
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
if self.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
max_prefix_len = min(max_prefix_len, input_len - 2)
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
max_prefix_len = max(max_prefix_len, 0)
return self.fill_ids[:max_prefix_len]
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
......
......@@ -2,8 +2,12 @@
import json
import re
import time
import numpy as np
import sglang as sgl
from sglang.utils import fetch_and_cache_jsonl
def test_few_shot_qa():
......@@ -447,3 +451,67 @@ def test_chat_completion_speculative():
)
gen_character_spec().sync()
def test_hellaswag_select():
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
lines = fetch_and_cache_jsonl(url)
# Construct prompts
def get_one_example(lines, i, include_answer):
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
if include_answer:
ret += lines[i]["endings"][lines[i]["label"]]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
num_questions = 200
num_shots = 20
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
choices = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_hellaswag(s, question, choices):
s += few_shot_examples + question
s += sgl.select("answer", choices=choices)
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic = time.time()
rets = few_shot_hellaswag.run_batch(
arguments,
temperature=0,
num_threads=64,
progress_bar=True,
)
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
latency = time.time() - tic
# Compute accuracy
accuracy = np.mean(np.array(preds) == np.array(labels))
return accuracy, latency
......@@ -4,6 +4,7 @@ import base64
import importlib
import json
import logging
import os
import signal
import sys
import traceback
......@@ -15,6 +16,7 @@ from typing import Union
import numpy as np
import requests
from tqdm import tqdm
logger = logging.getLogger(__name__)
......@@ -260,3 +262,40 @@ class LazyImport:
def __call__(self, *args, **kwargs):
module = self._load()
return module(*args, **kwargs)
def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"):
"""Read and cache a jsonl file from a url."""
# Check if the cache file already exists
if os.path.exists(cache_file):
print("Loading data from cache...")
with open(cache_file, "r") as f:
data = [json.loads(line) for line in f]
else:
print("Downloading data from URL...")
# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors
# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB
# Use tqdm to display the progress bar
with open(cache_file, "wb") as f, tqdm(
desc=cache_file,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
# Convert the data to a list of dictionaries
with open(cache_file, "r") as f:
data = [json.loads(line) for line in f]
return data
......@@ -7,6 +7,7 @@ from sglang.test.test_programs import (
test_dtype_gen,
test_expert_answer,
test_few_shot_qa,
test_hellaswag_select,
test_mt_bench,
test_parallel_decoding,
test_regex,
......@@ -62,6 +63,12 @@ class TestSRTBackend(unittest.TestCase):
def test_dtype_gen(self):
test_dtype_gen()
def test_hellaswag_select(self):
# Run twice to capture more bugs
for _ in range(2):
accuracy, latency = test_hellaswag_select()
assert accuracy > 0.71
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment