Unverified Commit 4d1c9db6 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: support loogle eval (#6190)

parent 17c36c55
# Code Structures # Code Structures
- `eval`: The evaluation utilities.
- `lang`: The frontend language. - `lang`: The frontend language.
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime). - `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
- `test`: The test utilities. - `test`: The test utilities.
...@@ -11,6 +12,5 @@ ...@@ -11,6 +12,5 @@
- `check_env.py`: Check the environment variables and dependencies. - `check_env.py`: Check the environment variables and dependencies.
- `global_config.py`: The global configs and constants. - `global_config.py`: The global configs and constants.
- `launch_server.py`: The entry point for launching the local server. - `launch_server.py`: The entry point for launching the local server.
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
- `utils.py`: Common utilities. - `utils.py`: Common utilities.
- `version.py`: Version info. - `version.py`: Version info.
import argparse
import asyncio
import os
import pickle
from pathlib import Path
from typing import List
import openai
import torch
from bert_score import BERTScorer
from datasets import load_dataset
from tqdm import tqdm
def get_client(api_url: str) -> openai.AsyncOpenAI:
if os.getenv("OPENAI_API_KEY") is None:
os.environ["OPENAI_API_KEY"] = "EMPTY"
return openai.AsyncOpenAI(base_url=api_url)
def get_dataset():
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
async def fetch_response(
client: openai.AsyncOpenAI,
context: str,
question: str,
semaphore: asyncio.Semaphore,
index: int,
model: str,
output_dir: Path,
):
output_file = output_dir / f"response_{index}.pkl"
if output_file.exists():
return
prompt = (
"Please answer the question based on the long texts below.\n"
f"{context}\n"
f"Question: {question}\n"
"Answer:"
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
async with semaphore:
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0.0,
max_tokens=512,
)
except openai.BadRequestError as e:
with open(output_file, "wb") as f:
pickle.dump({"error": str(e)}, f)
return
with open(output_file, "wb") as f:
pickle.dump(response, f)
async def benchmark(args):
dataset = get_dataset()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
client = get_client(args.api_url)
semaphore = asyncio.Semaphore(args.max_concurrency)
tasks: List[asyncio.Task] = []
for idx, ex in enumerate(dataset):
tasks.append(
asyncio.create_task(
fetch_response(
client,
ex["context"],
ex["question"],
semaphore,
idx,
args.model,
output_dir,
)
)
)
for _ in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
):
await _
def analyse(args):
dataset = get_dataset()
output_dir = Path(args.output_dir)
device = "cuda" if torch.cuda.is_available() else "cpu"
scorer = BERTScorer(lang="en", device=device)
hyps: List[str] = []
refs: List[str] = []
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
pkl_file = output_dir / f"response_{idx}.pkl"
if not pkl_file.exists():
raise FileNotFoundError(pkl_file)
response = pickle.load(open(pkl_file, "rb"))
if isinstance(response, dict) and "error" in response:
continue
hyps.append(response.choices[0].message.content.strip())
refs.append(ex["answer"])
if not hyps:
print("No valid responses to score!")
return
batch_size = 64
all_f1: List[float] = []
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
h_batch = hyps[i : i + batch_size]
r_batch = refs[i : i + batch_size]
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
all_f1.extend([float(x) for x in f1_scores])
avg = sum(all_f1) / len(all_f1)
print(f"Average BERTScore (F1): {avg:.2%}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run benchmark and evaluation in one go."
)
parser.add_argument(
"--api-url",
default="http://127.0.0.1:30000/v1",
help="OpenAI‑compatible API base URL",
)
parser.add_argument(
"--model",
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
help="Model name or ID",
)
parser.add_argument(
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
)
parser.add_argument(
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
)
args = parser.parse_args()
asyncio.run(benchmark(args))
analyse(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment