Commit b2e1bfc6 authored by artemorloff's avatar artemorloff
Browse files

Merge remote-tracking branch 'origin' into feature/eval_from_config

parents b5d16d61 e4a7b69f
......@@ -20,13 +20,12 @@ jobs:
with:
fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
# Uses the dorny/paths-filter@v3 action to check for changes.
# Outputs provided here: https://github.com/dorny/paths-filter#outputs
# Uses the tj-actions/changed-files action to check for changes.
# The `files_yaml` input optionally takes a yaml string to specify filters,
# and prepends the filter name to the standard output names.
- name: Check task folders
id: changed-tasks
uses: dorny/paths-filter@v3
uses: tj-actions/changed-files@v46.0.5
with:
# tasks checks the tasks folder and api checks the api folder for changes
files_yaml: |
......
......@@ -32,13 +32,14 @@ jobs:
env:
SKIP: "no-commit-to-branch,mypy"
uses: pre-commit/action@v3.0.1
# Job 2
# Job 2
testcpu:
name: CPU Tests
runs-on: ubuntu-latest
strategy:
fail-fast: true
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12" ]
python-version: ["3.9", "3.10", "3.11"]
timeout-minutes: 30
steps:
- name: Checkout Code
......@@ -49,18 +50,35 @@ jobs:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
# Cache HuggingFace cache directory for CPU tests
- name: Cache HuggingFace cache (CPU tests)
uses: actions/cache@v3
id: cache-hf-cpu
with:
path: ~/.cache/huggingface
key: ${{ runner.os }}-hf-cache-cpu
restore-keys: |
${{ runner.os }}-hf-cache-cpu
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev,sentencepiece,api]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[dev]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install hf_xet
- name: Test with pytest
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/models/test_neuralmagic.py --ignore=tests/models/test_openvino.py --ignore=tests/models/test_hf_steered.py
- name: Archive artifacts
continue-on-error: true # Continue workflow even if tests fail
# Save test artifacts
- name: Archive test artifacts
uses: actions/upload-artifact@v4
with:
name: output_testcpu${{ matrix.python-version }}
path: |
test_logs/*
testmodels:
name: External LM Tests
runs-on: ubuntu-latest
......@@ -74,10 +92,23 @@ jobs:
python-version: 3.9
cache: pip
cache-dependency-path: pyproject.toml
# Cache HuggingFace cache directory for External LM tests
- name: Cache HuggingFace cache (External LM tests)
uses: actions/cache@v3
id: cache-hf-lm
with:
path: ~/.cache/huggingface
key: ${{ runner.os }}-hf-cache-external-lm
restore-keys: |
${{ runner.os }}-hf-cache-external-lm
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -U transformers peft
pip install -U transformers peft accelerate
- name: Test with pytest
run: python -m pytest tests/models --showlocals -s -vv
continue-on-error: true # Continue workflow even if tests fail
......@@ -113,6 +113,9 @@ class TaskConfig(dict):
)
if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
......@@ -124,7 +127,11 @@ class TaskConfig(dict):
else [self.fewshot_delimiter]
),
"do_sample": False,
"temperature": 0,
}
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
def __getitem__(self, item):
return getattr(self, item)
......@@ -928,11 +935,17 @@ class ConfigurableTask(Task):
num_choice = len(test_choice)
if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
)
self.multiple_input = num_choice
else:
test_choice = None
if isinstance(test_target, list):
eval_logger.debug(
"doc_to_target returned a list. Assuming multiple targets."
)
self.multiple_target = len(test_target)
else:
if (isinstance(test_target, int)) and (test_choice is not None):
......
......@@ -49,6 +49,11 @@ class HFMultimodalLM(HFLM):
max_pixels: Optional[int] = None,
**kwargs,
):
# init pixels before calling tokenizer creation to avoid errors
self.pixels = ({"min_pixels": min_pixels} if min_pixels else {}) | (
{"max_pixels": max_pixels} if max_pixels else {}
)
# We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer
# modify init behavior.
super().__init__(pretrained, **kwargs)
......@@ -65,9 +70,6 @@ class HFMultimodalLM(HFLM):
self.interleave = interleave
self.max_images = max_images
self.rgb = convert_img_format
self.pixels = ({"min_pixels": min_pixels} if min_pixels else {}) | (
{"max_pixels": max_pixels} if max_pixels else {}
)
# WARNING: improperly set image_token_id can lead to ignored image input or other (potentially silent) errors!
if not image_string:
self.image_token_id = (
......
......@@ -3,7 +3,7 @@ import logging
import os
from datetime import timedelta
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import jinja2
import torch
......@@ -74,6 +74,7 @@ class HFLM(TemplateLM):
max_length: Optional[int] = None,
device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
softmax_dtype: Optional[Union[str, torch.dtype]] = None,
batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 64,
trust_remote_code: Optional[bool] = False,
......@@ -204,6 +205,7 @@ class HFLM(TemplateLM):
autogptq=autogptq,
gptqmodel=gptqmodel,
gguf_file=gguf_file,
quantization_config=getattr(self.config, "quantization_config", None),
**kwargs,
)
......@@ -233,6 +235,9 @@ class HFLM(TemplateLM):
self.batch_schedule = 1
self.batch_sizes = {}
self.max_batch_size = max_batch_size
self.softmax_dtype = (
get_dtype(softmax_dtype) if softmax_dtype is not None else None
)
if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":")
......@@ -546,6 +551,7 @@ class HFLM(TemplateLM):
autogptq: Optional[Union[bool, str]] = False,
gptqmodel: Optional[bool] = False,
gguf_file: Optional[str] = None,
quantization_config: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
"""
......@@ -591,6 +597,7 @@ class HFLM(TemplateLM):
torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code,
gguf_file=gguf_file,
quantization_config=quantization_config,
**model_kwargs,
)
else:
......@@ -765,7 +772,11 @@ class HFLM(TemplateLM):
(batch_size, max_length), device=self.device
).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) # noqa: F841
out = F.log_softmax( # noqa: F841
self._model_call(test_batch, **call_kwargs),
dim=-1,
dtype=self.softmax_dtype,
)
return batch_size
......@@ -1197,7 +1208,9 @@ class HFLM(TemplateLM):
}
multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1
self._model_call(batched_inps, **call_kwargs),
dim=-1,
dtype=self.softmax_dtype,
) # [batch, padding_length (inp or cont), vocab]
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
......
......@@ -28,6 +28,9 @@ try:
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer
if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template
except ModuleNotFoundError:
pass
......@@ -133,6 +136,16 @@ class VLLM(TemplateLM):
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
if parse_version(version("vllm")) >= parse_version("0.8.3"):
self.hf_chat_template = resolve_hf_chat_template(
tokenizer=self.tokenizer,
chat_template=None,
tools=None,
trust_remote_code=trust_remote_code,
)
else:
self.hf_chat_template = None
self.custom_prefix_token_id = prefix_token_id
if prefix_token_id is not None:
eval_logger.info(
......@@ -195,6 +208,7 @@ class VLLM(TemplateLM):
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
chat_template=self.hf_chat_template,
)
return chat_templated
......
......@@ -6,7 +6,7 @@
For more information, including a full list of task names and their precise meanings or sources, follow the links provided to the individual README.md files for each subfolder.
| Task Family | Description | Language(s) |
|--------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------|
|--------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------|
| [aclue](aclue/README.md) | Tasks focusing on ancient Chinese language understanding and cultural aspects. | Ancient Chinese |
| [acp_bench](acpbench/README.md) | Tasks evaluating the reasoning ability about Action, Change, and Planning | English |
| [aexams](aexams/README.md) | Tasks in Arabic related to various academic exams covering a range of subjects. | Arabic |
......@@ -41,14 +41,14 @@
| csatqa | Tasks related to SAT and other standardized testing questions for academic assessment. | Korean |
| [darija_bench](darija_bench/README.md) | Traditional NLP tasks (Translation, Summariation, etc..) for Moroccan Darija | Moroccan Darija (some MT) |
| [darijahellaswag](darijahellaswag/README.md) | Moroccan Darija version of HellaSwag. | Moroccan Darija (MT) |
| [darijammlu](darijammlu/README.md)| Multiple-choice QA in Moroccan Darija (an Arabic dialect). | Moroccan Darija (MT) |
| [darijammlu](darijammlu/README.md) | Multiple-choice QA in Moroccan Darija (an Arabic dialect). | Moroccan Darija (MT) |
| [drop](drop/README.md) | Tasks requiring numerical reasoning, reading comprehension, and question answering. | English |
| [eq_bench](eq_bench/README.md) | Tasks focused on equality and ethics in question answering and decision-making. | English |
| [eus_exams](eus_exams/README.md) | Tasks based on various professional and academic exams in the Basque language. | Basque |
| [eus_proficiency](eus_proficiency/README.md) | Tasks designed to test proficiency in the Basque language across various topics. | Basque |
| [eus_reading](eus_reading/README.md) | Reading comprehension tasks specifically designed for the Basque language. | Basque |
| [eus_trivia](eus_trivia/README.md) | Trivia and knowledge testing tasks in the Basque language. | Basque |
| [evalita-LLM](evalita-LLM/README.md) | A native Italian benchmark with diverse tasks formats and multiple prompts. | Italian |
| [evalita_LLM](evalita_llm/README.md) | A native Italian benchmark with diverse tasks formats and multiple prompts. | Italian |
| [fda](fda/README.md) | Tasks for extracting key-value pairs from FDA documents to test information extraction. | English |
| [fld](fld/README.md) | Tasks involving free-form and directed dialogue understanding. | English |
| [french_bench](french_bench/README.md) | Set of tasks designed to assess language model performance in French. | French |
......@@ -80,8 +80,10 @@
| [lambada_multilingual_stablelm](lambada_multilingual_stablelm/README.md) | Multilingual LAMBADA dataset. Users should prefer evaluating on this version of the multilingual dataset instead of on `lambada_multilingual`. | German, English, Spanish, French, Italian, Dutch, Portuguese |
| [leaderboard](leaderboard/README.md) | Task group used by Hugging Face's [Open LLM Leaderboard v2](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard). Those tasks are static and will not change through time | English |
| [lingoly](lingoly/README.md) | Challenging logical reasoning benchmark in low-resource languages with controls for memorization | English, Multilingual |
| [llama3](llama3/README.md) | Evals reproducing those provided by the LLAMA team in the Hugging Face repo (instruct) | English, Multilingual |
| [logiqa](logiqa/README.md) | Logical reasoning tasks requiring advanced inference and deduction. | English, Chinese |
| [logiqa2](logiqa2/README.md) | Large-scale logical reasoning dataset adapted from the Chinese Civil Service Examination. | English, Chinese |
| [longbench](longbench/README.md) | LongBench evaluates language models' ability to understand lengthy texts across multiple tasks and languages. | English, Chinese |
| [mastermind](mastermind/README.md) | Reasoning benchmark based on the board game of Mastermind. | English |
| [mathqa](mathqa/README.md) | Question answering tasks involving mathematical reasoning and problem-solving. | English |
| [mbpp](mbpp/README.md) | A benchmark designed to measure the ability to synthesize short Python programs from natural language descriptions. | Python |
......@@ -104,7 +106,7 @@
| [mmlu_prox](mmlu_prox/README.md) | A multilingual benchmark that extends MMLU-Pro to multiple typologically diverse languages with human validation. | English, Japanese, Chinese, Korean, French, German, Spanish, Portuguese, Swahili, Thai, Arabic, Hindi, Bengali |
| [mmlusr](mmlusr/README.md) | Variation of MMLU designed to be more rigorous. | English |
| model_written_evals | Evaluation tasks auto-generated for evaluating a collection of AI Safety concerns. | |
| [moral_stories](moral_stories/README.md) | A crowd-sourced dataset of structured narratives that describe normative and norm-divergent actions taken by individuals to accomplish certain intentions in concrete situations. | English
| [moral_stories](moral_stories/README.md) | A crowd-sourced dataset of structured narratives that describe normative and norm-divergent actions taken by individuals to accomplish certain intentions in concrete situations. | English |
| [mts_dialog](mts_dialog/README.md) | Open-ended healthcare QA from the MTS-Dialog dataset. | English |
| [mutual](mutual/README.md) | A retrieval-based dataset for multi-turn dialogue reasoning. | English |
| [nq_open](nq_open/README.md) | Open domain question answering tasks based on the Natural Questions dataset. | English |
......@@ -163,7 +165,7 @@
| [xstorycloze](xstorycloze/README.md) | Cross-lingual narrative understanding tasks to predict story endings in multiple languages. | Russian, Simplified Chinese, Spanish, Arabic, Hindi, Indonesian, Telugu, Swahili, Basque, Burmese |
| [xwinograd](xwinograd/README.md) | Cross-lingual Winograd schema tasks for coreference resolution in multiple languages. | English, French, Japanese, Portuguese, Russian, Chinese |
## Multilingual Tasks
## Multimodal Tasks
| Task Family | Description | Modality |
|------------------------------|---------------------------------------------------------------------------------------------------------|-------------|
| [chartqa](chartqa/README.md) | A benchmark for question answering about charts that requires both visual and logical reasoning. | Image, Text |
......
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: 2wikimqa
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
tag:
- longbench_e
task: longbench_2wikimqa_e
......@@ -5,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: 2wikimqa_e
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -95,3 +95,4 @@ If other tasks on this dataset are already supported:
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
v2.: fix doc_to_target; add vcsum
......@@ -138,7 +138,7 @@ DATASETS = [
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--save_prefix_path", default="longbench")
parser.add_argument("--save_prefix_path", default="")
return parser.parse_args()
......@@ -156,6 +156,7 @@ generation_kwargs:
max_gen_toks: {{ generation_kwargs.max_gen_toks }}
temperature: {{ generation_kwargs.temperature }}
do_sample: {{ generation_kwargs.do_sample }}
until: {{ generation_kwargs.until }}
metric_list:
- metric: {{ metric_list[0].metric }}
aggregation: {{ metric_list[0].aggregation }}
......@@ -171,10 +172,21 @@ if __name__ == "__main__":
template = env.from_string(template_str)
for ds in DATASETS:
df = ds[:-2] if ds.endswith("_e") else ds
# from https://github.com/THUDM/LongBench/blob/2e00731f8d0bff23dc4325161044d0ed8af94c1e/LongBench/eval.py#L52C25-L52C29
if df in ["trec", "triviaqa", "samsum", "lsht"] + [
"trec_e",
"triviaqa_e",
"samsum_e",
"lsht_e",
]:
until = ["\n"]
else:
until = []
generation_kwargs = {
"max_gen_toks": dataset2maxlen[df],
"temperature": 1,
"do_sample": True,
"until": until,
}
raw_doc_to_text = (
dataset2prompt[df]
......@@ -199,10 +211,10 @@ if __name__ == "__main__":
"test_split": "test",
"dataset_name": ds,
"doc_to_text": raw_doc_to_text,
"doc_to_target": "{{answers}}",
"doc_to_target": "{{answers[0]}}",
"generation_kwargs": generation_kwargs,
"metric_list": metric_list,
"metadata": {"version": "1.0"},
"metadata": {"version": "2.0"},
}
# Render template
......@@ -211,35 +223,3 @@ if __name__ == "__main__":
# Save to file
with open(args.save_prefix_path + f"{ds}.yaml", "w") as f:
f.write(rendered_yaml)
# for ds in DATASETS:
# df = ds[:-2] if ds.endswith("_e") else ds
# generation_kwargs = {"max_gen_toks": dataset2maxlen[df], "temperature": 1, "do_sample": False}
# # Escape newlines and curly braces
# raw_doc_to_text = dataset2prompt[df].replace("\n", "\\n").replace("{", "{{").replace("}", "}}")
# metric_list = [
# {"metric": f"!function metrics.{dataset2metric[df]}", "aggregation": "mean", "higher_is_better": True}]
# yaml_dict = {
# "tag": ["longbench_e" if ds.endswith("_e") else "longbench"],
# "task": f"longbench_{ds}",
# "dataset_path": "THUDM/LongBench",
# "test_split": "test",
# "dataset_name": ds,
# "doc_to_text": raw_doc_to_text,
# "doc_to_target": "{{answers}}",
# "generation_kwargs": generation_kwargs,
# "metric_list": metric_list,
# "metadata": {"version": "1.0"}
# }
# template = env.from_string(yaml_dict)
#
#
# file_save_path = args.save_prefix_path + f"{ds}.yaml"
# with open(file_save_path, "w", encoding="utf-8") as yaml_file:
# yaml.dump(
# yaml_dict,
# yaml_file,
# allow_unicode=True,
# default_flow_style=False,
# sort_keys=False
# )
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: dureader
doc_to_text: '请基于给定的文章回答下述问题。\n\n文章:{{context}}\n\n请基于上述文章回答下面的问题。\n\n问题:{{input}}\n回答:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 128
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.rouge_zh_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: gov_report
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 512
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.rouge_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: gov_report_e
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 512
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.rouge_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: hotpotqa
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: hotpotqa_e
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: lcc
doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.code_sim_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: lcc_e
doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.code_sim_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: lsht
doc_to_text: '请判断给定新闻的类别,下面是一些例子。\n\n{{context}}\n{{input}}'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
process_results: !function metrics.classification_score
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
until: ['\n']
metric_list:
- metric: !function metrics.classification_score
- metric: "classification_score"
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -124,12 +124,10 @@ def code_sim_score(predictions: list[str], references: list[str], **kwargs) -> f
return fuzz.ratio(prediction, ground_truth) / 100
def classification_score(
predictions: list[str], references: list[str], **kwargs
) -> float:
prediction, ground_truth = predictions[0], references[0]
def classification_score(doc: dict, results: list[str], **kwargs) -> dict:
prediction, ground_truth = results[0], doc["answers"][0]
em_match_list = []
all_classes = kwargs["all_classes"]
all_classes = doc["all_classes"]
for class_name in all_classes:
if class_name in prediction:
em_match_list.append(class_name)
......@@ -140,12 +138,14 @@ def classification_score(
score = 1.0 / len(em_match_list)
else:
score = 0.0
return score
return {"classification_score": score}
def rouge_score(predictions: list[str], references: list[str], **kwargs) -> float:
prediction, ground_truth = predictions[0], references[0]
global rouge
if "rouge" not in globals():
rouge = Rouge()
prediction, ground_truth = predictions[0], references[0]
try:
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
# ruff: noqa
......@@ -162,7 +162,7 @@ def rouge_zh_score(predictions: list[str], references: list[str], **kwargs) -> f
return score
def f1_score(predictions: list[str], references: list[str], **kwargs):
def f1_score(predictions: list[str], references: list[str], **kwargs) -> float:
try:
prediction, ground_truth = predictions[0], references[0]
except:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment