Commit 352127ae authored by Baber's avatar Baber
Browse files

allow to pass metadata from main

parent a74d0408
......@@ -262,6 +262,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
)
parser.add_argument(
"--metadata",
type=str,
default=None,
help="Comma separated string argument metadata to pass to task configs, for example max_context_len=4096,8192 etc.",
)
return parser
......@@ -410,6 +416,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
metadata=args.metadata,
**request_caching_args,
)
......
......@@ -75,6 +75,7 @@ def simple_evaluate(
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False,
metadata: Optional[dict] = None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -98,9 +99,9 @@ def simple_evaluate(
:param cache_requests: bool, optional
Speed up evaluation by caching the building of dataset requests. `None` if not caching.
:param rewrite_requests_cache: bool, optional
Rewrites all of the request cache if set to `True`. `None` if not desired.
Rewrites all the request cache if set to `True`. `None` if not desired.
:param delete_requests_cache: bool, optional
Deletes all of the request cache if set to `True`. `None` if not desired.
Deletes all the request cache if set to `True`. `None` if not desired.
:param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters:
......@@ -134,7 +135,7 @@ def simple_evaluate(
:param fewshot_random_seed: int
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
:return
return
Dictionary of results
"""
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
......@@ -235,7 +236,10 @@ def simple_evaluate(
# TODO fix this. hack to get around the fact that we can't pass model to task config
task_dict = get_task_dict(
tasks, task_manager, metadata=simple_parse_args_string(model_args)
tasks,
task_manager,
metadata=simple_parse_args_string(model_args)
| simple_parse_args_string(metadata),
)
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
......
......@@ -9,12 +9,12 @@ if TYPE_CHECKING:
import transformers
SEQ_LENGTHS = (
DEFAULT_SEQ_LENGTHS = (
# 131072,
# 65536,
# 32768,
# 16384,
# 8192,
16384,
8192,
4096,
)
......@@ -61,7 +61,7 @@ def string_match_part(preds: list[str], refs: list[list[str]]) -> float:
def process_results(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS}
metrics = {str(length): -1.0 for length in DEFAULT_SEQ_LENGTHS}
input_len = doc["max_length"]
pred = postprocess_pred(results[0])
score = string_match_all([pred], [doc["outputs"]])
......@@ -71,7 +71,7 @@ def process_results(doc: dict, results: list[str]) -> dict[str, float]:
def process_results_part(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS}
metrics = {str(length): -1.0 for length in DEFAULT_SEQ_LENGTHS}
input_len = doc["max_length"]
pred = postprocess_pred(results[0])
score = string_match_part([pred], [doc["outputs"]])
......
......@@ -18,7 +18,7 @@ import datasets
import wonderwords
from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
RNG = random.Random(42)
......@@ -172,7 +172,7 @@ def get_dataset(pretrained, seq=None, **kwargs):
def get_cw_dataset(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS)
df = (get_dataset(pretrained, seq=seq) for seq in DEFAULT_SEQ_LENGTHS)
return {
"test": datasets.Dataset.from_list(
......
......@@ -21,7 +21,7 @@ import transformers
from scipy.special import zeta
from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
config = (
......@@ -159,7 +159,7 @@ def get_dataset(pretrained, max_seq_length=None, **kwargs):
def fwe_download(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, max_seq_length=seq) for seq in SEQ_LENGTHS)
df = (get_dataset(pretrained, max_seq_length=seq) for seq in DEFAULT_SEQ_LENGTHS)
return {
"test": datasets.Dataset.from_list(
......
......@@ -3,8 +3,9 @@ from typing import Generator
import datasets
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
from lm_eval.tasks.ruler.prepare_niah import generate_samples, get_haystack
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?"""
......@@ -17,110 +18,132 @@ def download_dataset(df: Generator) -> dict[str, datasets.Dataset]:
}
# ruff: noqa
niah_single_1 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="repeat"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="repeat",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
def niah_single_1(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="repeat"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="repeat",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in seq_lengths
)
for seq in SEQ_LENGTHS
)
# ruff: noqa
niah_single_2 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
def niah_single_2(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in seq_lengths
)
for seq in SEQ_LENGTHS
)
# noqa
niah_single_3 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
def niah_single_3(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in seq_lengths
)
for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_1 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
num_needle_k=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
def niah_multikey_1(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
num_needle_k=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in seq_lengths
)
for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_2 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="needle"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="needle",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
def niah_multikey_2(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="needle"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="needle",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in seq_lengths
)
for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_3 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="needle"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="needle",
type_needle_k="uuids",
type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
def niah_multikey_3(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="needle"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="needle",
type_needle_k="uuids",
type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in seq_lengths
)
for seq in SEQ_LENGTHS
)
# noqa
niah_multivalue = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
num_needle_v=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
def niah_multivalue(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
num_needle_v=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in seq_lengths
)
for seq in SEQ_LENGTHS
)
# noqa
niah_multiquery = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
num_needle_q=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
def niah_multiquery(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
num_needle_q=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in seq_lengths
)
for seq in SEQ_LENGTHS
)
......@@ -21,7 +21,7 @@ import datasets
import requests
from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
config = {
"tokens_to_generate": 32,
......@@ -223,7 +223,7 @@ def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]:
qas, docs = read_hotpotqa()
df = (
get_dataset(pretrained=pretrained, docs=docs, qas=qas, max_seq_length=seq)
for seq in SEQ_LENGTHS
for seq in DEFAULT_SEQ_LENGTHS
)
return {
......
......@@ -22,7 +22,7 @@ import datasets
import numpy as np
from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
TASKS = {
"variable_tracking": {
......@@ -239,7 +239,7 @@ def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]:
def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS)
df = (get_dataset(pretrained, seq=seq) for seq in DEFAULT_SEQ_LENGTHS)
return {
"test": datasets.Dataset.from_list(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment