Unverified Commit 80a10075 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Add loncxt tasks (#2629)

suport for longcontext (and other synthetic tasks)
* add ruler
* add longbench
* pass `metadata` to TaskConfig
parent f47ddaf8
tag:
- longbench
task: longbench_trec
dataset_path: THUDM/LongBench
test_split: test
dataset_name: trec
doc_to_text: 'Please determine the type of the question below. Here are some examples of questions.\n\n{{context}}\n{{input}}'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.classification_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
tag:
- longbench_e
task: longbench_trec_e
dataset_path: THUDM/LongBench
test_split: test
dataset_name: trec_e
doc_to_text: 'Please determine the type of the question below. Here are some examples of questions.\n\n{{context}}\n{{input}}'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.classification_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
tag:
- longbench
task: longbench_triviaqa
dataset_path: THUDM/LongBench
test_split: test
dataset_name: triviaqa
doc_to_text: 'Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{{context}}\n\n{{input}}'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
tag:
- longbench_e
task: longbench_triviaqa_e
dataset_path: THUDM/LongBench
test_split: test
dataset_name: triviaqa_e
doc_to_text: 'Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{{context}}\n\n{{input}}'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
import argparse
import json
import os
import numpy as np
from metrics import (
classification_score,
code_sim_score,
count_score,
qa_f1_score,
qa_f1_zh_score,
retrieval_score,
retrieval_zh_score,
rouge_score,
rouge_zh_score,
)
dataset2metric = {
"narrativeqa": qa_f1_score,
"qasper": qa_f1_score,
"multifieldqa_en": qa_f1_score,
"multifieldqa_zh": qa_f1_zh_score,
"hotpotqa": qa_f1_score,
"2wikimqa": qa_f1_score,
"musique": qa_f1_score,
"dureader": rouge_zh_score,
"gov_report": rouge_score,
"qmsum": rouge_score,
"multi_news": rouge_score,
"vcsum": rouge_zh_score,
"trec": classification_score,
"triviaqa": qa_f1_score,
"samsum": rouge_score,
"lsht": classification_score,
"passage_retrieval_en": retrieval_score,
"passage_count": count_score,
"passage_retrieval_zh": retrieval_zh_score,
"lcc": code_sim_score,
"repobench-p": code_sim_score,
}
# def parse_args(args=None):
# parser = argparse.ArgumentParser()
# parser.add_argument('--model', type=str, default=None)
# parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
# return parser.parse_args(args)
def scorer_e(dataset, predictions, answers, lengths, all_classes):
scores = {"0-4k": [], "4-8k": [], "8k+": []}
for prediction, ground_truths, length in zip(predictions, answers, lengths):
score = 0.0
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
prediction = prediction.lstrip("\n").split("\n")[0]
for ground_truth in ground_truths:
score = max(
score,
dataset2metric[dataset](
prediction, ground_truth, all_classes=all_classes
),
)
if length < 4000:
scores["0-4k"].append(score)
elif length < 8000:
scores["4-8k"].append(score)
else:
scores["8k+"].append(score)
for key in scores.keys():
scores[key] = round(100 * np.mean(scores[key]), 2)
return scores
def scorer(dataset, predictions, answers, all_classes):
total_score = 0.0
for prediction, ground_truths in zip(predictions, answers):
score = 0.0
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
prediction = prediction.lstrip("\n").split("\n")[0]
for ground_truth in ground_truths:
score = max(
score,
dataset2metric[dataset](
prediction, ground_truth, all_classes=all_classes
),
)
total_score += score
return round(100 * total_score / len(predictions), 2)
# Task-name
### Paper
Title: `RULER: What’s the Real Context Size of Your Long-Context Language Models?`
Abstract: `https://arxiv.org/abs/2404.06654`
`RULER expands upon the vanilla NIAH test to encompass variations with diverse types and quantities of needles. Moreover, RULER introduces new task categories multi-hop tracing and aggregation to test behaviors beyond searching from context. We evaluate 17 long-context LMs with 13 representative tasks in RULER.`
Homepage: `https://github.com/NVIDIA/RULER`
> [!NOTE]
> When using Ruler tasks, please note:
> 1. A tokenizer is required for data processing. The system will use the `tokenizer` from model_args, or fall back to the tokenizer associated with the `pretrained` model name.
> 2. The default maximum sequence length is 4096. For calculating metrics of different max seq lengths, specify additional lengths using the metadata parameter:
> `--metadata='{"max_seq_lengths":[4096,8192,16384,32768,65536,131072]}'`. The metadata parameter can also be passed to the TaskManager (metadata: dict).
> 3. To prevent truncation of longer sequences, we recommend setting the max_length parameter in model_args:
> `--model_args=pretrained=...,max_length=32768`
### Citation
```
@article{hsieh2024ruler,
title={RULER: What's the Real Context Size of Your Long-Context Language Models?},
author={Cheng-Ping Hsieh and Simeng Sun and Samuel Kriman and Shantanu Acharya and Dima Rekesh and Fei Jia and Yang Zhang and Boris Ginsburg},
year={2024},
journal={arXiv preprint arXiv:2404.06654},
}
```
### Groups, Tags, and Tasks
#### Groups
* `ruler`: `All 13 tasks in the RULER benchmark`
#### Tags
`longcxt`: `Long-context tasks`
#### Tasks
* `niah_single_1`: `NIAH single needle; key=word,value=number,haystack=repeat ∼passkey retrieval`
* `niah_single_2`: `NIAH single needle; key=word,value=number,haystack=essay ∼vanilla NIAH`
* `niah_single_3`: `NIAH single needle; key=word,value=uuid,haystack=essay`
* `niah_multikey_1`: `NIAH multi-key, ∼line retrieval`
* `niah_multikey_2`: `NIAH multi-key, ∼KV retrieval`
* `niah_multikey_3`: `NIAH multi-key, `
* `niah_multiquery`: `NIA multi-query`
* `niah_multivalue`: `NIAH multi-value`
* `ruler_vt`: `Variation tracing`
* `ruler_cwe`: `Common word extraction`
* `ruler_fwe`: `Frequent word extraction`
* `ruler_qa_hotpot`: `QA Hotpot`
* `ruler_qa_squad`: `QA SQuADv2`
### Checklist
For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
import logging
import re
from functools import cache
from typing import TYPE_CHECKING, Union
from transformers import AutoTokenizer
if TYPE_CHECKING:
import transformers
eval_logger = logging.getLogger(__name__)
DEFAULT_SEQ_LENGTHS = [
4096,
]
@cache
def get_tokenizer(
tokenizer=None, pretrained=None, **kwargs
) -> Union["transformers.PreTrainedTokenizer", "transformers.PreTrainedTokenizerFast"]:
pretrained = tokenizer or pretrained
assert pretrained, "No tokenizer or pretrained provided."
eval_logger.info(f"Using tokenizer {pretrained} for synthetic tasks.")
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def postprocess_pred(prediction: list[str]) -> list[str]:
res = []
for predict_str in prediction:
predict_str = predict_str.strip()
# Remove all non-printable characters
np_pattern = re.compile(r"[\x00-\x1f]")
predict_str = np_pattern.sub("\n", predict_str).strip()
res.append(predict_str)
return res
def string_match_all(preds: list[str], refs: list[list[str]]) -> float:
score = sum(
[
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
for pred, ref in zip(preds, refs)
]
) / len(preds)
return score
def string_match_part(preds: list[str], refs: list[list[str]]) -> float:
score = max(
[
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
for pred, ref in zip(preds, refs)
]
) / len(preds)
return score
def process_results(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in DEFAULT_SEQ_LENGTHS}
input_len = doc["max_length"]
pred = postprocess_pred(results)
score = string_match_all(pred, [doc["outputs"]])
metrics[str(input_len)] = score
return metrics
def process_results_part(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in DEFAULT_SEQ_LENGTHS}
input_len = doc["max_length"]
pred = postprocess_pred(results)
score = string_match_part(pred, [doc["outputs"]])
metrics[str(input_len)] = score
return metrics
def aggregate_metrics(metrics: list[float]) -> float:
res = [x for x in metrics if x != -1]
if not res:
# we don't have any samples with this length
return -1
return sum(res) / len(res)
include: niah_single_1.yaml
task: ruler_cwe
custom_dataset: !function cwe_utils.get_cw_dataset
target_delimiter: "\n\n"
generation_kwargs:
do_sample: false
temperature: 0.0
max_gen_toks: 120
until: []
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import itertools
import random
import datasets
import wonderwords
from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
CONFIG = {
"tokens_to_generate": 120,
"template": """Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?""",
"answer_prefix": """ Answer: The top 10 words that appear most often in the list are:""",
}
RNG = random.Random(42)
TEMPLATE = CONFIG["template"] + CONFIG["answer_prefix"]
r = wonderwords.RandomWord()
WORDS = sorted(
list(
set([item for x in ["noun", "adjective", "verb"] for item in r._categories[x]])
)
)
RNG.shuffle(WORDS)
def get_example(num_words, common_repeats=30, uncommon_repeats=3, common_nums=10):
word_list_full = random.sample(WORDS, num_words)
common, uncommon = word_list_full[:common_nums], word_list_full[common_nums:]
word_list = common * int(common_repeats) + uncommon * int(uncommon_repeats)
RNG.shuffle(word_list)
# Formatting the word list as "1. word1 2. word2 3. word3 ..."
context = " ".join([f"{i + 1}. {word}" for i, word in enumerate(word_list)])
return context, common
def generate_input_output(
num_words: int,
max_seq_length: int,
freq_cw: int = 30,
freq_ucw: int = 3,
num_cw: int = 10,
):
if max_seq_length < 4096:
context_example, answer_example = get_example(20, 3, 1, num_cw)
context, answer = get_example(num_words, 6, 1, num_cw)
else:
context_example, answer_example = get_example(40, 10, 3, num_cw)
context, answer = get_example(num_words, freq_cw, freq_ucw, num_cw)
template = TEMPLATE
input_example = template.format(
context=context_example,
query="",
) + " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer_example)])
input_text = template.format(
context=context,
query="",
)
return input_example, input_text, answer
def sys_word_pair_random(
num_samples: int,
max_seq_length: int,
tokenizer=None,
incremental: int = 10,
remove_newline_tab=False,
tokens_to_generate=120,
):
assert tokenizer is not None, "Tokenizer is not provided."
write_jsons = []
tokens_to_generate = tokens_to_generate
# Find the perfect num_words
num_words = incremental
total_tokens = 0
while total_tokens + tokens_to_generate < max_seq_length:
input_example, input_text, answer = generate_input_output(
num_words, max_seq_length
)
# Calculate the number of tokens in the example
total_tokens = len(
tokenizer(
input_example
+ "\n"
+ input_text
+ " "
+ " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)])
).input_ids
)
# print(
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}"
# )
if total_tokens + tokens_to_generate > max_seq_length:
num_words -= incremental
break
num_words += incremental
if num_words > len(WORDS):
num_words = len(WORDS)
break
# print("num_words:", num_words)
# Generate samples
for index in tqdm(
range(num_samples), desc=f"Generating CWE Samples | {max_seq_length}"
):
used_words = num_words
while True:
try:
input_example, input_text, answer = generate_input_output(
used_words, max_seq_length
)
length = len(tokenizer(input_text).input_ids) + tokens_to_generate
assert length <= max_seq_length, f"{length} exceeds max_seq_length."
break
except: # noqa: E722
if used_words > incremental:
used_words -= incremental
if remove_newline_tab:
input_text = " ".join(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
input_example = " ".join(
input_example.replace("\n", " ").replace("\t", " ").strip().split()
)
gen_prefix_index = input_text.rfind(CONFIG["answer_prefix"])
input_text = input_text[:gen_prefix_index]
formatted_output = {
"index": index,
"input": input_text.strip(),
"input_example": input_example,
"outputs": answer,
"length": length,
"max_length": max_seq_length,
"gen_prefix": CONFIG["answer_prefix"].strip(),
}
write_jsons.append(formatted_output)
return write_jsons
def get_dataset(pretrained, seq=None, **kwargs):
tokenizer = get_tokenizer(pretrained)
write_jsons = sys_word_pair_random(
num_samples=500, max_seq_length=seq, tokenizer=tokenizer
)
return write_jsons
def get_cw_dataset(**kwargs):
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (
get_dataset(pretrained, seq=seq)
for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
)
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import asyncio
import glob
import os
from functools import cache
from typing import Dict
import html2text
import httpx
from bs4 import BeautifulSoup
from tqdm.asyncio import tqdm as async_tqdm
@cache
async def fetch_url(client: httpx.AsyncClient, url: str) -> str:
response = await client.get(url)
response.raise_for_status()
return response.text
@cache
async def process_html_essay(
client: httpx.AsyncClient, url: str, h: html2text.HTML2Text, temp_folder: str
) -> None:
filename = url.split("/")[-1].replace(".html", ".txt")
if os.path.exists(os.path.join(temp_folder, filename)):
return None
try:
content = await fetch_url(client, url)
soup = BeautifulSoup(content, "html.parser")
specific_tag = soup.find("font")
if specific_tag:
parsed = h.handle(str(specific_tag))
with open(
os.path.join(temp_folder, filename), "w", encoding="utf-8"
) as file:
file.write(parsed)
except Exception as e:
print(f"Failed to download {filename}: {str(e)}")
@cache
async def process_text_essay(
client: httpx.AsyncClient, url: str, temp_folder: str
) -> None:
filename = url.split("/")[-1]
if os.path.exists(os.path.join(temp_folder, filename)):
return None
try:
content = await fetch_url(client, url)
with open(os.path.join(temp_folder, filename), "w", encoding="utf-8") as file:
file.write(content)
except Exception as e:
print(f"Failed to download {filename}: {str(e)}")
@cache
async def get_essays() -> Dict[str, str]:
temp_folder_repo = "essay_repo"
temp_folder_html = "essay_html"
os.makedirs(temp_folder_repo, exist_ok=True)
os.makedirs(temp_folder_html, exist_ok=True)
h = html2text.HTML2Text()
h.ignore_images = True
h.ignore_tables = True
h.escape_all = True
h.reference_links = False
h.mark_code = False
url_list = "https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt"
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
# Fetch URL list
content = await fetch_url(client, url_list)
urls = content.splitlines()
# Separate HTML and text URLs
html_urls = [url for url in urls if ".html" in url]
text_urls = [url for url in urls if ".html" not in url]
# Process HTML essays
html_tasks = [
process_html_essay(client, url, h, temp_folder_html) for url in html_urls
]
await async_tqdm.gather(*html_tasks, desc="Downloading HTML essays")
# Process text essays
text_tasks = [
process_text_essay(client, url, temp_folder_repo) for url in text_urls
]
await async_tqdm.gather(*text_tasks, desc="Downloading text essays")
# Collect results
files_repo = sorted(glob.glob(os.path.join(temp_folder_repo, "*.txt")))
files_html = sorted(glob.glob(os.path.join(temp_folder_html, "*.txt")))
# Combine all texts
text = ""
for file in files_repo + files_html:
with open(file, "r", encoding="utf-8") as f:
text += f.read()
return {"text": text}
@cache
def get_all_essays() -> Dict[str, str]:
"""Synchronous wrapper for get_essays()"""
return asyncio.run(get_essays())
include: niah_single_1.yaml
task: ruler_fwe
custom_dataset: !function fwe_utils.fwe_download
generation_kwargs:
do_sample: false
temperature: 0.0
max_gen_toks: 50
until: []
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import itertools
import random
import string
import datasets
import numpy as np
import transformers
from scipy.special import zeta
from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
CONFIG = {
"tokens_to_generate": 50,
"template": """Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?""",
"answer_prefix": """ Answer: According to the coded text above, the three most frequently appeared words are:""",
}
SEED = 42
TEMPLATE = CONFIG["template"] + CONFIG["answer_prefix"]
def generate_input_output(
max_len: int,
tokenizer: "transformers.PreTrainedTokenizerFast",
num_words=-1,
coded_wordlen=6,
vocab_size=2000,
incremental=10,
alpha=2.0,
) -> tuple[str, list[str], int]:
# generate vocab
vocab = [
"".join(random.choices(string.ascii_lowercase, k=coded_wordlen))
for _ in range(vocab_size)
]
while len(set(vocab)) < vocab_size:
vocab.append("".join(random.choices(string.ascii_lowercase, k=coded_wordlen)))
vocab = sorted(list(set(vocab)))
random.Random(SEED).shuffle(vocab)
vocab[0] = "..." # treat the top ranked as noise
# sample words
template = TEMPLATE
def gen_text(num_words):
k = np.arange(1, len(vocab) + 1)
sampled_cnt = num_words * (k**-alpha) / zeta(alpha)
sampled_words = [[w] * zi for w, zi in zip(vocab, sampled_cnt.astype(int))]
sampled_words = [x for wlst in sampled_words for x in wlst]
random.Random(SEED).shuffle(sampled_words)
return template.format(context=" ".join(sampled_words), query=""), vocab[1:4]
if num_words > 0:
num_words = num_words
text, answer = gen_text(num_words)
while len(tokenizer(text).input_ids) > max_len:
num_words -= incremental
text, answer = gen_text(num_words)
else:
num_words = max_len // coded_wordlen # init
text, answer = gen_text(num_words)
while len(tokenizer(text).input_ids) < max_len:
num_words += incremental
text, answer = gen_text(num_words)
num_words -= incremental
text, answer = gen_text(num_words)
return text, answer, num_words
def sys_kwext(
tokenizer: "transformers.PreTrainedTokenizerFast",
max_seq_length: int,
num_samples: int = 500,
vocab_size: int = -1,
coded_wordlen: int = 6,
alpha: float = 2.0,
tokens_to_generate: int = 50,
remove_newline_tab: bool = False,
) -> list[dict]:
write_jsons = []
tokens_to_generate = tokens_to_generate
vocab_size = max_seq_length // 50 if vocab_size == -1 else vocab_size
# get number of words
input_max_len = max_seq_length
_, _, num_example_words = generate_input_output(
input_max_len,
tokenizer,
coded_wordlen=coded_wordlen,
vocab_size=vocab_size,
incremental=input_max_len // 32,
alpha=alpha,
)
# Generate samples
for index in tqdm(
range(num_samples), desc=f"Generating FWE Samples | {max_seq_length}"
):
# construct input
input_max_len = max_seq_length
input_text, answer, _ = generate_input_output(
input_max_len,
tokenizer,
num_words=num_example_words,
coded_wordlen=coded_wordlen,
vocab_size=vocab_size,
incremental=input_max_len // 32,
alpha=alpha,
)
length = len(tokenizer(input_text).input_ids) + tokens_to_generate
if remove_newline_tab:
input_text = " ".join(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
formatted_output = {
"index": index,
"input": input_text[: input_text.rfind(CONFIG["answer_prefix"])].strip(),
"outputs": answer,
"length": length,
"max_length": max_seq_length,
"gen_prefix": CONFIG["answer_prefix"].strip(),
}
write_jsons.append(formatted_output)
return write_jsons
def get_dataset(pretrained, max_seq_length=None, **kwargs):
tokenizer = get_tokenizer(pretrained)
write_jsons = sys_kwext(
tokenizer=tokenizer,
max_seq_length=max_seq_length,
)
return write_jsons
def fwe_download(**kwargs):
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (
get_dataset(pretrained, max_seq_length=seq)
for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
)
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
task: niah_multikey_1
include: niah_single_1.yaml
custom_dataset: !function niah_utils.niah_multikey_1
task: niah_multikey_2
include: niah_single_1.yaml
custom_dataset: !function niah_utils.niah_multikey_2
task: niah_multikey_3
include: niah_single_1.yaml
custom_dataset: !function niah_utils.niah_multikey_3
task: niah_multiquery
include: niah_single_1.yaml
custom_dataset: !function niah_utils.niah_multiquery
task: niah_multivalue
include: niah_single_1.yaml
custom_dataset: !function niah_utils.niah_multivalue
tag:
- longcxt
task: niah_single_1
dataset_path: ""
dataset_name: ""
output_type: generate_until
test_split: test
custom_dataset: !function niah_utils.niah_single_1
doc_to_text: "{{input}}"
doc_to_target: "{{outputs}}"
gen_prefix: "{{gen_prefix}}"
target_delimiter: " "
process_results: !function common_utils.process_results
metric_list:
- metric: "4096"
aggregation: !function common_utils.aggregate_metrics
higher_is_better: true
- metric: "8192"
aggregation: !function common_utils.aggregate_metrics
higher_is_better: true
- metric: "16384"
aggregation: !function common_utils.aggregate_metrics
higher_is_better: true
- metric: "32768"
aggregation: !function common_utils.aggregate_metrics
higher_is_better: true
- metric: "65536"
aggregation: !function common_utils.aggregate_metrics
higher_is_better: true
- metric: "131072"
aggregation: !function common_utils.aggregate_metrics
higher_is_better: true
generation_kwargs:
do_sample: false
temperature: 0.0
max_gen_toks: 128
until: []
repeats: 1
metadata:
version: 1.0
task: niah_single_2
include: niah_single_1.yaml
custom_dataset: !function niah_utils.niah_single_2
task: niah_single_3
include: niah_single_1.yaml
custom_dataset: !function niah_utils.niah_single_3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment