"vscode:/vscode.git/clone" did not exist on "40dc810c5e3e790278ece6b36bb9e687b0bc13ff"
Commit 352127ae authored by Baber's avatar Baber
Browse files

allow to pass metadata from main

parent a74d0408
...@@ -262,6 +262,12 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -262,6 +262,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true", action="store_true",
help="Confirm that you understand the risks of running unsafe code for tasks that require it", help="Confirm that you understand the risks of running unsafe code for tasks that require it",
) )
parser.add_argument(
"--metadata",
type=str,
default=None,
help="Comma separated string argument metadata to pass to task configs, for example max_context_len=4096,8192 etc.",
)
return parser return parser
...@@ -410,6 +416,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -410,6 +416,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
torch_random_seed=args.seed[2], torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3], fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code, confirm_run_unsafe_code=args.confirm_run_unsafe_code,
metadata=args.metadata,
**request_caching_args, **request_caching_args,
) )
......
...@@ -75,6 +75,7 @@ def simple_evaluate( ...@@ -75,6 +75,7 @@ def simple_evaluate(
torch_random_seed: int = 1234, torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234, fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False, confirm_run_unsafe_code: bool = False,
metadata: Optional[dict] = None,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -98,9 +99,9 @@ def simple_evaluate( ...@@ -98,9 +99,9 @@ def simple_evaluate(
:param cache_requests: bool, optional :param cache_requests: bool, optional
Speed up evaluation by caching the building of dataset requests. `None` if not caching. Speed up evaluation by caching the building of dataset requests. `None` if not caching.
:param rewrite_requests_cache: bool, optional :param rewrite_requests_cache: bool, optional
Rewrites all of the request cache if set to `True`. `None` if not desired. Rewrites all the request cache if set to `True`. `None` if not desired.
:param delete_requests_cache: bool, optional :param delete_requests_cache: bool, optional
Deletes all of the request cache if set to `True`. `None` if not desired. Deletes all the request cache if set to `True`. `None` if not desired.
:param limit: int or float, optional :param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples. Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters: :param bootstrap_iters:
...@@ -134,7 +135,7 @@ def simple_evaluate( ...@@ -134,7 +135,7 @@ def simple_evaluate(
:param fewshot_random_seed: int :param fewshot_random_seed: int
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None. Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
:return return
Dictionary of results Dictionary of results
""" """
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
...@@ -235,7 +236,10 @@ def simple_evaluate( ...@@ -235,7 +236,10 @@ def simple_evaluate(
# TODO fix this. hack to get around the fact that we can't pass model to task config # TODO fix this. hack to get around the fact that we can't pass model to task config
task_dict = get_task_dict( task_dict = get_task_dict(
tasks, task_manager, metadata=simple_parse_args_string(model_args) tasks,
task_manager,
metadata=simple_parse_args_string(model_args)
| simple_parse_args_string(metadata),
) )
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups. # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
......
...@@ -9,12 +9,12 @@ if TYPE_CHECKING: ...@@ -9,12 +9,12 @@ if TYPE_CHECKING:
import transformers import transformers
SEQ_LENGTHS = ( DEFAULT_SEQ_LENGTHS = (
# 131072, # 131072,
# 65536, # 65536,
# 32768, # 32768,
# 16384, 16384,
# 8192, 8192,
4096, 4096,
) )
...@@ -61,7 +61,7 @@ def string_match_part(preds: list[str], refs: list[list[str]]) -> float: ...@@ -61,7 +61,7 @@ def string_match_part(preds: list[str], refs: list[list[str]]) -> float:
def process_results(doc: dict, results: list[str]) -> dict[str, float]: def process_results(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1 # hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS} metrics = {str(length): -1.0 for length in DEFAULT_SEQ_LENGTHS}
input_len = doc["max_length"] input_len = doc["max_length"]
pred = postprocess_pred(results[0]) pred = postprocess_pred(results[0])
score = string_match_all([pred], [doc["outputs"]]) score = string_match_all([pred], [doc["outputs"]])
...@@ -71,7 +71,7 @@ def process_results(doc: dict, results: list[str]) -> dict[str, float]: ...@@ -71,7 +71,7 @@ def process_results(doc: dict, results: list[str]) -> dict[str, float]:
def process_results_part(doc: dict, results: list[str]) -> dict[str, float]: def process_results_part(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1 # hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS} metrics = {str(length): -1.0 for length in DEFAULT_SEQ_LENGTHS}
input_len = doc["max_length"] input_len = doc["max_length"]
pred = postprocess_pred(results[0]) pred = postprocess_pred(results[0])
score = string_match_part([pred], [doc["outputs"]]) score = string_match_part([pred], [doc["outputs"]])
......
...@@ -18,7 +18,7 @@ import datasets ...@@ -18,7 +18,7 @@ import datasets
import wonderwords import wonderwords
from tqdm import tqdm from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
RNG = random.Random(42) RNG = random.Random(42)
...@@ -172,7 +172,7 @@ def get_dataset(pretrained, seq=None, **kwargs): ...@@ -172,7 +172,7 @@ def get_dataset(pretrained, seq=None, **kwargs):
def get_cw_dataset(**kwargs): def get_cw_dataset(**kwargs):
kwargs = kwargs.get("metadata", {}) kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS) df = (get_dataset(pretrained, seq=seq) for seq in DEFAULT_SEQ_LENGTHS)
return { return {
"test": datasets.Dataset.from_list( "test": datasets.Dataset.from_list(
......
...@@ -21,7 +21,7 @@ import transformers ...@@ -21,7 +21,7 @@ import transformers
from scipy.special import zeta from scipy.special import zeta
from tqdm import tqdm from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
config = ( config = (
...@@ -159,7 +159,7 @@ def get_dataset(pretrained, max_seq_length=None, **kwargs): ...@@ -159,7 +159,7 @@ def get_dataset(pretrained, max_seq_length=None, **kwargs):
def fwe_download(**kwargs): def fwe_download(**kwargs):
kwargs = kwargs.get("metadata", {}) kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, max_seq_length=seq) for seq in SEQ_LENGTHS) df = (get_dataset(pretrained, max_seq_length=seq) for seq in DEFAULT_SEQ_LENGTHS)
return { return {
"test": datasets.Dataset.from_list( "test": datasets.Dataset.from_list(
......
...@@ -3,8 +3,9 @@ from typing import Generator ...@@ -3,8 +3,9 @@ from typing import Generator
import datasets import datasets
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
from lm_eval.tasks.ruler.prepare_niah import generate_samples, get_haystack from lm_eval.tasks.ruler.prepare_niah import generate_samples, get_haystack
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?""" TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?"""
...@@ -17,8 +18,9 @@ def download_dataset(df: Generator) -> dict[str, datasets.Dataset]: ...@@ -17,8 +18,9 @@ def download_dataset(df: Generator) -> dict[str, datasets.Dataset]:
} }
# ruff: noqa def niah_single_1(**kwargs):
niah_single_1 = lambda **kwargs: download_dataset( seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="repeat"), get_haystack(type_haystack="repeat"),
max_seq_length=seq, max_seq_length=seq,
...@@ -28,10 +30,13 @@ niah_single_1 = lambda **kwargs: download_dataset( ...@@ -28,10 +30,13 @@ niah_single_1 = lambda **kwargs: download_dataset(
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in seq_lengths
) )
# ruff: noqa
niah_single_2 = lambda **kwargs: download_dataset(
def niah_single_2(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -41,10 +46,13 @@ niah_single_2 = lambda **kwargs: download_dataset( ...@@ -41,10 +46,13 @@ niah_single_2 = lambda **kwargs: download_dataset(
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in seq_lengths
) )
# noqa
niah_single_3 = lambda **kwargs: download_dataset(
def niah_single_3(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -54,10 +62,13 @@ niah_single_3 = lambda **kwargs: download_dataset( ...@@ -54,10 +62,13 @@ niah_single_3 = lambda **kwargs: download_dataset(
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in seq_lengths
) )
# noqa
niah_multikey_1 = lambda **kwargs: download_dataset(
def niah_multikey_1(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -68,10 +79,13 @@ niah_multikey_1 = lambda **kwargs: download_dataset( ...@@ -68,10 +79,13 @@ niah_multikey_1 = lambda **kwargs: download_dataset(
num_needle_k=4, num_needle_k=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in seq_lengths
) )
# noqa
niah_multikey_2 = lambda **kwargs: download_dataset(
def niah_multikey_2(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="needle"), get_haystack(type_haystack="needle"),
max_seq_length=seq, max_seq_length=seq,
...@@ -81,10 +95,13 @@ niah_multikey_2 = lambda **kwargs: download_dataset( ...@@ -81,10 +95,13 @@ niah_multikey_2 = lambda **kwargs: download_dataset(
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in seq_lengths
) )
# noqa
niah_multikey_3 = lambda **kwargs: download_dataset(
def niah_multikey_3(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="needle"), get_haystack(type_haystack="needle"),
max_seq_length=seq, max_seq_length=seq,
...@@ -94,10 +111,13 @@ niah_multikey_3 = lambda **kwargs: download_dataset( ...@@ -94,10 +111,13 @@ niah_multikey_3 = lambda **kwargs: download_dataset(
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in seq_lengths
) )
# noqa
niah_multivalue = lambda **kwargs: download_dataset(
def niah_multivalue(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -108,10 +128,13 @@ niah_multivalue = lambda **kwargs: download_dataset( ...@@ -108,10 +128,13 @@ niah_multivalue = lambda **kwargs: download_dataset(
num_needle_v=4, num_needle_v=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in seq_lengths
) )
# noqa
niah_multiquery = lambda **kwargs: download_dataset(
def niah_multiquery(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -122,5 +145,5 @@ niah_multiquery = lambda **kwargs: download_dataset( ...@@ -122,5 +145,5 @@ niah_multiquery = lambda **kwargs: download_dataset(
num_needle_q=4, num_needle_q=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in seq_lengths
) )
...@@ -21,7 +21,7 @@ import datasets ...@@ -21,7 +21,7 @@ import datasets
import requests import requests
from tqdm import tqdm from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
config = { config = {
"tokens_to_generate": 32, "tokens_to_generate": 32,
...@@ -223,7 +223,7 @@ def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]: ...@@ -223,7 +223,7 @@ def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]:
qas, docs = read_hotpotqa() qas, docs = read_hotpotqa()
df = ( df = (
get_dataset(pretrained=pretrained, docs=docs, qas=qas, max_seq_length=seq) get_dataset(pretrained=pretrained, docs=docs, qas=qas, max_seq_length=seq)
for seq in SEQ_LENGTHS for seq in DEFAULT_SEQ_LENGTHS
) )
return { return {
......
...@@ -22,7 +22,7 @@ import datasets ...@@ -22,7 +22,7 @@ import datasets
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
TASKS = { TASKS = {
"variable_tracking": { "variable_tracking": {
...@@ -239,7 +239,7 @@ def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]: ...@@ -239,7 +239,7 @@ def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]:
def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]: def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {}) kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS) df = (get_dataset(pretrained, seq=seq) for seq in DEFAULT_SEQ_LENGTHS)
return { return {
"test": datasets.Dataset.from_list( "test": datasets.Dataset.from_list(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment