Commit 105ea3e1 authored by Baber's avatar Baber
Browse files

pass tokenizer to task config

parent adcacfbf
...@@ -821,7 +821,11 @@ class ConfigurableTask(Task): ...@@ -821,7 +821,11 @@ class ConfigurableTask(Task):
if self.config.download_dataset is None: if self.config.download_dataset is None:
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
else: else:
self.dataset = self.config.download_dataset() self.dataset = self.config.download_dataset(
self.config.metadata.get(
"tokenizer", self.config.metadata.get("pretrained")
)
)
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
......
...@@ -232,7 +232,10 @@ def simple_evaluate( ...@@ -232,7 +232,10 @@ def simple_evaluate(
if task_manager is None: if task_manager is None:
task_manager = TaskManager(verbosity) task_manager = TaskManager(verbosity)
task_dict = get_task_dict(tasks, task_manager) # TODO fix this. hack to get around the fact that we can't pass model to task config
task_dict = get_task_dict(
tasks, task_manager, metadata=simple_parse_args_string(model_args)
)
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups. # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed) # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
......
...@@ -257,6 +257,7 @@ class TaskManager: ...@@ -257,6 +257,7 @@ class TaskManager:
name_or_config: Optional[Union[str, dict]] = None, name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None, parent_name: Optional[str] = None,
update_config: Optional[dict] = None, update_config: Optional[dict] = None,
metadata: Optional[dict] = None,
) -> Mapping: ) -> Mapping:
def _load_task(config, task): def _load_task(config, task):
if "include" in config: if "include" in config:
...@@ -268,6 +269,7 @@ class TaskManager: ...@@ -268,6 +269,7 @@ class TaskManager:
), ),
**config, **config,
} }
print("hello")
if self._config_is_python_task(config): if self._config_is_python_task(config):
if self._class_has_config_in_constructor(config["class"]): if self._class_has_config_in_constructor(config["class"]):
task_object = config["class"](config=config) task_object = config["class"](config=config)
...@@ -277,6 +279,7 @@ class TaskManager: ...@@ -277,6 +279,7 @@ class TaskManager:
# very scuffed: set task name here. TODO: fixme? # very scuffed: set task name here. TODO: fixme?
task_object.config.task = task task_object.config.task = task
else: else:
config["metadata"] = config.get("metadata", {}) | metadata
task_object = ConfigurableTask(config=config) task_object = ConfigurableTask(config=config)
return {task: task_object} return {task: task_object}
...@@ -398,7 +401,9 @@ class TaskManager: ...@@ -398,7 +401,9 @@ class TaskManager:
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
} }
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict: def load_task_or_group(
self, task_list: Optional[Union[str, list]] = None, metadata=None
) -> dict:
"""Loads a dictionary of task objects from a list """Loads a dictionary of task objects from a list
:param task_list: Union[str, list] = None :param task_list: Union[str, list] = None
...@@ -411,12 +416,19 @@ class TaskManager: ...@@ -411,12 +416,19 @@ class TaskManager:
task_list = [task_list] task_list = [task_list]
all_loaded_tasks = dict( all_loaded_tasks = dict(
collections.ChainMap(*map(self._load_individual_task_or_group, task_list)) collections.ChainMap(
*map(
lambda task: self._load_individual_task_or_group(
task, metadata=metadata
),
task_list,
)
)
) )
return all_loaded_tasks return all_loaded_tasks
def load_config(self, config: Dict): def load_config(self, config: Dict, metadata=None | dict):
return self._load_individual_task_or_group(config) return self._load_individual_task_or_group(config, metadata=metadata)
def _get_task_and_group(self, task_dir: str): def _get_task_and_group(self, task_dir: str):
"""Creates a dictionary of tasks index with the following metadata, """Creates a dictionary of tasks index with the following metadata,
...@@ -576,6 +588,7 @@ def _check_duplicates(task_dict: dict) -> List[str]: ...@@ -576,6 +588,7 @@ def _check_duplicates(task_dict: dict) -> List[str]:
def get_task_dict( def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]], task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_manager: Optional[TaskManager] = None, task_manager: Optional[TaskManager] = None,
metadata=None,
): ):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object. """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
...@@ -616,7 +629,7 @@ def get_task_dict( ...@@ -616,7 +629,7 @@ def get_task_dict(
task_manager = TaskManager() task_manager = TaskManager()
task_name_from_string_dict = task_manager.load_task_or_group( task_name_from_string_dict = task_manager.load_task_or_group(
string_task_name_list string_task_name_list, metadata=metadata
) )
for task_element in others_task_name_list: for task_element in others_task_name_list:
......
...@@ -14,21 +14,21 @@ metric_list: ...@@ -14,21 +14,21 @@ metric_list:
- metric: "4096" - metric: "4096"
aggregation: !function utils.aggregate_metrics aggregation: !function utils.aggregate_metrics
higher_is_better: true higher_is_better: true
- metric: "8192" # - metric: "8192"
aggregation: !function utils.aggregate_metrics # aggregation: !function utils.aggregate_metrics
higher_is_better: true # higher_is_better: true
- metric: "16384" # - metric: "16384"
aggregation: !function utils.aggregate_metrics # aggregation: !function utils.aggregate_metrics
higher_is_better: true # higher_is_better: true
- metric: "32768" # - metric: "32768"
aggregation: !function utils.aggregate_metrics # aggregation: !function utils.aggregate_metrics
higher_is_better: true # higher_is_better: true
- metric: "65536" # - metric: "65536"
aggregation: !function utils.aggregate_metrics # aggregation: !function utils.aggregate_metrics
higher_is_better: true # higher_is_better: true
- metric: "131072" # - metric: "131072"
aggregation: !function utils.aggregate_metrics # aggregation: !function utils.aggregate_metrics
higher_is_better: true # higher_is_better: true
generation_kwargs: generation_kwargs:
do_sample: true do_sample: true
temperature: 1.0 temperature: 1.0
......
...@@ -14,21 +14,19 @@ from lm_eval.tasks.ruler.prepare import generate_samples ...@@ -14,21 +14,19 @@ from lm_eval.tasks.ruler.prepare import generate_samples
@cache @cache
def get_tokenizer(): def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained( return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
os.environ.get("TOKENIZER"), trust_remote_code=True
)
# TOKENIZER = AutoTokenizer.from_pretrained(os.environ.get("TOKENIZER")) # TOKENIZER = AutoTokenizer.from_pretrained(os.environ.get("TOKENIZER"))
TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text? The special magic {type_needle_v} for {query} mentioned in the provided text are""" TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text? The special magic {type_needle_v} for {query} mentioned in the provided text are"""
SEQ_LENGTHS = ( SEQ_LENGTHS = (
131072, # 131072,
65536, # 65536,
32768, # 32768,
16384, # 16384,
8192, # 8192,
4096, 4096,
) )
...@@ -63,7 +61,7 @@ def flatten(df): ...@@ -63,7 +61,7 @@ def flatten(df):
# ruff: noqa # ruff: noqa
niah_single_1 = lambda: flatten( niah_single_1 = lambda x: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="repeat"), get_haystack(type_haystack="repeat"),
max_seq_length=seq, max_seq_length=seq,
...@@ -71,12 +69,12 @@ niah_single_1 = lambda: flatten( ...@@ -71,12 +69,12 @@ niah_single_1 = lambda: flatten(
type_haystack="repeat", type_haystack="repeat",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(), TOKENIZER=get_tokenizer(x),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# ruff: noqa # ruff: noqa
niah_single_2 = lambda: flatten( niah_single_2 = lambda x: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -84,12 +82,12 @@ niah_single_2 = lambda: flatten( ...@@ -84,12 +82,12 @@ niah_single_2 = lambda: flatten(
type_haystack="essay", type_haystack="essay",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(), TOKENIZER=get_tokenizer(x),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_single_3 = lambda: flatten( niah_single_3 = lambda x: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -97,12 +95,12 @@ niah_single_3 = lambda: flatten( ...@@ -97,12 +95,12 @@ niah_single_3 = lambda: flatten(
type_haystack="essay", type_haystack="essay",
type_needle_k="words", type_needle_k="words",
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(), TOKENIZER=get_tokenizer(x),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multikey_1 = lambda: flatten( niah_multikey_1 = lambda x: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -111,12 +109,12 @@ niah_multikey_1 = lambda: flatten( ...@@ -111,12 +109,12 @@ niah_multikey_1 = lambda: flatten(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_k=4, num_needle_k=4,
TOKENIZER=get_tokenizer(), TOKENIZER=get_tokenizer(x),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multikey_2 = lambda: flatten( niah_multikey_2 = lambda x: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="needle"), get_haystack(type_haystack="needle"),
max_seq_length=seq, max_seq_length=seq,
...@@ -124,12 +122,12 @@ niah_multikey_2 = lambda: flatten( ...@@ -124,12 +122,12 @@ niah_multikey_2 = lambda: flatten(
type_haystack="needle", type_haystack="needle",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(), TOKENIZER=get_tokenizer(x),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multikey_3 = lambda: flatten( niah_multikey_3 = lambda x: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="needle"), get_haystack(type_haystack="needle"),
max_seq_length=seq, max_seq_length=seq,
...@@ -137,12 +135,12 @@ niah_multikey_3 = lambda: flatten( ...@@ -137,12 +135,12 @@ niah_multikey_3 = lambda: flatten(
type_haystack="needle", type_haystack="needle",
type_needle_k="uuids", type_needle_k="uuids",
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(), TOKENIZER=get_tokenizer(x),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multivalue = lambda: flatten( niah_multivalue = lambda x: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -151,12 +149,12 @@ niah_multivalue = lambda: flatten( ...@@ -151,12 +149,12 @@ niah_multivalue = lambda: flatten(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_v=4, num_needle_v=4,
TOKENIZER=get_tokenizer(), TOKENIZER=get_tokenizer(x),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multiquery = lambda: flatten( niah_multiquery = lambda x: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -165,7 +163,7 @@ niah_multiquery = lambda: flatten( ...@@ -165,7 +163,7 @@ niah_multiquery = lambda: flatten(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_q=4, num_needle_q=4,
TOKENIZER=get_tokenizer(), TOKENIZER=get_tokenizer(x),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment