Unverified Commit 7862a723 authored by Ceng's avatar Ceng Committed by GitHub
Browse files

issue/122 :更新benchmark脚本和README.md (#123)



* issue/122 :更新benchmark脚本和README.md
Signed-off-by: default avatarCeng23333 <441651826@qq.com>

* .
Signed-off-by: default avatarCeng23333 <441651826@qq.com>

* fix input_ids
Signed-off-by: default avatarCeng23333 <441651826@qq.com>

* explicitly split mmul all subject
Signed-off-by: default avatarCeng23333 <441651826@qq.com>

---------
Signed-off-by: default avatarCeng23333 <441651826@qq.com>
parent cdce626e
...@@ -87,3 +87,48 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA ...@@ -87,3 +87,48 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
```bash ```bash
python examples/jiuge.py --nvidia --model_path=/models/9G7B_MHA/ --backend=cpp --tp=4 --batch_size=16 python examples/jiuge.py --nvidia --model_path=/models/9G7B_MHA/ --backend=cpp --tp=4 --batch_size=16
``` ```
- 运行推理基准测试(C-Eval/MMLU)
```bash
python test/bench/test_benchmark.py [--cpu | --nvidia | --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] <path/to/model_dir> --bench {ceval|mmlu} [--backend cpp] [--ndev N] [--subject SUBJECT] [--num_samples N] [--max_new_tokens N] [--output_csv PATH] [--cache_dir PATH]
```
- 参数说明:
- `--subject`: 指定科目,支持单个科目、多个科目(逗号分隔)或 `all`(默认值,加载全部科目)
- `--output_csv`: 可选,指定CSV输出文件路径。如未指定则不生成CSV文件。CSV包含每个科目的结果和总体结果
- `--cache_dir`: 可选,指定数据集缓存目录的父目录。应指向包含 `ceval___ceval-exam` 和 `cais___mmlu` 等数据集子目录的父目录(例如 `~/.cache/huggingface/datasets/`)。设置后脚本优先使用本地 CSV(`pandas.read_csv`)离线加载数据,避免 `load_dataset` 的网络请求
- C-Eval示例:
- 单个科目:
```bash
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench ceval --subject middle_school_mathematics --num_samples 100 --backend cpp --ndev 1
```
- 多个科目(逗号分隔):
```bash
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench ceval --subject middle_school_mathematics,high_school_physics --backend cpp --ndev 1 --output_csv results.csv
```
- 全部科目并输出CSV:
```bash
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench ceval --subject all --backend cpp --ndev 1 --output_csv results.csv
```
- 使用缓存目录加速加载:
```bash
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench ceval --subject middle_school_mathematics --backend cpp --ndev 1 --cache_dir ~/.cache/huggingface/datasets/
```
> 注意:`--cache_dir` 应指向包含 `ceval___ceval-exam` 和 `cais___mmlu` 等数据集子目录的父目录,而不是直接指向这些子目录
- MMLU示例:
- 单个科目:
```bash
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench mmlu --subject abstract_algebra --backend cpp --ndev 1
```
- 多个科目(逗号分隔):
```bash
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench mmlu --subject abstract_algebra,anatomy,astronomy --backend cpp --ndev 1 --output_csv results.csv
```
- 使用缓存目录加速加载:
```bash
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench mmlu --subject abstract_algebra --backend cpp --ndev 1 --cache_dir ~/.cache/huggingface/datasets/
```
> 注意:`--cache_dir` 应指向包含 `ceval___ceval-exam` 和 `cais___mmlu` 等数据集子目录的父目录,而不是直接指向这些子目录
...@@ -3,11 +3,12 @@ import os ...@@ -3,11 +3,12 @@ import os
import argparse import argparse
import time import time
import re import re
from datasets import load_dataset import csv
from datasets import load_dataset, Dataset
import infinicore import infinicore
import infinilm import infinilm
from infinilm.models.llama import AutoLlamaModel from infinilm.models.llama import AutoLlamaModel
from infinilm.modeling_utils import get_model_state_dict from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig from infinilm.distributed import DistConfig
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -112,12 +113,11 @@ class InfiniLMBenchmark(BaseBenchmark): ...@@ -112,12 +113,11 @@ class InfiniLMBenchmark(BaseBenchmark):
# Load weights # Load weights
print("Loading model weights...") print("Loading model weights...")
model_param_infini = get_model_state_dict( load_model_state_dict_by_file(
self.model,
model_dir_path, model_dir_path,
device=self.device,
dtype=self.dtype, dtype=self.dtype,
) )
self.model.load_state_dict(model_param_infini)
print("Model loaded successfully") print("Model loaded successfully")
def max_context_len(self): def max_context_len(self):
...@@ -195,7 +195,7 @@ class InfiniLMBenchmark(BaseBenchmark): ...@@ -195,7 +195,7 @@ class InfiniLMBenchmark(BaseBenchmark):
""" """
# Convert tokens to infinicore format # Convert tokens to infinicore format
input_ids_list = [tokens] input_ids_list = [tokens]
input_ids = infinicore.from_list(input_ids_list, dtype=infinicore.int64).to(self.device) input_ids = infinicore.from_list(input_ids_list)
# Use model's built-in generate() method which properly handles KV cache # Use model's built-in generate() method which properly handles KV cache
# Pass sampling parameters (temperature, topk, topp) via kwargs # Pass sampling parameters (temperature, topk, topp) via kwargs
...@@ -246,11 +246,197 @@ def extract_answer_mmlu(output_content): ...@@ -246,11 +246,197 @@ def extract_answer_mmlu(output_content):
return None return None
def evaluate_samples(model, samples, benchmark, max_new_tokens, subject_name=None):
"""Evaluate samples for a single subject and return results"""
answers_list = []
for idx, sample in enumerate(samples):
if benchmark == "ceval":
input_content = f"'question':{sample['question']},'A': {sample['A']}, 'B':{sample['B']}, 'C': {sample['C']},'D': {sample['D']}。"
conversation = [
{
"role": "system",
"content": "请从question的A,B,C,D四个选项中选择正确的选项。例如,标准答案:A。",
},
{"role": "user", "content": input_content},
]
answer = sample["answer"]
output_content, avg_time = model.generate(
conversation, max_steps=max_new_tokens, topp_=1.0, topk_=1, temperature_=1.0
)
is_correct = extract_answer_ceval(output_content, answer)
answers_list.append({
"id": sample.get("id", idx),
"output_content": output_content,
"answer": answer,
"is_correct": is_correct,
"subject": subject_name
})
if benchmark == "ceval":
print("标准答案:", answer)
elif benchmark == "mmlu":
question = sample['question']
choices = sample['choices']
answer_idx = sample['answer'] # MMLU answer is 0-3 index
output_content, avg_time = model.generate(
question, choices, max_steps=max_new_tokens, topp_=1.0, topk_=1, temperature_=1.0
)
predicted_answer = extract_answer_mmlu(output_content)
# Convert answer index to letter for display
answer_letter = chr(65 + answer_idx) if answer_idx < 4 else "?"
predicted_letter = chr(65 + predicted_answer) if predicted_answer is not None and predicted_answer < 4 else "?"
print(f"Sample {idx}: Correct answer: {answer_letter} ({answer_idx}), Predicted: {predicted_letter} ({predicted_answer})")
answers_list.append({
"id": idx,
"output_content": output_content,
"answer": answer_idx,
"predicted": predicted_answer,
"subject": subject_name
})
# Evaluate results for this subject
true_num = 0
all_num = 0
for cont in answers_list:
id = cont["id"]
all_num = all_num + 1
if benchmark == "ceval":
answer = cont["answer"]
is_correct = cont["is_correct"]
if is_correct:
true_num = true_num + 1
print(f"id {id} : ", "正确")
else:
print(f"id {id}: ", "错误")
elif benchmark == "mmlu":
answer = cont["answer"]
predicted = cont["predicted"]
if predicted is not None and predicted == answer:
true_num = true_num + 1
print(f"id {id}: Correct")
else:
answer_letter = chr(65 + answer) if answer < 4 else "?"
predicted_letter = chr(65 + predicted) if predicted is not None and predicted < 4 else "?"
print(f"id {id}: Wrong (correct: {answer_letter}, predicted: {predicted_letter})")
accuracy = true_num / all_num if all_num > 0 else 0.0
if benchmark == "ceval":
print(f"成绩: {true_num}/{all_num}", accuracy)
else:
print(f"Accuracy: {true_num}/{all_num} = {accuracy:.2%}")
return {
"subject": subject_name or "all",
"correct": true_num,
"total": all_num,
"accuracy": accuracy,
"answers_list": answers_list
}
def _load_ceval_from_cache(cache_dir, subject_name, split, ceval_subjects):
"""
Load CEval data from local cache avoiding network calls.
Scans cached Arrow files under ceval___ceval-exam and filters by split.
"""
split_names = (
["test"] if split == "test" else ["val"] if split == "val" else ["val", "test"]
)
base = os.path.join(cache_dir, "ceval___ceval-exam", subject_name)
if os.path.isdir(base):
records = []
for root, _, files in os.walk(base):
for fname in files:
if not fname.endswith(".arrow"):
continue
lower = fname.lower()
if split == "test" and "test" not in lower:
continue
if split == "val" and not any(x in lower for x in ["val", "validation", "dev"]):
continue
if split == "all" and not any(x in lower for x in ["val", "validation", "dev", "test"]):
continue
try:
ds = Dataset.from_file(os.path.join(root, fname))
records.extend(ds.to_list())
except Exception:
continue
if records:
return records
# If cache_dir provided and nothing loaded, fail without network
raise FileNotFoundError(f"CEval cached data not found for subject '{subject_name}' with splits {split_names}")
def _load_mmlu_from_cache(cache_dir, subject_name, split, mmlu_subjects):
"""
Load MMLU data from local cache avoiding network calls.
Scans cached Arrow files under cache_dir/cais___mmlu and filters by split.
"""
def load_one(subj):
split_names = (
["test"]
if split == "test"
else ["validation", "dev"]
if split == "val"
else ["validation", "dev", "test"]
)
base = os.path.join(cache_dir, "cais___mmlu", subj)
if not os.path.isdir(base):
raise FileNotFoundError(f"MMLU cache dir not found: {base}")
records = []
for root, _, files in os.walk(base):
for fname in files:
if not fname.endswith(".arrow"):
continue
lower = fname.lower()
if split == "test" and "test" not in lower:
continue
if split == "val" and not any(x in lower for x in ["validation", "dev"]):
continue
if split == "all" and not any(x in lower for x in ["validation", "dev", "test"]):
continue
try:
ds = Dataset.from_file(os.path.join(root, fname))
records.extend(ds.to_list())
except Exception:
continue
if records:
return records
raise FileNotFoundError(f"MMLU cached data not found for subject '{subj}' with splits {split_names}")
if subject_name == "all":
# Use hardcoded list of MMLU subjects, excluding "all"
all_samples = []
for subj in mmlu_subjects:
try:
all_samples.extend(load_one(subj))
except FileNotFoundError:
continue
if not all_samples:
raise FileNotFoundError(
f"No MMLU cached data found for any subject. Please ensure datasets are cached."
)
return all_samples, "all"
return load_one(subject_name), subject_name
def test(): def test():
# Parse arguments manually to handle device flags properly # Parse arguments manually to handle device flags properly
if len(sys.argv) < 4: if len(sys.argv) < 4:
print( print(
"Usage: python test_benchmark.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] <path/to/model_dir> --bench [ceval|mmlu] [--backend cpp] [--ndev N] [--subject SUBJECT] [--num_samples N] [--max_new_tokens N]" "Usage: python test_benchmark.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] <path/to/model_dir> --bench [ceval|mmlu] [--backend cpp] [--ndev N] [--subject SUBJECT] [--split {test|val|all}] [--num_samples N] [--max_new_tokens N] [--output_csv PATH] [--cache_dir PATH]"
) )
sys.exit(1) sys.exit(1)
...@@ -262,10 +448,12 @@ def test(): ...@@ -262,10 +448,12 @@ def test():
backend = "cpp" backend = "cpp"
ndev = 1 ndev = 1
benchmark = None benchmark = None
subject = None # For MMLU subject = "all" # Shared for both C-Eval and MMLU, can be comma-separated
dataset_name = "middle_school_mathematics" # For C-Eval split = "test" # test | val | all
num_samples = None num_samples = None
max_new_tokens = 500 max_new_tokens = 500
output_csv = None
cache_dir = None
i = 3 i = 3
while i < len(sys.argv): while i < len(sys.argv):
...@@ -281,8 +469,8 @@ def test(): ...@@ -281,8 +469,8 @@ def test():
elif sys.argv[i] == "--subject" and i + 1 < len(sys.argv): elif sys.argv[i] == "--subject" and i + 1 < len(sys.argv):
subject = sys.argv[i + 1] subject = sys.argv[i + 1]
i += 2 i += 2
elif sys.argv[i] == "--dataset" and i + 1 < len(sys.argv): elif sys.argv[i] == "--split" and i + 1 < len(sys.argv):
dataset_name = sys.argv[i + 1] split = sys.argv[i + 1]
i += 2 i += 2
elif sys.argv[i] == "--num_samples" and i + 1 < len(sys.argv): elif sys.argv[i] == "--num_samples" and i + 1 < len(sys.argv):
num_samples = int(sys.argv[i + 1]) num_samples = int(sys.argv[i + 1])
...@@ -290,6 +478,12 @@ def test(): ...@@ -290,6 +478,12 @@ def test():
elif sys.argv[i] == "--max_new_tokens" and i + 1 < len(sys.argv): elif sys.argv[i] == "--max_new_tokens" and i + 1 < len(sys.argv):
max_new_tokens = int(sys.argv[i + 1]) max_new_tokens = int(sys.argv[i + 1])
i += 2 i += 2
elif sys.argv[i] == "--output_csv" and i + 1 < len(sys.argv):
output_csv = sys.argv[i + 1]
i += 2
elif sys.argv[i] == "--cache_dir" and i + 1 < len(sys.argv):
cache_dir = sys.argv[i + 1]
i += 2
else: else:
i += 1 i += 1
...@@ -323,63 +517,243 @@ def test(): ...@@ -323,63 +517,243 @@ def test():
device_type_str = "hygon" device_type_str = "hygon"
else: else:
print( print(
"Usage: python test_benchmark.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] <path/to/model_dir> --bench [ceval|mmlu] [--backend cpp] [--ndev N] [--subject SUBJECT] [--num_samples N] [--max_new_tokens N]" "Usage: python test_benchmark.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] <path/to/model_dir> --bench [ceval|mmlu] [--backend cpp] [--ndev N] [--subject SUBJECT] [--num_samples N] [--max_new_tokens N] [--output_csv PATH] [--cache_dir PATH]"
) )
sys.exit(1) sys.exit(1)
# Load dataset based on benchmark # Normalize cache_dir and force offline when provided
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
# Parse comma-separated subjects
if split not in ["test", "val", "all"]:
print("Error: --split must be one of: test, val, all")
sys.exit(1)
if subject and subject != "all":
subject_list = [s.strip() for s in subject.split(",")]
else:
subject_list = ["all"]
# Create model based on backend (create once, reuse for all subjects)
if backend != "010":
model = InfiniLMBenchmark(model_path, device_type_str, ndev, backend, benchmark)
else:
print(f"test 010 backend by scripts/test_ceval.py")
exit(0)
# Define helper functions for loading datasets
if benchmark == "ceval": if benchmark == "ceval":
# Load C-Eval dataset ceval_subjects = [
# https://huggingface.co/datasets/ceval/ceval-exam/tree/main/middle_school_geography "accountant",
print(f"Loading C-Eval dataset (dataset: {dataset_name})...") "advanced_mathematics",
"art_studies",
"basic_medicine",
"business_administration",
"chinese_language_and_literature",
"civil_servant",
"clinical_medicine",
"college_chemistry",
"college_economics",
"college_physics",
"college_programming",
"computer_architecture",
"computer_network",
"discrete_mathematics",
"education_science",
"electrical_engineer",
"environmental_impact_assessment_engineer",
"fire_engineer",
"high_school_biology",
"high_school_chemistry",
"high_school_chinese",
"high_school_geography",
"high_school_history",
"high_school_mathematics",
"high_school_physics",
"high_school_politics",
"ideological_and_moral_cultivation",
"law",
"legal_professional",
"logic",
"mao_zedong_thought",
"marxism",
"metrology_engineer",
"middle_school_biology",
"middle_school_chemistry",
"middle_school_geography",
"middle_school_history",
"middle_school_mathematics",
"middle_school_physics",
"middle_school_politics",
"modern_chinese_history",
"operating_system",
"physician",
"plant_protection",
"probability_and_statistics",
"professional_tour_guide",
"sports_science",
"tax_accountant",
"teacher_qualification",
"urban_and_rural_planner",
"veterinary_medicine",
]
def _load_ceval_subject(subj):
print(f"Loading C-Eval dataset (subject: {subj})...")
if cache_dir:
return _load_ceval_from_cache(cache_dir, subj, split, ceval_subjects)
# online fallback via HF load_dataset
if split == "all":
records = []
for split_name in ["val", "test"]:
try: try:
dataset = load_dataset(r"ceval/ceval-exam", name=dataset_name) ds = load_dataset(r"ceval/ceval-exam", name=subj, split=split_name)
samples = dataset["val"] records.extend(ds.to_list())
# Convert Dataset to list if needed except Exception:
if hasattr(samples, 'to_list'): continue
samples = samples.to_list() if records:
return records
raise FileNotFoundError(f"No ceval splits found online for subject {subj}")
hf_split = "test" if split == "test" else "val"
ds = load_dataset(r"ceval/ceval-exam", name=subj, split=hf_split)
data = ds.to_list()
return data
def load_subject_samples(subj_name):
if subj_name == "all":
samples = []
for subj in ceval_subjects:
samples.extend(_load_ceval_subject(subj))
return samples, "all"
else: else:
samples = list(samples) if subj_name not in ceval_subjects:
except Exception as e: raise ValueError(f"Unknown C-Eval subject '{subj_name}'. Available subjects: {', '.join(ceval_subjects)}")
print(f"Error loading dataset: {e}") return _load_ceval_subject(subj_name), subj_name
print("Available datasets: middle_school_mathematics, high_school_history, high_school_chinese, high_school_physics, middle_school_geography, middle_school_physics")
sys.exit(1)
elif benchmark == "mmlu": elif benchmark == "mmlu":
# Load MMLU dataset mmlu_subjects = [
# https://huggingface.co/datasets/cais/mmlu "abstract_algebra",
if subject is None: "anatomy",
subject = "all" "astronomy",
print(f"Loading MMLU dataset (subject: {subject})...") "business_ethics",
try: "clinical_knowledge",
if subject == "all": "college_biology",
dataset = load_dataset("cais/mmlu", "all") "college_chemistry",
# Combine all subjects into a single dataset "college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
def _load_mmlu_subject(subj):
print(f"Loading MMLU dataset (subject: {subj})...")
if cache_dir:
return _load_mmlu_from_cache(cache_dir, subj, split, mmlu_subjects)
if subj == "all":
samples = [] samples = []
for subject_name in dataset.keys(): splits_to_load = ["test"] if split == "test" else ["validation"] if split == "val" else ["validation", "test"]
if subject_name in ["train", "validation", "test"]: # Load each subject individually from hardcoded list, excluding "all"
continue for subject_name in mmlu_subjects:
# Convert Dataset to list for sp in splits_to_load:
test_data = dataset[subject_name]["test"] try:
if hasattr(test_data, 'to_list'): dataset = load_dataset("cais/mmlu", subject_name, split=sp)
samples.extend(test_data.to_list()) if hasattr(dataset, 'to_list'):
samples.extend(dataset.to_list())
else: else:
samples.extend(list(test_data)) samples.extend(list(dataset))
except Exception:
continue
if not samples:
raise FileNotFoundError(f"No MMLU data found for any subject in the list")
return samples, "all"
else: else:
dataset = load_dataset("cais/mmlu", subject) splits_to_load = ["test"] if split == "test" else ["validation"] if split == "val" else ["validation", "test"]
test_data = dataset["test"] records = []
# Convert Dataset to list for sp in splits_to_load:
if hasattr(test_data, 'to_list'): try:
samples = test_data.to_list() dataset = load_dataset("cais/mmlu", subj, split=sp)
if hasattr(dataset, 'to_list'):
records.extend(dataset.to_list())
else: else:
samples = list(test_data) records.extend(list(dataset))
except Exception as e: except Exception:
print(f"Error loading dataset: {e}") continue
print("Available subjects: abstract_algebra, anatomy, astronomy, business_ethics, etc.") if not records:
print("Use --subject all to load all subjects") raise FileNotFoundError(f"MMLU subject {subj} split(s) {splits_to_load} not found")
sys.exit(1) return records, subj
print(f"Loaded {len(samples)} samples") def load_subject_samples(subj_name):
return _load_mmlu_subject(subj_name)
# Expand "all" to individual subjects for per-subject reporting
if "all" in subject_list:
if benchmark == "ceval":
# Replace "all" with all individual ceval subjects
subject_list = [s for s in subject_list if s != "all"] + ceval_subjects
elif benchmark == "mmlu":
# Replace "all" with all individual mmlu subjects
subject_list = [s for s in subject_list if s != "all"] + mmlu_subjects
# Evaluate each subject separately
all_results = []
for subj in subject_list:
print(f"\n{'='*60}")
print(f"Evaluating subject: {subj}")
print(f"{'='*60}\n")
try:
samples, actual_subj_name = load_subject_samples(subj)
print(f"Loaded {len(samples)} samples for subject: {actual_subj_name}")
# Limit number of samples if specified # Limit number of samples if specified
if num_samples is not None and num_samples > 0: if num_samples is not None and num_samples > 0:
...@@ -387,13 +761,6 @@ def test(): ...@@ -387,13 +761,6 @@ def test():
samples = samples[:num_samples] samples = samples[:num_samples]
print(f"Limited to {len(samples)} samples for validation (from {original_count} total)") print(f"Limited to {len(samples)} samples for validation (from {original_count} total)")
# Create model based on backend
if backend != "010":
model = InfiniLMBenchmark(model_path, device_type_str, ndev, backend, benchmark)
else:
print(f"test 010 backend by scripts/test_ceval.py")
exit(0)
# Test with first sample if available # Test with first sample if available
if len(samples) > 0: if len(samples) > 0:
sample = samples[0] sample = samples[0]
...@@ -411,93 +778,42 @@ def test(): ...@@ -411,93 +778,42 @@ def test():
question = sample['question'] question = sample['question']
choices = sample['choices'] choices = sample['choices']
test_output, _ = model.generate(question, choices, max_steps=max_new_tokens, topp_=1.0, topk_=1, temperature_=1.0) test_output, _ = model.generate(question, choices, max_steps=max_new_tokens, topp_=1.0, topk_=1, temperature_=1.0)
print(f"\nTest output: {test_output}") print(f"\nTest output: {test_output}\n")
answers_list = []
for idx, sample in enumerate(samples):
if benchmark == "ceval":
input_content = f"'question':{sample['question']},'A': {sample['A']}, 'B':{sample['B']}, 'C': {sample['C']},'D': {sample['D']}。"
conversation = [
{
"role": "system",
"content": "请从question的A,B,C,D四个选项中选择正确的选项。例如,标准答案:A。",
},
{"role": "user", "content": input_content},
]
answer = sample["answer"]
output_content, avg_time = model.generate(
conversation, max_steps=max_new_tokens, topp_=1.0, topk_=1, temperature_=1.0
)
is_correct = extract_answer_ceval(output_content, answer)
answers_list.append({
"id": sample.get("id", idx),
"output_content": output_content,
"answer": answer,
"is_correct": is_correct
})
if benchmark == "ceval":
print("标准答案:", answer)
elif benchmark == "mmlu":
question = sample['question']
choices = sample['choices']
answer_idx = sample['answer'] # MMLU answer is 0-3 index
output_content, avg_time = model.generate(
question, choices, max_steps=max_new_tokens, topp_=1.0, topk_=1, temperature_=1.0
)
predicted_answer = extract_answer_mmlu(output_content)
# Convert answer index to letter for display
answer_letter = chr(65 + answer_idx) if answer_idx < 4 else "?"
predicted_letter = chr(65 + predicted_answer) if predicted_answer is not None and predicted_answer < 4 else "?"
print(f"Sample {idx}: Correct answer: {answer_letter} ({answer_idx}), Predicted: {predicted_letter} ({predicted_answer})") # Evaluate samples for this subject
result = evaluate_samples(model, samples, benchmark, max_new_tokens, actual_subj_name)
all_results.append(result)
print(f"\nSubject '{actual_subj_name}' completed: {result['correct']}/{result['total']} = {result['accuracy']:.2%}")
answers_list.append({ except Exception as e:
"id": idx, print(f"Error evaluating subject '{subj}': {e}")
"output_content": output_content, continue
"answer": answer_idx,
"predicted": predicted_answer
})
model.destroy_model_instance() model.destroy_model_instance()
print("-------------------------------------------------------------") # Calculate overall results
overall_correct = sum(r['correct'] for r in all_results)
# Evaluate results overall_total = sum(r['total'] for r in all_results)
true_num = 0 overall_accuracy = overall_correct / overall_total if overall_total > 0 else 0.0
all_num = 0
for cont in answers_list:
id = cont["id"]
all_num = all_num + 1
print(f"\n{'='*60}")
print("OVERALL RESULTS")
print(f"{'='*60}")
if benchmark == "ceval": if benchmark == "ceval":
answer = cont["answer"] print(f"Overall 成绩: {overall_correct}/{overall_total} = {overall_accuracy:.2%}")
is_correct = cont["is_correct"]
if is_correct:
true_num = true_num + 1
print(f"id {id} : ", "正确")
else: else:
print(f"id {id}: ", "错误") print(f"Overall Accuracy: {overall_correct}/{overall_total} = {overall_accuracy:.2%}")
elif benchmark == "mmlu": # Write CSV if output path is specified
answer = cont["answer"] if output_csv:
predicted = cont["predicted"] print(f"\nWriting results to CSV: {output_csv}")
if predicted is not None and predicted == answer: with open(output_csv, 'w', newline='', encoding='utf-8') as csvfile:
true_num = true_num + 1 writer = csv.writer(csvfile)
print(f"id {id}: Correct") writer.writerow(['Subject', 'Correct', 'Total', 'Accuracy'])
else: for result in all_results:
answer_letter = chr(65 + answer) if answer < 4 else "?" writer.writerow([result['subject'], result['correct'], result['total'], f"{result['accuracy']:.4f}"])
predicted_letter = chr(65 + predicted) if predicted is not None and predicted < 4 else "?" writer.writerow(['Overall', overall_correct, overall_total, f"{overall_accuracy:.4f}"])
print(f"id {id}: Wrong (correct: {answer_letter}, predicted: {predicted_letter})") print(f"CSV file written successfully: {output_csv}")
accuracy = true_num / all_num if all_num > 0 else 0.0
if benchmark == "ceval":
print(f"成绩: {true_num}/{all_num}", accuracy)
else:
print(f"Accuracy: {true_num}/{all_num} = {accuracy:.2%}")
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment