Commit b2e1bfc6 authored by artemorloff's avatar artemorloff
Browse files

Merge remote-tracking branch 'origin' into feature/eval_from_config

parents b5d16d61 e4a7b69f
...@@ -20,13 +20,12 @@ jobs: ...@@ -20,13 +20,12 @@ jobs:
with: with:
fetch-depth: 2 # OR "2" -> To retrieve the preceding commit. fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
# Uses the dorny/paths-filter@v3 action to check for changes. # Uses the tj-actions/changed-files action to check for changes.
# Outputs provided here: https://github.com/dorny/paths-filter#outputs
# The `files_yaml` input optionally takes a yaml string to specify filters, # The `files_yaml` input optionally takes a yaml string to specify filters,
# and prepends the filter name to the standard output names. # and prepends the filter name to the standard output names.
- name: Check task folders - name: Check task folders
id: changed-tasks id: changed-tasks
uses: dorny/paths-filter@v3 uses: tj-actions/changed-files@v46.0.5
with: with:
# tasks checks the tasks folder and api checks the api folder for changes # tasks checks the tasks folder and api checks the api folder for changes
files_yaml: | files_yaml: |
......
...@@ -32,13 +32,14 @@ jobs: ...@@ -32,13 +32,14 @@ jobs:
env: env:
SKIP: "no-commit-to-branch,mypy" SKIP: "no-commit-to-branch,mypy"
uses: pre-commit/action@v3.0.1 uses: pre-commit/action@v3.0.1
# Job 2 # Job 2
testcpu: testcpu:
name: CPU Tests name: CPU Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: true
matrix: matrix:
python-version: ["3.9", "3.10", "3.11", "3.12" ] python-version: ["3.9", "3.10", "3.11"]
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- name: Checkout Code - name: Checkout Code
...@@ -49,18 +50,35 @@ jobs: ...@@ -49,18 +50,35 @@ jobs:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
cache: pip cache: pip
cache-dependency-path: pyproject.toml cache-dependency-path: pyproject.toml
# Cache HuggingFace cache directory for CPU tests
- name: Cache HuggingFace cache (CPU tests)
uses: actions/cache@v3
id: cache-hf-cpu
with:
path: ~/.cache/huggingface
key: ${{ runner.os }}-hf-cache-cpu
restore-keys: |
${{ runner.os }}-hf-cache-cpu
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[dev,sentencepiece,api]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[dev]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install hf_xet
- name: Test with pytest - name: Test with pytest
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/models/test_neuralmagic.py --ignore=tests/models/test_openvino.py --ignore=tests/models/test_hf_steered.py run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/models/test_neuralmagic.py --ignore=tests/models/test_openvino.py --ignore=tests/models/test_hf_steered.py
- name: Archive artifacts continue-on-error: true # Continue workflow even if tests fail
# Save test artifacts
- name: Archive test artifacts
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: output_testcpu${{ matrix.python-version }} name: output_testcpu${{ matrix.python-version }}
path: | path: |
test_logs/* test_logs/*
testmodels: testmodels:
name: External LM Tests name: External LM Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
...@@ -74,10 +92,23 @@ jobs: ...@@ -74,10 +92,23 @@ jobs:
python-version: 3.9 python-version: 3.9
cache: pip cache: pip
cache-dependency-path: pyproject.toml cache-dependency-path: pyproject.toml
# Cache HuggingFace cache directory for External LM tests
- name: Cache HuggingFace cache (External LM tests)
uses: actions/cache@v3
id: cache-hf-lm
with:
path: ~/.cache/huggingface
key: ${{ runner.os }}-hf-cache-external-lm
restore-keys: |
${{ runner.os }}-hf-cache-external-lm
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -U transformers peft pip install -U transformers peft accelerate
- name: Test with pytest - name: Test with pytest
run: python -m pytest tests/models --showlocals -s -vv run: python -m pytest tests/models --showlocals -s -vv
continue-on-error: true # Continue workflow even if tests fail
...@@ -113,6 +113,9 @@ class TaskConfig(dict): ...@@ -113,6 +113,9 @@ class TaskConfig(dict):
) )
if "until" not in self.generation_kwargs: if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter] self.generation_kwargs["until"] = [self.fewshot_delimiter]
else: else:
if self.output_type == "generate_until": if self.output_type == "generate_until":
...@@ -124,7 +127,11 @@ class TaskConfig(dict): ...@@ -124,7 +127,11 @@ class TaskConfig(dict):
else [self.fewshot_delimiter] else [self.fewshot_delimiter]
), ),
"do_sample": False, "do_sample": False,
"temperature": 0,
} }
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -928,11 +935,17 @@ class ConfigurableTask(Task): ...@@ -928,11 +935,17 @@ class ConfigurableTask(Task):
num_choice = len(test_choice) num_choice = len(test_choice)
if isinstance(test_text, int): if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
)
self.multiple_input = num_choice self.multiple_input = num_choice
else: else:
test_choice = None test_choice = None
if isinstance(test_target, list): if isinstance(test_target, list):
eval_logger.debug(
"doc_to_target returned a list. Assuming multiple targets."
)
self.multiple_target = len(test_target) self.multiple_target = len(test_target)
else: else:
if (isinstance(test_target, int)) and (test_choice is not None): if (isinstance(test_target, int)) and (test_choice is not None):
......
...@@ -49,6 +49,11 @@ class HFMultimodalLM(HFLM): ...@@ -49,6 +49,11 @@ class HFMultimodalLM(HFLM):
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
**kwargs, **kwargs,
): ):
# init pixels before calling tokenizer creation to avoid errors
self.pixels = ({"min_pixels": min_pixels} if min_pixels else {}) | (
{"max_pixels": max_pixels} if max_pixels else {}
)
# We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer # We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer
# modify init behavior. # modify init behavior.
super().__init__(pretrained, **kwargs) super().__init__(pretrained, **kwargs)
...@@ -65,9 +70,6 @@ class HFMultimodalLM(HFLM): ...@@ -65,9 +70,6 @@ class HFMultimodalLM(HFLM):
self.interleave = interleave self.interleave = interleave
self.max_images = max_images self.max_images = max_images
self.rgb = convert_img_format self.rgb = convert_img_format
self.pixels = ({"min_pixels": min_pixels} if min_pixels else {}) | (
{"max_pixels": max_pixels} if max_pixels else {}
)
# WARNING: improperly set image_token_id can lead to ignored image input or other (potentially silent) errors! # WARNING: improperly set image_token_id can lead to ignored image input or other (potentially silent) errors!
if not image_string: if not image_string:
self.image_token_id = ( self.image_token_id = (
......
...@@ -3,7 +3,7 @@ import logging ...@@ -3,7 +3,7 @@ import logging
import os import os
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import jinja2 import jinja2
import torch import torch
...@@ -74,6 +74,7 @@ class HFLM(TemplateLM): ...@@ -74,6 +74,7 @@ class HFLM(TemplateLM):
max_length: Optional[int] = None, max_length: Optional[int] = None,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
softmax_dtype: Optional[Union[str, torch.dtype]] = None,
batch_size: Optional[Union[int, str]] = 1, batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 64, max_batch_size: Optional[int] = 64,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
...@@ -204,6 +205,7 @@ class HFLM(TemplateLM): ...@@ -204,6 +205,7 @@ class HFLM(TemplateLM):
autogptq=autogptq, autogptq=autogptq,
gptqmodel=gptqmodel, gptqmodel=gptqmodel,
gguf_file=gguf_file, gguf_file=gguf_file,
quantization_config=getattr(self.config, "quantization_config", None),
**kwargs, **kwargs,
) )
...@@ -233,6 +235,9 @@ class HFLM(TemplateLM): ...@@ -233,6 +235,9 @@ class HFLM(TemplateLM):
self.batch_schedule = 1 self.batch_schedule = 1
self.batch_sizes = {} self.batch_sizes = {}
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.softmax_dtype = (
get_dtype(softmax_dtype) if softmax_dtype is not None else None
)
if str(batch_size).startswith("auto"): if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":") batch_size = batch_size.split(":")
...@@ -546,6 +551,7 @@ class HFLM(TemplateLM): ...@@ -546,6 +551,7 @@ class HFLM(TemplateLM):
autogptq: Optional[Union[bool, str]] = False, autogptq: Optional[Union[bool, str]] = False,
gptqmodel: Optional[bool] = False, gptqmodel: Optional[bool] = False,
gguf_file: Optional[str] = None, gguf_file: Optional[str] = None,
quantization_config: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
...@@ -591,6 +597,7 @@ class HFLM(TemplateLM): ...@@ -591,6 +597,7 @@ class HFLM(TemplateLM):
torch_dtype=get_dtype(dtype), torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
gguf_file=gguf_file, gguf_file=gguf_file,
quantization_config=quantization_config,
**model_kwargs, **model_kwargs,
) )
else: else:
...@@ -765,7 +772,11 @@ class HFLM(TemplateLM): ...@@ -765,7 +772,11 @@ class HFLM(TemplateLM):
(batch_size, max_length), device=self.device (batch_size, max_length), device=self.device
).long() ).long()
for _ in range(5): for _ in range(5):
out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) # noqa: F841 out = F.log_softmax( # noqa: F841
self._model_call(test_batch, **call_kwargs),
dim=-1,
dtype=self.softmax_dtype,
)
return batch_size return batch_size
...@@ -1197,7 +1208,9 @@ class HFLM(TemplateLM): ...@@ -1197,7 +1208,9 @@ class HFLM(TemplateLM):
} }
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1 self._model_call(batched_inps, **call_kwargs),
dim=-1,
dtype=self.softmax_dtype,
) # [batch, padding_length (inp or cont), vocab] ) # [batch, padding_length (inp or cont), vocab]
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
......
...@@ -28,6 +28,9 @@ try: ...@@ -28,6 +28,9 @@ try:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
...@@ -133,6 +136,16 @@ class VLLM(TemplateLM): ...@@ -133,6 +136,16 @@ class VLLM(TemplateLM):
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it." "Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
) )
if parse_version(version("vllm")) >= parse_version("0.8.3"):
self.hf_chat_template = resolve_hf_chat_template(
tokenizer=self.tokenizer,
chat_template=None,
tools=None,
trust_remote_code=trust_remote_code,
)
else:
self.hf_chat_template = None
self.custom_prefix_token_id = prefix_token_id self.custom_prefix_token_id = prefix_token_id
if prefix_token_id is not None: if prefix_token_id is not None:
eval_logger.info( eval_logger.info(
...@@ -195,6 +208,7 @@ class VLLM(TemplateLM): ...@@ -195,6 +208,7 @@ class VLLM(TemplateLM):
tokenize=False, tokenize=False,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
chat_template=self.hf_chat_template,
) )
return chat_templated return chat_templated
......
This diff is collapsed.
...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench ...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: 2wikimqa dataset_name: 2wikimqa
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:' doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}' doc_to_target: '{{answers[0]}}'
generation_kwargs: generation_kwargs:
max_gen_toks: 32 max_gen_toks: 32
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: []
metric_list: metric_list:
- metric: !function metrics.qa_f1_score - metric: !function metrics.qa_f1_score
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 1.0 version: 2.0
tag: tag:
- longbench_e - longbench_e
task: longbench_2wikimqa_e task: longbench_2wikimqa_e
...@@ -5,14 +6,15 @@ dataset_path: THUDM/LongBench ...@@ -5,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: 2wikimqa_e dataset_name: 2wikimqa_e
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:' doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}' doc_to_target: '{{answers[0]}}'
generation_kwargs: generation_kwargs:
max_gen_toks: 32 max_gen_toks: 32
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: []
metric_list: metric_list:
- metric: !function metrics.qa_f1_score - metric: !function metrics.qa_f1_score
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 1.0 version: 2.0
...@@ -95,3 +95,4 @@ If other tasks on this dataset are already supported: ...@@ -95,3 +95,4 @@ If other tasks on this dataset are already supported:
* [x] Have you noted which, if any, published evaluation setups are matched by this variant? * [x] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog ### Changelog
v2.: fix doc_to_target; add vcsum
...@@ -138,7 +138,7 @@ DATASETS = [ ...@@ -138,7 +138,7 @@ DATASETS = [
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--save_prefix_path", default="longbench") parser.add_argument("--save_prefix_path", default="")
return parser.parse_args() return parser.parse_args()
...@@ -156,6 +156,7 @@ generation_kwargs: ...@@ -156,6 +156,7 @@ generation_kwargs:
max_gen_toks: {{ generation_kwargs.max_gen_toks }} max_gen_toks: {{ generation_kwargs.max_gen_toks }}
temperature: {{ generation_kwargs.temperature }} temperature: {{ generation_kwargs.temperature }}
do_sample: {{ generation_kwargs.do_sample }} do_sample: {{ generation_kwargs.do_sample }}
until: {{ generation_kwargs.until }}
metric_list: metric_list:
- metric: {{ metric_list[0].metric }} - metric: {{ metric_list[0].metric }}
aggregation: {{ metric_list[0].aggregation }} aggregation: {{ metric_list[0].aggregation }}
...@@ -171,10 +172,21 @@ if __name__ == "__main__": ...@@ -171,10 +172,21 @@ if __name__ == "__main__":
template = env.from_string(template_str) template = env.from_string(template_str)
for ds in DATASETS: for ds in DATASETS:
df = ds[:-2] if ds.endswith("_e") else ds df = ds[:-2] if ds.endswith("_e") else ds
# from https://github.com/THUDM/LongBench/blob/2e00731f8d0bff23dc4325161044d0ed8af94c1e/LongBench/eval.py#L52C25-L52C29
if df in ["trec", "triviaqa", "samsum", "lsht"] + [
"trec_e",
"triviaqa_e",
"samsum_e",
"lsht_e",
]:
until = ["\n"]
else:
until = []
generation_kwargs = { generation_kwargs = {
"max_gen_toks": dataset2maxlen[df], "max_gen_toks": dataset2maxlen[df],
"temperature": 1, "temperature": 1,
"do_sample": True, "do_sample": True,
"until": until,
} }
raw_doc_to_text = ( raw_doc_to_text = (
dataset2prompt[df] dataset2prompt[df]
...@@ -199,10 +211,10 @@ if __name__ == "__main__": ...@@ -199,10 +211,10 @@ if __name__ == "__main__":
"test_split": "test", "test_split": "test",
"dataset_name": ds, "dataset_name": ds,
"doc_to_text": raw_doc_to_text, "doc_to_text": raw_doc_to_text,
"doc_to_target": "{{answers}}", "doc_to_target": "{{answers[0]}}",
"generation_kwargs": generation_kwargs, "generation_kwargs": generation_kwargs,
"metric_list": metric_list, "metric_list": metric_list,
"metadata": {"version": "1.0"}, "metadata": {"version": "2.0"},
} }
# Render template # Render template
...@@ -211,35 +223,3 @@ if __name__ == "__main__": ...@@ -211,35 +223,3 @@ if __name__ == "__main__":
# Save to file # Save to file
with open(args.save_prefix_path + f"{ds}.yaml", "w") as f: with open(args.save_prefix_path + f"{ds}.yaml", "w") as f:
f.write(rendered_yaml) f.write(rendered_yaml)
# for ds in DATASETS:
# df = ds[:-2] if ds.endswith("_e") else ds
# generation_kwargs = {"max_gen_toks": dataset2maxlen[df], "temperature": 1, "do_sample": False}
# # Escape newlines and curly braces
# raw_doc_to_text = dataset2prompt[df].replace("\n", "\\n").replace("{", "{{").replace("}", "}}")
# metric_list = [
# {"metric": f"!function metrics.{dataset2metric[df]}", "aggregation": "mean", "higher_is_better": True}]
# yaml_dict = {
# "tag": ["longbench_e" if ds.endswith("_e") else "longbench"],
# "task": f"longbench_{ds}",
# "dataset_path": "THUDM/LongBench",
# "test_split": "test",
# "dataset_name": ds,
# "doc_to_text": raw_doc_to_text,
# "doc_to_target": "{{answers}}",
# "generation_kwargs": generation_kwargs,
# "metric_list": metric_list,
# "metadata": {"version": "1.0"}
# }
# template = env.from_string(yaml_dict)
#
#
# file_save_path = args.save_prefix_path + f"{ds}.yaml"
# with open(file_save_path, "w", encoding="utf-8") as yaml_file:
# yaml.dump(
# yaml_dict,
# yaml_file,
# allow_unicode=True,
# default_flow_style=False,
# sort_keys=False
# )
...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench ...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: dureader dataset_name: dureader
doc_to_text: '请基于给定的文章回答下述问题。\n\n文章:{{context}}\n\n请基于上述文章回答下面的问题。\n\n问题:{{input}}\n回答:' doc_to_text: '请基于给定的文章回答下述问题。\n\n文章:{{context}}\n\n请基于上述文章回答下面的问题。\n\n问题:{{input}}\n回答:'
doc_to_target: '{{answers}}' doc_to_target: '{{answers[0]}}'
generation_kwargs: generation_kwargs:
max_gen_toks: 128 max_gen_toks: 128
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: []
metric_list: metric_list:
- metric: !function metrics.rouge_zh_score - metric: !function metrics.rouge_zh_score
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 1.0 version: 2.0
...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench ...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: gov_report dataset_name: gov_report
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:' doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
doc_to_target: '{{answers}}' doc_to_target: '{{answers[0]}}'
generation_kwargs: generation_kwargs:
max_gen_toks: 512 max_gen_toks: 512
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: []
metric_list: metric_list:
- metric: !function metrics.rouge_score - metric: !function metrics.rouge_score
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 1.0 version: 2.0
...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench ...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: gov_report_e dataset_name: gov_report_e
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:' doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
doc_to_target: '{{answers}}' doc_to_target: '{{answers[0]}}'
generation_kwargs: generation_kwargs:
max_gen_toks: 512 max_gen_toks: 512
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: []
metric_list: metric_list:
- metric: !function metrics.rouge_score - metric: !function metrics.rouge_score
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 1.0 version: 2.0
...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench ...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: hotpotqa dataset_name: hotpotqa
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:' doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}' doc_to_target: '{{answers[0]}}'
generation_kwargs: generation_kwargs:
max_gen_toks: 32 max_gen_toks: 32
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: []
metric_list: metric_list:
- metric: !function metrics.qa_f1_score - metric: !function metrics.qa_f1_score
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 1.0 version: 2.0
...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench ...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: hotpotqa_e dataset_name: hotpotqa_e
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:' doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}' doc_to_target: '{{answers[0]}}'
generation_kwargs: generation_kwargs:
max_gen_toks: 32 max_gen_toks: 32
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: []
metric_list: metric_list:
- metric: !function metrics.qa_f1_score - metric: !function metrics.qa_f1_score
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 1.0 version: 2.0
...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench ...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: lcc dataset_name: lcc
doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n' doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n'
doc_to_target: '{{answers}}' doc_to_target: '{{answers[0]}}'
generation_kwargs: generation_kwargs:
max_gen_toks: 64 max_gen_toks: 64
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: []
metric_list: metric_list:
- metric: !function metrics.code_sim_score - metric: !function metrics.code_sim_score
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 1.0 version: 2.0
...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench ...@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: lcc_e dataset_name: lcc_e
doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n' doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n'
doc_to_target: '{{answers}}' doc_to_target: '{{answers[0]}}'
generation_kwargs: generation_kwargs:
max_gen_toks: 64 max_gen_toks: 64
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: []
metric_list: metric_list:
- metric: !function metrics.code_sim_score - metric: !function metrics.code_sim_score
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 1.0 version: 2.0
...@@ -6,14 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,14 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: lsht dataset_name: lsht
doc_to_text: '请判断给定新闻的类别,下面是一些例子。\n\n{{context}}\n{{input}}' doc_to_text: '请判断给定新闻的类别,下面是一些例子。\n\n{{context}}\n{{input}}'
doc_to_target: '{{answers}}' doc_to_target: '{{answers[0]}}'
process_results: !function metrics.classification_score
generation_kwargs: generation_kwargs:
max_gen_toks: 64 max_gen_toks: 64
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: ['\n']
metric_list: metric_list:
- metric: !function metrics.classification_score - metric: "classification_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 1.0 version: 2.0
...@@ -124,12 +124,10 @@ def code_sim_score(predictions: list[str], references: list[str], **kwargs) -> f ...@@ -124,12 +124,10 @@ def code_sim_score(predictions: list[str], references: list[str], **kwargs) -> f
return fuzz.ratio(prediction, ground_truth) / 100 return fuzz.ratio(prediction, ground_truth) / 100
def classification_score( def classification_score(doc: dict, results: list[str], **kwargs) -> dict:
predictions: list[str], references: list[str], **kwargs prediction, ground_truth = results[0], doc["answers"][0]
) -> float:
prediction, ground_truth = predictions[0], references[0]
em_match_list = [] em_match_list = []
all_classes = kwargs["all_classes"] all_classes = doc["all_classes"]
for class_name in all_classes: for class_name in all_classes:
if class_name in prediction: if class_name in prediction:
em_match_list.append(class_name) em_match_list.append(class_name)
...@@ -140,12 +138,14 @@ def classification_score( ...@@ -140,12 +138,14 @@ def classification_score(
score = 1.0 / len(em_match_list) score = 1.0 / len(em_match_list)
else: else:
score = 0.0 score = 0.0
return score return {"classification_score": score}
def rouge_score(predictions: list[str], references: list[str], **kwargs) -> float: def rouge_score(predictions: list[str], references: list[str], **kwargs) -> float:
prediction, ground_truth = predictions[0], references[0] global rouge
if "rouge" not in globals():
rouge = Rouge() rouge = Rouge()
prediction, ground_truth = predictions[0], references[0]
try: try:
scores = rouge.get_scores([prediction], [ground_truth], avg=True) scores = rouge.get_scores([prediction], [ground_truth], avg=True)
# ruff: noqa # ruff: noqa
...@@ -162,7 +162,7 @@ def rouge_zh_score(predictions: list[str], references: list[str], **kwargs) -> f ...@@ -162,7 +162,7 @@ def rouge_zh_score(predictions: list[str], references: list[str], **kwargs) -> f
return score return score
def f1_score(predictions: list[str], references: list[str], **kwargs): def f1_score(predictions: list[str], references: list[str], **kwargs) -> float:
try: try:
prediction, ground_truth = predictions[0], references[0] prediction, ground_truth = predictions[0], references[0]
except: except:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment