"tests/nn/vscode:/vscode.git/clone" did not exist on "8c8a625a828ae5a222b30cb1256af0da7ecfee26"
Commit d859d1ca authored by Nathan Habib's avatar Nathan Habib
Browse files

batch commit

parent 6e49b1f6
...@@ -102,10 +102,12 @@ results = lm_eval.simple_evaluate( # call simple_evaluate ...@@ -102,10 +102,12 @@ results = lm_eval.simple_evaluate( # call simple_evaluate
) )
``` ```
See the `simple_evaluate()` and `evaluate()` functions in [lm_eval/evaluator.py](../lm_eval/evaluator.py#:~:text=simple_evaluate) for a full description of all arguments available. All keyword arguments to simple_evaluate share the same role as the command-line flags described previously. See https://github.com/EleutherAI/lm-evaluation-harness/blob/365fcda9b85bbb6e0572d91976b8daf409164500/lm_eval/evaluator.py#L35 for a full description of all arguments available. All keyword arguments to simple_evaluate share the same role as the command-line flags described previously.
Additionally, the `evaluate()` function offers the core evaluation functionality provided by the library, but without some of the special handling and simplification + abstraction provided by `simple_evaluate()`. Additionally, the `evaluate()` function offers the core evaluation functionality provided by the library, but without some of the special handling and simplification + abstraction provided by `simple_evaluate()`.
See https://github.com/EleutherAI/lm-evaluation-harness/blob/365fcda9b85bbb6e0572d91976b8daf409164500/lm_eval/evaluator.py#L173 for more details.
As a brief example usage of `evaluate()`: As a brief example usage of `evaluate()`:
```python ```python
...@@ -145,7 +147,7 @@ task_dict = lm_eval.tasks.get_task_dict( ...@@ -145,7 +147,7 @@ task_dict = lm_eval.tasks.get_task_dict(
task_manager # A task manager that allows lm_eval to task_manager # A task manager that allows lm_eval to
# load the task during evaluation. # load the task during evaluation.
# If none is provided, `get_task_dict` # If none is provided, `get_task_dict`
# will instantiate one itself, but this # will instantiated one itself, but this
# only includes the stock tasks so users # only includes the stock tasks so users
# will need to set this if including # will need to set this if including
# custom paths is required. # custom paths is required.
......
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
import sys import sys
from functools import partial from functools import partial
from typing import Union from typing import Union
from accelerate import Accelerator
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict from lm_eval.evaluator import request_caching_arg_to_dict
...@@ -292,13 +293,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -292,13 +293,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
"If fewshot_as_multiturn is set, apply_chat_template must be set to True." "If fewshot_as_multiturn is set, apply_chat_template must be set to True."
) )
if (
args.num_fewshot is None or args.num_fewshot == 0
) and args.fewshot_as_multiturn:
raise ValueError(
"If fewshot_as_multiturn is set, num_fewshot must be greater than 0."
)
if args.include_path is not None: if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}") eval_logger.info(f"Including path: {args.include_path}")
task_manager = TaskManager(args.verbosity, include_path=args.include_path) task_manager = TaskManager(args.verbosity, include_path=args.include_path)
...@@ -354,17 +348,11 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -354,17 +348,11 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code: if args.trust_remote_code:
eval_logger.info( os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`" args.model_args = (
args.model_args
+ f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}"
) )
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
import datasets
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
args.model_args = args.model_args + ",trust_remote_code=True"
eval_logger.info(f"Selected Tasks: {task_names}") eval_logger.info(f"Selected Tasks: {task_names}")
...@@ -400,7 +388,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -400,7 +388,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
**request_caching_args, **request_caching_args,
) )
if results is not None: accelerator = Accelerator()
if results is not None and accelerator.is_main_process:
if args.log_samples: if args.log_samples:
samples = results.pop("samples") samples = results.pop("samples")
dumped = json.dumps( dumped = json.dumps(
......
...@@ -67,9 +67,9 @@ class TaskConfig(dict): ...@@ -67,9 +67,9 @@ class TaskConfig(dict):
training_split: Optional[str] = None training_split: Optional[str] = None
validation_split: Optional[str] = None validation_split: Optional[str] = None
test_split: Optional[str] = None test_split: Optional[str] = None
fewshot_split: Optional[str] = ( fewshot_split: Optional[
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) str
) ] = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None process_docs: Optional[Callable] = None
...@@ -92,9 +92,9 @@ class TaskConfig(dict): ...@@ -92,9 +92,9 @@ class TaskConfig(dict):
filter_list: Optional[Union[str, list]] = None filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None doc_to_decontamination_query: Optional[str] = None
metadata: Optional[dict] = ( metadata: Optional[
None # by default, not used in the code. allows for users to pass arbitrary info to tasks dict
) ] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
...@@ -229,9 +229,9 @@ class Task(abc.ABC): ...@@ -229,9 +229,9 @@ class Task(abc.ABC):
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig() self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
self._filters = [build_filter_ensemble("none", [["take_first", None]])] self._filters = [build_filter_ensemble("none", [["take_first", None]])]
self.fewshot_rnd: Optional[random.Random] = ( self.fewshot_rnd: Optional[
None # purposely induce errors in case of improper usage random.Random
) ] = None # purposely induce errors in case of improper usage
def download( def download(
self, self,
...@@ -368,16 +368,15 @@ class Task(abc.ABC): ...@@ -368,16 +368,15 @@ class Task(abc.ABC):
def build_all_requests( def build_all_requests(
self, self,
*, *,
limit: Union[int, None] = None, limit=None,
rank: int = 0, rank=None,
world_size: int = 1, world_size=None,
cache_requests: bool = False, cache_requests=False,
rewrite_requests_cache: bool = False, rewrite_requests_cache=False,
system_instruction: Optional[str] = None, system_instruction=None,
apply_chat_template: bool = False, apply_chat_template=False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn=False,
chat_template: Optional[Callable] = None, lm=None,
tokenizer_name: str = "",
) -> None: ) -> None:
"""Build a set of Instances for a task, and store them in task.instances""" """Build a set of Instances for a task, and store them in task.instances"""
...@@ -392,7 +391,7 @@ class Task(abc.ABC): ...@@ -392,7 +391,7 @@ class Task(abc.ABC):
if system_instruction is not None if system_instruction is not None
else "" else ""
) )
cache_key += f"-tokenizer{tokenizer_name}" cache_key += f"-tokenizer{lm.tokenizer_name}" if apply_chat_template else ""
cached_instances = load_from_cache(file_name=cache_key) cached_instances = load_from_cache(file_name=cache_key)
...@@ -437,7 +436,7 @@ class Task(abc.ABC): ...@@ -437,7 +436,7 @@ class Task(abc.ABC):
system_instruction, system_instruction,
apply_chat_template, apply_chat_template,
fewshot_as_multiturn, fewshot_as_multiturn,
chat_template, lm,
) )
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
...@@ -445,6 +444,7 @@ class Task(abc.ABC): ...@@ -445,6 +444,7 @@ class Task(abc.ABC):
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats), metadata=(self.config["task"], doc_id, self.config.repeats),
apply_chat_template=apply_chat_template
) )
if not isinstance(inst, list): if not isinstance(inst, list):
...@@ -987,6 +987,28 @@ class ConfigurableTask(Task): ...@@ -987,6 +987,28 @@ class ConfigurableTask(Task):
return super().fewshot_docs() return super().fewshot_docs()
@staticmethod @staticmethod
def append_target_question(
labeled_examples: List[Dict[str, str]],
question: str,
fewshot_as_multiturn: bool = False,
) -> None:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
"""
if not fewshot_as_multiturn:
# if no messages or last message is system, append as new user entry
if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
labeled_examples.append({"role": "user", "content": question})
# if last message is user, append to it to avoid two user messages in a row
else:
labeled_examples[-1]["content"] += question
else:
return self.sampler.fewshot_delimiter + "".join(
f"{s['role']}: {s['content']}" + self.sampler.fewshot_delimiter
for s in chat_history
)
@staticmethod
def append_target_question( def append_target_question(
labeled_examples: List[Dict[str, str]], labeled_examples: List[Dict[str, str]],
question: str, question: str,
...@@ -1015,7 +1037,7 @@ class ConfigurableTask(Task): ...@@ -1015,7 +1037,7 @@ class ConfigurableTask(Task):
system_instruction: Optional[str] = None, system_instruction: Optional[str] = None,
apply_chat_template: bool = False, apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None, lm=None,
) -> str: ) -> str:
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -1028,10 +1050,12 @@ class ConfigurableTask(Task): ...@@ -1028,10 +1050,12 @@ class ConfigurableTask(Task):
System instruction to be applied to the prompt. System instruction to be applied to the prompt.
:param apply_chat_template: bool :param apply_chat_template: bool
Whether to apply the chat template to the fewshot context. Whether to apply the chat template to the fewshot context.
:param tokenizer:
The tokenizer to use for applying the chat template.
:param fewshot_as_multiturn: bool :param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn. Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param chat_template: Callable :param lm:
Chat template to be applied to the fewshot context. Language model with definition of the tokenizer/function to use for applying the chat template.
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
...@@ -1078,7 +1102,7 @@ class ConfigurableTask(Task): ...@@ -1078,7 +1102,7 @@ class ConfigurableTask(Task):
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
if apply_chat_template: if apply_chat_template:
if self.multiple_input: if self.multiple_input:
return chat_template(labeled_examples) return lm.apply_chat_template(labeled_examples)
if isinstance(example, str): if isinstance(example, str):
self.append_target_question( self.append_target_question(
labeled_examples, example, fewshot_as_multiturn labeled_examples, example, fewshot_as_multiturn
...@@ -1090,7 +1114,7 @@ class ConfigurableTask(Task): ...@@ -1090,7 +1114,7 @@ class ConfigurableTask(Task):
for ex in example: for ex in example:
chat = deepcopy(labeled_examples) chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn) self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(chat_template(chat)) labeled_examples_list.append(lm.apply_chat_template(chat))
return labeled_examples_list return labeled_examples_list
# if example is an integer, append the choice or convert to string # if example is an integer, append the choice or convert to string
elif isinstance(example, int): elif isinstance(example, int):
...@@ -1104,7 +1128,7 @@ class ConfigurableTask(Task): ...@@ -1104,7 +1128,7 @@ class ConfigurableTask(Task):
labeled_examples, str(example), fewshot_as_multiturn labeled_examples, str(example), fewshot_as_multiturn
) )
# return lm.apply_chat_template(labeled_examples) # return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples) return lm.apply_chat_template(labeled_examples)
else: else:
if self.multiple_input: if self.multiple_input:
return labeled_examples return labeled_examples
...@@ -1271,6 +1295,8 @@ class ConfigurableTask(Task): ...@@ -1271,6 +1295,8 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
target_delimiter = self.config.target_delimiter target_delimiter = self.config.target_delimiter
if kwargs.get("apply_chat_template", False) is True:
target_delimiter = ""
if self.multiple_input: if self.multiple_input:
# If there are multiple inputs, choices are placed in the ctx # If there are multiple inputs, choices are placed in the ctx
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
...@@ -1280,6 +1306,7 @@ class ConfigurableTask(Task): ...@@ -1280,6 +1306,7 @@ class ConfigurableTask(Task):
else: else:
# Otherwise they are placed in the continuation # Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
kwargs.pop("apply_chat_template")
request_list = [ request_list = [
Instance( Instance(
...@@ -1316,6 +1343,7 @@ class ConfigurableTask(Task): ...@@ -1316,6 +1343,7 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until": elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs)) arguments = (ctx, deepcopy(self.config.generation_kwargs))
kwargs.pop("apply_chat_template")
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
) )
......
...@@ -22,7 +22,7 @@ from lm_eval.evaluator_utils import ( ...@@ -22,7 +22,7 @@ from lm_eval.evaluator_utils import (
run_task_tests, run_task_tests,
) )
from lm_eval.loggers import EvaluationTracker from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash from lm_eval.loggers.utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import ( from lm_eval.utils import (
eval_logger, eval_logger,
...@@ -271,7 +271,6 @@ def simple_evaluate( ...@@ -271,7 +271,6 @@ def simple_evaluate(
model_args=model_args, model_args=model_args,
system_instruction=system_instruction, system_instruction=system_instruction,
chat_template=lm.chat_template if apply_chat_template else None, chat_template=lm.chat_template if apply_chat_template else None,
fewshot_as_multiturn=fewshot_as_multiturn,
) )
results = evaluate( results = evaluate(
...@@ -326,7 +325,6 @@ def simple_evaluate( ...@@ -326,7 +325,6 @@ def simple_evaluate(
results["git_hash"] = get_git_commit_hash() results["git_hash"] = get_git_commit_hash()
results["date"] = start_date results["date"] = start_date
add_env_info(results) # additional environment info to results add_env_info(results) # additional environment info to results
add_tokenizer_info(results, lm) # additional info about tokenizer
return results return results
else: else:
return None return None
...@@ -399,12 +397,7 @@ def evaluate( ...@@ -399,12 +397,7 @@ def evaluate(
system_instruction=system_instruction, system_instruction=system_instruction,
apply_chat_template=apply_chat_template, apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn, fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=getattr(lm, "apply_chat_template") lm=lm,
if apply_chat_template
else None,
tokenizer_name=getattr(lm, "tokenizer_name", "")
if apply_chat_template
else "",
) )
eval_logger.debug( eval_logger.debug(
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}" f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
...@@ -614,16 +607,16 @@ def evaluate( ...@@ -614,16 +607,16 @@ def evaluate(
] ]
# compute group's pooled metric and stderr # compute group's pooled metric and stderr
results[group][metric] = ( results[group][
lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes) metric
) ] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
# TODO: calculate grouped metric using aggregation fn # TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs: if "N/A" in stderrs:
results[group][stderr] = "N/A" results[group][stderr] = "N/A"
else: else:
results[group][stderr] = ( results[group][
lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes) stderr
) ] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line: # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics) # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
......
...@@ -275,9 +275,9 @@ def consolidate_results( ...@@ -275,9 +275,9 @@ def consolidate_results(
metric_key metric_key
] ]
results[task_output.task_name]["samples"] = task_output.sample_len results[task_output.task_name]["samples"] = task_output.sample_len
results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = ( results[task_output.task_name][
task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] f"{metric}_stderr,{filter_key}"
) ] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
return results, samples, configs, versions, num_fewshot, higher_is_better return results, samples, configs, versions, num_fewshot, higher_is_better
......
...@@ -4,6 +4,7 @@ from lm_eval.api.registry import register_filter ...@@ -4,6 +4,7 @@ from lm_eval.api.registry import register_filter
@register_filter("decontaminate") @register_filter("decontaminate")
class DecontaminationFilter(Filter): class DecontaminationFilter(Filter):
""" """
A filter which evaluates A filter which evaluates
""" """
......
...@@ -62,11 +62,8 @@ class WhitespaceFilter(Filter): ...@@ -62,11 +62,8 @@ class WhitespaceFilter(Filter):
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
if resp.startswith(" "): resp = resp.lstrip()
resp = resp[1:]
filtered_resp.append(resp) filtered_resp.append(resp)
return filtered_resp return filtered_resp
filtered_resps = [filter_set(resp) for resp in resps] filtered_resps = [filter_set(resp) for resp in resps]
......
import json import json
import os
import re import re
import time import time
from collections import defaultdict from collections import defaultdict
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from datasets import load_dataset from datasets import load_dataset
from datasets.utils.metadata import MetadataConfigs from datasets.utils.metadata import MetadataConfigs
...@@ -17,15 +19,9 @@ from huggingface_hub import ( ...@@ -17,15 +19,9 @@ from huggingface_hub import (
from lm_eval.utils import ( from lm_eval.utils import (
eval_logger, eval_logger,
get_file_datetime,
get_file_task_name,
get_results_filenames,
get_sample_results_filenames,
handle_non_serializable, handle_non_serializable,
hash_string, hash_string,
sanitize_list, sanitize_list,
sanitize_model_name,
sanitize_task_name,
) )
...@@ -48,7 +44,6 @@ class GeneralConfigTracker: ...@@ -48,7 +44,6 @@ class GeneralConfigTracker:
model_name_sanitized: str = None model_name_sanitized: str = None
system_instruction: str = None system_instruction: str = None
system_instruction_sha: str = None system_instruction_sha: str = None
fewshot_as_multiturn: bool = None
chat_template: str = None chat_template: str = None
chat_template_sha: str = None chat_template_sha: str = None
start_time: float = None start_time: float = None
...@@ -81,19 +76,24 @@ class GeneralConfigTracker: ...@@ -81,19 +76,24 @@ class GeneralConfigTracker:
model_args: str, model_args: str,
system_instruction: str, system_instruction: str,
chat_template: str, chat_template: str,
fewshot_as_multiturn: bool,
) -> None: ) -> None:
"""Logs model parameters and job ID.""" """Logs model parameters and job ID."""
self.model_source = model_source self.model_source = model_source
self.model_name = GeneralConfigTracker._get_model_name(model_args) self.model_name = GeneralConfigTracker._get_model_name(model_args)
self.model_name_sanitized = sanitize_model_name(self.model_name) self.model_name_sanitized = re.sub(
r"[\"<>:/\|\\?\*\[\]]+", "__", self.model_name
)
self.system_instruction = system_instruction self.system_instruction = system_instruction
self.system_instruction_sha = ( self.system_instruction_sha = (
hash_string(system_instruction) if system_instruction else None hash_string(system_instruction) if system_instruction else None
) )
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_sha = hash_string(chat_template) if chat_template else None self.chat_template_sha = None
self.fewshot_as_multiturn = fewshot_as_multiturn if chat_template:
if not isinstance(chat_template, str):
self.chat_template_sha = hash_string(str(chat_template))
else:
self.chat_template_sha = hash_string(chat_template)
def log_end_time(self) -> None: def log_end_time(self) -> None:
"""Logs the end time of the evaluation and calculates the total evaluation time.""" """Logs the end time of the evaluation and calculates the total evaluation time."""
...@@ -210,21 +210,17 @@ class EvaluationTracker: ...@@ -210,21 +210,17 @@ class EvaluationTracker:
file_results_aggregated.open("w", encoding="utf-8").write(dumped) file_results_aggregated.open("w", encoding="utf-8").write(dumped)
if self.api and self.push_results_to_hub: if self.api and self.push_results_to_hub:
repo_id = ( repo_id = "open-llm-leaderboard/results_v2"
self.hub_results_repo
if self.public_repo
else self.hub_results_repo_private
)
self.api.create_repo( self.api.create_repo(
repo_id=repo_id, repo_id=repo_id,
repo_type="dataset", repo_type="dataset",
private=not self.public_repo, private=not self.public_repo,
exist_ok=True, exist_ok=True,
) )
self.api.upload_folder( self.api.upload_file(
repo_id=repo_id, repo_id=repo_id,
folder_path=str(path), path_or_fileobj=str(path.joinpath(f"results_{self.date_id}.json")),
path_in_repo=self.general_config_tracker.model_name_sanitized, path_in_repo=os.path.join(self.general_config_tracker.model_name, f"results_{self.date_id}.json"),
repo_type="dataset", repo_type="dataset",
commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}", commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
) )
...@@ -262,7 +258,7 @@ class EvaluationTracker: ...@@ -262,7 +258,7 @@ class EvaluationTracker:
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
file_results_samples = path.joinpath( file_results_samples = path.joinpath(
f"samples_{task_name}_{self.date_id}.jsonl" f"samples_{task_name}_{self.date_id}.json"
) )
for sample in samples: for sample in samples:
...@@ -278,6 +274,7 @@ class EvaluationTracker: ...@@ -278,6 +274,7 @@ class EvaluationTracker:
sample["resps"] = sanitize_list(sample["resps"]) sample["resps"] = sanitize_list(sample["resps"])
sample["filtered_resps"] = sanitize_list(sample["filtered_resps"]) sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
sample["arguments"] = arguments sample["arguments"] = arguments
sample["target"] = str(sample["target"])
sample_dump = ( sample_dump = (
json.dumps( json.dumps(
...@@ -303,6 +300,13 @@ class EvaluationTracker: ...@@ -303,6 +300,13 @@ class EvaluationTracker:
private=not self.public_repo, private=not self.public_repo,
exist_ok=True, exist_ok=True,
) )
headers = build_hf_headers()
r = get_session().put(
url=f"https://huggingface.co/api/datasets/{repo_id}/settings",
headers=headers,
json={"gated": "auto"},
)
hf_raise_for_status(r)
self.api.upload_folder( self.api.upload_folder(
repo_id=repo_id, repo_id=repo_id,
folder_path=str(path), folder_path=str(path),
...@@ -326,14 +330,23 @@ class EvaluationTracker: ...@@ -326,14 +330,23 @@ class EvaluationTracker:
Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub. Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub.
""" """
def get_file_task_name(filename: str) -> str:
return filename[filename.find("_") + 1 : filename.rfind("_")]
def get_file_datetime(filename: str) -> str:
return filename[filename.rfind("_") + 1 :].replace(".json", "")
def sanitize_task_name(task_name: str) -> str:
return re.sub(r"\W", "_", task_name)
eval_logger.info("Recreating metadata card") eval_logger.info("Recreating metadata card")
repo_id = ( repo_id = (
self.hub_results_repo if self.public_repo else self.hub_results_repo_private self.hub_results_repo if self.public_repo else self.hub_results_repo_private
) )
files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset") files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset")
results_files = get_results_filenames(files_in_repo) results_files = [f for f in files_in_repo if "/results_" in f and ".json" in f]
sample_files = get_sample_results_filenames(files_in_repo) sample_files = [f for f in files_in_repo if "/samples_" in f and ".json" in f]
# Build a dictionary to store the latest evaluation datetime for: # Build a dictionary to store the latest evaluation datetime for:
# - Each tested model and its aggregated results # - Each tested model and its aggregated results
...@@ -360,7 +373,10 @@ class EvaluationTracker: ...@@ -360,7 +373,10 @@ class EvaluationTracker:
results_datetime, results_datetime,
) )
latest_task_results_datetime[samples_key] = latest_datetime latest_task_results_datetime[samples_key] = latest_datetime
latest_task_results_datetime[results_key] = latest_datetime latest_task_results_datetime[results_key] = max(
latest_task_results_datetime[results_key],
latest_datetime,
)
# Create metadata card # Create metadata card
card_metadata = MetadataConfigs() card_metadata = MetadataConfigs()
...@@ -377,14 +393,15 @@ class EvaluationTracker: ...@@ -377,14 +393,15 @@ class EvaluationTracker:
sanitized_last_eval_date_results = re.sub( sanitized_last_eval_date_results = re.sub(
r"[^\w\.]", "_", latest_task_results_datetime[config_name] r"[^\w\.]", "_", latest_task_results_datetime[config_name]
) )
# Ensure that all results files are listed in the metadata card
current_results = card_metadata.get(config_name, {"data_files": []})
current_results["data_files"].append(
{"split": eval_date_sanitized, "path": [str(results_filename)]}
)
card_metadata[config_name] = current_results
# If the results file is the newest, update the "latest" field in the metadata card
if eval_date_sanitized == sanitized_last_eval_date_results: if eval_date_sanitized == sanitized_last_eval_date_results:
# Ensure that all results files are listed in the metadata card
current_results = card_metadata.get(config_name, {"data_files": []})
current_results["data_files"].append(
{"split": eval_date_sanitized, "path": [str(results_filename)]}
)
card_metadata[config_name] = current_results
# If the results file is the newest, update the "latest" field in the metadata card
card_metadata[config_name]["data_files"].append( card_metadata[config_name]["data_files"].append(
{"split": "latest", "path": [str(results_filename)]} {"split": "latest", "path": [str(results_filename)]}
) )
...@@ -403,64 +420,65 @@ class EvaluationTracker: ...@@ -403,64 +420,65 @@ class EvaluationTracker:
sanitized_last_eval_date_results = re.sub( sanitized_last_eval_date_results = re.sub(
r"[^\w\.]", "_", latest_task_results_datetime[config_name] r"[^\w\.]", "_", latest_task_results_datetime[config_name]
) )
# Ensure that all sample results files are listed in the metadata card
current_details_for_task = card_metadata.get(
config_name, {"data_files": []}
)
current_details_for_task["data_files"].append(
{"split": eval_date_sanitized, "path": [str(results_filename)]}
)
card_metadata[config_name] = current_details_for_task
# If the samples results file is the newest, update the "latest" field in the metadata card
if eval_date_sanitized == sanitized_last_eval_date_results: if eval_date_sanitized == sanitized_last_eval_date_results:
print(f"adding {config_name} for {eval_date_sanitized}")
# Ensure that all sample results files are listed in the metadata card
current_details_for_task = card_metadata.get(
config_name, {"data_files": []}
)
current_details_for_task["data_files"].append(
{"split": eval_date_sanitized, "path": [str(results_filename)]}
)
card_metadata[config_name] = current_details_for_task
# If the samples results file is the newest, update the "latest" field in the metadata card
card_metadata[config_name]["data_files"].append( card_metadata[config_name]["data_files"].append(
{"split": "latest", "path": [str(results_filename)]} {"split": "latest", "path": [str(results_filename)]}
) )
# Special case for MMLU with a single split covering it all # Special case for MMLU with a single split covering it all
# We add another config with all MMLU splits results together for easy inspection # We add another config with all MMLU splits results together for easy inspection
SPECIAL_TASKS = ["mmlu", "gpqa", "minerva_math"] SPECIAL_TASKS = ["leaderboard_gpqa", "leaderboard_math", "leaderboard_bbh", "leaderboard_musr"]
for special_task in SPECIAL_TASKS: for special_task in SPECIAL_TASKS:
if special_task in config_name: if special_task in config_name:
special_task = f"{model_name}__{special_task}" special_task = f"{model_name}__{special_task}"
former_entry = card_metadata.get(special_task, {"data_files": []}) former_entry = card_metadata.get(special_task, {"data_files": []})
former_split = [
(i, entry)
for i, entry in enumerate(former_entry["data_files"])
if entry.get("split", None) == eval_date_sanitized
]
if len(former_split) == 0:
former_entry["data_files"].append(
{
"split": eval_date_sanitized,
"path": [str(results_filename)],
}
)
else:
split_index, _ = former_split[0]
former_entry["data_files"][split_index]["path"].append(
str(results_filename)
)
if eval_date_sanitized == sanitized_last_eval_date_results: former_split = [
latest_split = [
(i, entry) (i, entry)
for i, entry in enumerate(former_entry["data_files"]) for i, entry in enumerate(former_entry["data_files"])
if entry.get("split", None) == "latest" if entry.get("split", None) == eval_date_sanitized
] ]
if len(latest_split) == 0:
if len(former_split) == 0:
former_entry["data_files"].append( former_entry["data_files"].append(
{"split": "latest", "path": [str(results_filename)]} {
"split": eval_date_sanitized,
"path": [str(results_filename)],
}
) )
else: else:
latest_index, _ = latest_split[0] split_index, _ = former_split[0]
former_entry["data_files"][latest_index]["path"].append( former_entry["data_files"][split_index]["path"].append(
str(results_filename) str(results_filename)
) )
card_metadata[special_task] = former_entry if eval_date_sanitized == sanitized_last_eval_date_results:
latest_split = [
(i, entry)
for i, entry in enumerate(former_entry["data_files"])
if entry.get("split", None) == "latest"
]
if len(latest_split) == 0:
former_entry["data_files"].append(
{"split": "latest", "path": [str(results_filename)]}
)
else:
latest_index, _ = latest_split[0]
former_entry["data_files"][latest_index]["path"].append(
str(results_filename)
)
card_metadata[special_task] = former_entry
# Get latest results and extract info to update metadata card examples # Get latest results and extract info to update metadata card examples
latest_datetime = max(latest_task_results_datetime.values()) latest_datetime = max(latest_task_results_datetime.values())
......
...@@ -110,20 +110,3 @@ def add_env_info(storage: Dict[str, Any]): ...@@ -110,20 +110,3 @@ def add_env_info(storage: Dict[str, Any]):
"upper_git_hash": upper_dir_commit, # in case this repo is submodule "upper_git_hash": upper_dir_commit, # in case this repo is submodule
} }
storage.update(added_info) storage.update(added_info)
def add_tokenizer_info(storage: Dict[str, Any], lm):
if getattr(lm, "tokenizer", False):
tokenizer_info = {
"tokenizer_pad_token": [lm.tokenizer.pad_token, lm.tokenizer.pad_token_id],
"tokenizer_eos_token": [lm.tokenizer.eos_token, lm.tokenizer.eos_token_id],
"tokenizer_bos_token": [lm.tokenizer.bos_token, lm.tokenizer.bos_token_id],
"eot_token_id": getattr(lm, "eot_token_id", None),
"max_length": getattr(lm, "max_length", None),
}
storage.update(tokenizer_info)
# seems gguf and textsynth do not have tokenizer
else:
logger.debug(
"LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
)
...@@ -307,7 +307,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -307,7 +307,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
# defaults to os.environ.get("ANTHROPIC_API_KEY") # defaults to os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic() self.client = anthropic.Anthropic()
self.temperature = temperature self.temperature = temperature
self.max_tokens = max_tokens self.max_token = max_tokens
self.tokenizer = self.client.get_tokenizer() self.tokenizer = self.client.get_tokenizer()
self.kwargs = kwargs self.kwargs = kwargs
......
This diff is collapsed.
...@@ -288,7 +288,7 @@ class NEURON_HF(TemplateLM): ...@@ -288,7 +288,7 @@ class NEURON_HF(TemplateLM):
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.add_bos_token = add_bos_token self.add_bos_token = self.add_bos_token
self._max_length = max_length self._max_length = max_length
......
"""TextSynth API """ TextSynth API
Implementation provided by Fabrice Bellard: Implementation provided by Fabrice Bellard:
https://github.com/EleutherAI/lm-evaluation-harness/issues/295 https://github.com/EleutherAI/lm-evaluation-harness/issues/295
...@@ -11,7 +11,6 @@ Example usage: ...@@ -11,7 +11,6 @@ Example usage:
Homepage: https://textsynth.com/index.html Homepage: https://textsynth.com/index.html
""" """
import logging import logging
import os import os
......
...@@ -389,7 +389,7 @@ class Collator: ...@@ -389,7 +389,7 @@ class Collator:
self._arr_with_indices, fn=self._group_fn, group_by="contexts" self._arr_with_indices, fn=self._group_fn, group_by="contexts"
) )
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator: def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None, reset_batch_fn: Optional[Callable] = None, accelerator=None) -> Iterator:
""" """
Generates and yields batches from the reordered array. The method of grouping and batching Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`. depends on the parameter `group_by`.
...@@ -402,6 +402,8 @@ class Collator: ...@@ -402,6 +402,8 @@ class Collator:
- n (int): The size of each batch. Defaults to 1. - n (int): The size of each batch. Defaults to 1.
- batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
each batch. Optional, defaults to None. each batch. Optional, defaults to None.
- reset_batch_fn ([Callable[[int, Iterable], int]] | None): A function to reset the scheduler of
the batch_fn, if present, when we change group in generative mode.
Returns: Returns:
Iterator: An iterator over batches of reordered elements grouped as per the `group_by` Iterator: An iterator over batches of reordered elements grouped as per the `group_by`
...@@ -411,10 +413,9 @@ class Collator: ...@@ -411,10 +413,9 @@ class Collator:
List of batched elements according to the `group_by` attribute. List of batched elements according to the `group_by` attribute.
""" """
if self._group_by == "gen_kwargs": if self._group_by == "gen_kwargs":
for ( for key, values in self._arr_with_indices.items(): # type: ignore
key, if reset_batch_fn is not None: # with each group change, we must recompute the batch size, so we restart the scheduler
values, reset_batch_fn()
) in self._arr_with_indices.items(): # type: ignore
values = self._reorder(values) values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn) batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch yield from batch
......
...@@ -119,12 +119,6 @@ class VLLM(TemplateLM): ...@@ -119,12 +119,6 @@ class VLLM(TemplateLM):
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
) )
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
if "gemma" in pretrained.lower():
self.add_bos_token = True
eval_logger.info(
"Found 'gemma' in model name, a BOS token will be used as Gemma underperforms without it."
)
self.custom_prefix_token_id = prefix_token_id self.custom_prefix_token_id = prefix_token_id
if prefix_token_id is not None: if prefix_token_id is not None:
eval_logger.info( eval_logger.info(
...@@ -499,10 +493,7 @@ class VLLM(TemplateLM): ...@@ -499,10 +493,7 @@ class VLLM(TemplateLM):
def modify_gen_kwargs(kwargs: dict) -> dict: def modify_gen_kwargs(kwargs: dict) -> dict:
# sampling_params # sampling_params
do_sample = kwargs.pop("do_sample", None) do_sample = kwargs.pop("do_sample", None)
if do_sample is False and "temperature" not in kwargs: if do_sample is False or "temperature" not in kwargs:
eval_logger.debug(
"Got `do_sample=False` and no temperature value, setting VLLM temperature to 0.0 ..."
)
kwargs["temperature"] = 0.0 kwargs["temperature"] = 0.0
# hf defaults # hf defaults
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False) kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
......
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
| [aclue](aclue/README.md) | Tasks focusing on ancient Chinese language understanding and cultural aspects. | Ancient Chinese | | [aclue](aclue/README.md) | Tasks focusing on ancient Chinese language understanding and cultural aspects. | Ancient Chinese |
| [aexams](aexams/README.md) | Tasks in Arabic related to various academic exams covering a range of subjects. | Arabic | | [aexams](aexams/README.md) | Tasks in Arabic related to various academic exams covering a range of subjects. | Arabic |
| [agieval](agieval/README.md) | Tasks involving historical data or questions related to history and historical texts. | English, Chinese | | [agieval](agieval/README.md) | Tasks involving historical data or questions related to history and historical texts. | English, Chinese |
| [ammlu](ammlu/README.md) | Arabic version of MMLU. | Arabic |
| [anli](anli/README.md) | Adversarial natural language inference tasks designed to test model robustness. | English | | [anli](anli/README.md) | Adversarial natural language inference tasks designed to test model robustness. | English |
| [arabicmmlu](arabicmmlu/README.md) | Localized Arabic version of MMLU with multiple-choice questions from 40 subjects. | Arabic |
| [arc](arc/README.md) | Tasks involving complex reasoning over a diverse set of questions. | English | | [arc](arc/README.md) | Tasks involving complex reasoning over a diverse set of questions. | English |
| [arithmetic](arithmetic/README.md) | Tasks involving numerical computations and arithmetic reasoning. | English | | [arithmetic](arithmetic/README.md) | Tasks involving numerical computations and arithmetic reasoning. | English |
| [asdiv](asdiv/README.md) | Tasks involving arithmetic and mathematical reasoning challenges. | English | | [asdiv](asdiv/README.md) | Tasks involving arithmetic and mathematical reasoning challenges. | English |
...@@ -20,13 +20,11 @@ ...@@ -20,13 +20,11 @@
| [bbh](bbh/README.md) | Tasks focused on deep semantic understanding through hypothesization and reasoning. | English, German | | [bbh](bbh/README.md) | Tasks focused on deep semantic understanding through hypothesization and reasoning. | English, German |
| [belebele](belebele/README.md) | Language understanding tasks in a variety of languages and scripts. | Multiple (122 languages) | | [belebele](belebele/README.md) | Language understanding tasks in a variety of languages and scripts. | Multiple (122 languages) |
| benchmarks | General benchmarking tasks that test a wide range of language understanding capabilities. | | | benchmarks | General benchmarking tasks that test a wide range of language understanding capabilities. | |
| [bertaqa](bertaqa/README.md) | Local Basque cultural trivia QA tests in English and Basque languages. | English, Basque, Basque (MT) |
| [bigbench](bigbench/README.md) | Broad tasks from the BIG-bench benchmark designed to push the boundaries of large models. | Multiple | | [bigbench](bigbench/README.md) | Broad tasks from the BIG-bench benchmark designed to push the boundaries of large models. | Multiple |
| [blimp](blimp/README.md) | Tasks testing grammatical phenomena to evaluate language model's linguistic capabilities. | English | | [blimp](blimp/README.md) | Tasks testing grammatical phenomena to evaluate language model's linguistic capabilities. | English |
| [ceval](ceval/README.md) | Tasks that evaluate language understanding and reasoning in an educational context. | Chinese | | [ceval](ceval/README.md) | Tasks that evaluate language understanding and reasoning in an educational context. | Chinese |
| [cmmlu](cmmlu/README.md) | Multi-subject multiple choice question tasks for comprehensive academic assessment. | Chinese | | [cmmlu](cmmlu/README.md) | Multi-subject multiple choice question tasks for comprehensive academic assessment. | Chinese |
| code_x_glue | Tasks that involve understanding and generating code across multiple programming languages. | Go, Java, JS, PHP, Python, Ruby | | code_x_glue | Tasks that involve understanding and generating code across multiple programming languages. | Go, Java, JS, PHP, Python, Ruby |
| [commonsense_qa](commmonsense_qa/README.md) | CommonsenseQA, a multiple-choice QA dataset for measuring commonsense knowledge. | English |
| [copal_id](copal_id/README.md) | Indonesian causal commonsense reasoning dataset that captures local nuances. | Indonesian | | [copal_id](copal_id/README.md) | Indonesian causal commonsense reasoning dataset that captures local nuances. | Indonesian |
| [coqa](coqa/README.md) | Conversational question answering tasks to test dialog understanding. | English | | [coqa](coqa/README.md) | Conversational question answering tasks to test dialog understanding. | English |
| [crows_pairs](crows_pairs/README.md) | Tasks designed to test model biases in various sociodemographic groups. | English, French | | [crows_pairs](crows_pairs/README.md) | Tasks designed to test model biases in various sociodemographic groups. | English, French |
...@@ -73,7 +71,6 @@ ...@@ -73,7 +71,6 @@
| okapi/mmlu_multilingual | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (34 languages) | | okapi/mmlu_multilingual | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (34 languages) |
| [okapi/truthfulqa_multilingual](okapi/truthfulqa_multilingual/README.md) | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (31 languages) | | [okapi/truthfulqa_multilingual](okapi/truthfulqa_multilingual/README.md) | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (31 languages) |
| [openbookqa](openbookqa/README.md) | Open-book question answering tasks that require external knowledge and reasoning. | English | | [openbookqa](openbookqa/README.md) | Open-book question answering tasks that require external knowledge and reasoning. | English |
| [paloma](paloma/README.md) | Paloma is a comprehensive benchmark designed to evaluate open language models across a wide range of domains, ranging from niche artist communities to mental health forums on Reddit. | English |
| [paws-x](paws-x/README.md) | Paraphrase Adversaries from Word Scrambling, focusing on cross-lingual capabilities. | English, French, Spanish, German, Chinese, Japanese, Korean | | [paws-x](paws-x/README.md) | Paraphrase Adversaries from Word Scrambling, focusing on cross-lingual capabilities. | English, French, Spanish, German, Chinese, Japanese, Korean |
| [pile](pile/README.md) | Open source language modelling data set that consists of 22 smaller, high-quality datasets. | English | | [pile](pile/README.md) | Open source language modelling data set that consists of 22 smaller, high-quality datasets. | English |
| [pile_10k](pile_10k/README.md) | The first 10K elements of The Pile, useful for debugging models trained on it. | English | | [pile_10k](pile_10k/README.md) | The first 10K elements of The Pile, useful for debugging models trained on it. | English |
......
...@@ -14,43 +14,27 @@ class TaskManager: ...@@ -14,43 +14,27 @@ class TaskManager:
""" """
def __init__( def __init__(self, verbosity="INFO", include_path: Optional[str] = None) -> None:
self,
verbosity="INFO",
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
) -> None:
self.verbosity = verbosity self.verbosity = verbosity
self.include_path = include_path self.include_path = include_path
self.logger = utils.eval_logger self.logger = utils.eval_logger
self.logger.setLevel(getattr(logging, f"{verbosity}")) self.logger.setLevel(getattr(logging, f"{verbosity}"))
self._task_index = self.initialize_tasks( self._task_index = self.initialize_tasks(include_path=include_path)
include_path=include_path, include_defaults=include_defaults
)
self._all_tasks = sorted(list(self._task_index.keys())) self._all_tasks = sorted(list(self._task_index.keys()))
self.task_group_map = collections.defaultdict(list) self.task_group_map = collections.defaultdict(list)
def initialize_tasks( def initialize_tasks(self, include_path: Optional[str] = None):
self,
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
):
"""Creates a dictionary of tasks index. """Creates a dictionary of tasks index.
:param include_path: Union[str, List] = None :param include_path: str = None
An additional path to be searched for tasks recursively. An additional path to be searched for tasks
Can provide more than one such path as a list.
:param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
:return :return
Dictionary of task names as key and task metadata Dictionary of task names as key and task metadata
""" """
if include_defaults: all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
else:
all_paths = []
if include_path is not None: if include_path is not None:
if isinstance(include_path, str): if isinstance(include_path, str):
include_path = [include_path] include_path = [include_path]
...@@ -312,13 +296,8 @@ class TaskManager: ...@@ -312,13 +296,8 @@ class TaskManager:
:return :return
Dictionary of task names as key and task metadata Dictionary of task names as key and task metadata
""" """
ignore_dirs = [
"__pycache__",
".ipynb_checkpoints",
]
tasks_and_groups = collections.defaultdict() tasks_and_groups = collections.defaultdict()
for root, dirs, file_list in os.walk(task_dir): for root, _, file_list in os.walk(task_dir):
dirs[:] = [d for d in dirs if d not in ignore_dirs]
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
......
""" """
Take in a YAML, and output all other splits with this YAML Take in a YAML, and output all other splits with this YAML
""" """
import argparse import argparse
import os import os
......
""" """
Take in a YAML, and output all other splits with this YAML Take in a YAML, and output all other splits with this YAML
""" """
import argparse import argparse
import os import os
import re import re
......
""" """
Take in a YAML, and output all other splits with this YAML Take in a YAML, and output all other splits with this YAML
""" """
import argparse import argparse
import os import os
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment