Unverified Commit e4d68afc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Many cleanup (#1357)

parent c9b75917
......@@ -7,7 +7,7 @@ import time
import numpy as np
import sglang as sgl
from sglang.utils import fetch_and_cache_jsonl
from sglang.utils import download_and_cache_file, read_jsonl
def test_few_shot_qa():
......@@ -456,10 +456,6 @@ def test_chat_completion_speculative():
def test_hellaswag_select():
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
lines = fetch_and_cache_jsonl(url)
# Construct prompts
def get_one_example(lines, i, include_answer):
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
if include_answer:
......@@ -472,6 +468,12 @@ def test_hellaswag_select():
ret += get_one_example(lines, i, True) + "\n\n"
return ret
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = 200
num_shots = 20
few_shot_examples = get_few_shot_examples(lines, num_shots)
......
......@@ -12,7 +12,7 @@ import urllib.request
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from json import dumps
from typing import Union
from typing import Optional, Union
import numpy as np
import requests
......@@ -38,13 +38,11 @@ def is_same_type(values: list):
def read_jsonl(filename: str):
"""Read a JSONL file."""
rets = []
with open(filename) as fin:
for line in fin:
if line.startswith("#"):
continue
rets.append(json.loads(line))
return rets
yield json.loads(line)
def dump_state_text(filename: str, states: list, mode: str = "w"):
......@@ -264,38 +262,35 @@ class LazyImport:
return module(*args, **kwargs)
def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"):
"""Read and cache a jsonl file from a url."""
def download_and_cache_file(url: str, filename: Optional[str] = None):
"""Read and cache a file from a url."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])
# Check if the cache file already exists
if os.path.exists(cache_file):
print("Loading data from cache...")
with open(cache_file, "r") as f:
data = [json.loads(line) for line in f]
else:
print("Downloading data from URL...")
# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors
# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB
# Use tqdm to display the progress bar
with open(cache_file, "wb") as f, tqdm(
desc=cache_file,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
# Convert the data to a list of dictionaries
with open(cache_file, "r") as f:
data = [json.loads(line) for line in f]
return data
if os.path.exists(filename):
return filename
print(f"Downloading from {url} to {filename}")
# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors
# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB
# Use tqdm to display the progress bar
with open(filename, "wb") as f, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
return filename
......@@ -42,7 +42,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.63, f"{metrics}"
assert metrics["score"] >= 0.62, f"{metrics}"
def test_human_eval(self):
args = SimpleNamespace(
......@@ -66,7 +66,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.63, f"{metrics}"
assert metrics["score"] >= 0.62, f"{metrics}"
if __name__ == "__main__":
......
import json
import unittest
from sglang.srt.server_args import prepare_server_args
......@@ -15,7 +16,7 @@ class TestPrepareServerArgs(unittest.TestCase):
)
self.assertEqual(server_args.model_path, "model_path")
self.assertEqual(
server_args.json_model_override_args,
json.loads(server_args.json_model_override_args),
{"rope_scaling": {"factor": 2.0, "type": "linear"}},
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment