Commit c31c4300 authored by lintangsutawika's avatar lintangsutawika
Browse files

Merge branch 'recursive-groups' of...

Merge branch 'recursive-groups' of https://github.com/EleutherAI/lm-evaluation-harness into t5v2-alt-plus
parents f7f298ee 6282a1be
...@@ -172,7 +172,7 @@ lm_eval --model openai-completions \ ...@@ -172,7 +172,7 @@ lm_eval --model openai-completions \
--tasks lambada_openai,hellaswag --tasks lambada_openai,hellaswag
``` ```
We also support using your own local inference server with an implemented version of the OpenAI ChatCompletions endpoint and passing trained HuggingFace artifacts and tokenizers. We also support using your own local inference server with servers that mirror the OpenAI Completions and ChatCompletions APIs.
```bash ```bash
lm_eval --model local-chat-completions --tasks gsm8k --model_args model=facebook/opt-125m,base_url=http://{yourip}:8000/v1 lm_eval --model local-chat-completions --tasks gsm8k --model_args model=facebook/opt-125m,base_url=http://{yourip}:8000/v1
...@@ -181,7 +181,7 @@ Note that for externally hosted models, configs such as `--device` and `--batch_ ...@@ -181,7 +181,7 @@ Note that for externally hosted models, configs such as `--device` and `--batch_
| API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: | | API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: |
|---------------------------------------------------------------------------------------------------------------------------|---------------------------------|---------------------------------------------------------------------|-----------------------------------------------------------------------------------------------|------------------------------------------------------------| |---------------------------------------------------------------------------------------------------------------------------|---------------------------------|---------------------------------------------------------------------|-----------------------------------------------------------------------------------------------|------------------------------------------------------------|
| OpenAI Completions | :heavy_check_mark: | `openai-completions` | up to `code-davinci-002` | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | OpenAI Completions | :heavy_check_mark: | `openai-completions`, `local-completions` | All OpenAI Completions API models | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| OpenAI ChatCompletions | :heavy_check_mark: | `openai-chat-completions`, `local-chat-completions` | [All ChatCompletions API models](https://platform.openai.com/docs/guides/gpt) | `generate_until` (no logprobs) | | OpenAI ChatCompletions | :heavy_check_mark: | `openai-chat-completions`, `local-chat-completions` | [All ChatCompletions API models](https://platform.openai.com/docs/guides/gpt) | `generate_until` (no logprobs) |
| Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `generate_until` (no logprobs) | | Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `generate_until` (no logprobs) |
| Textsynth | :heavy_check_mark: | `textsynth` | [All supported engines](https://textsynth.com/documentation.html#engines) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | Textsynth | :heavy_check_mark: | `textsynth` | [All supported engines](https://textsynth.com/documentation.html#engines) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
...@@ -189,9 +189,12 @@ Note that for externally hosted models, configs such as `--device` and `--batch_ ...@@ -189,9 +189,12 @@ Note that for externally hosted models, configs such as `--device` and `--batch_
| [Llama.cpp](https://github.com/ggerganov/llama.cpp) (via [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)) | :heavy_check_mark: | `gguf`, `ggml` | [All models supported by llama.cpp](https://github.com/ggerganov/llama.cpp) | `generate_until`, `loglikelihood`, (perplexity evaluation not yet implemented) | | [Llama.cpp](https://github.com/ggerganov/llama.cpp) (via [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)) | :heavy_check_mark: | `gguf`, `ggml` | [All models supported by llama.cpp](https://github.com/ggerganov/llama.cpp) | `generate_until`, `loglikelihood`, (perplexity evaluation not yet implemented) |
| vLLM | :heavy_check_mark: | `vllm` | [Most HF Causal Language Models](https://docs.vllm.ai/en/latest/models/supported_models.html) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | vLLM | :heavy_check_mark: | `vllm` | [Most HF Causal Language Models](https://docs.vllm.ai/en/latest/models/supported_models.html) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Mamba | :heavy_check_mark: | `mamba_ssm` | [Mamba architecture Language Models via the `mamba_ssm` package](https://huggingface.co/state-spaces) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | Mamba | :heavy_check_mark: | `mamba_ssm` | [Mamba architecture Language Models via the `mamba_ssm` package](https://huggingface.co/state-spaces) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Your local inference server! | :heavy_check_mark: | `local-chat-completions` (using `openai-chat-completions` model type) | Any server address that accepts GET requests using HF models and mirror's OpenAI's ChatCompletions interface | `generate_until` | | ... | | Your local inference server! | :heavy_check_mark: | `local-completions` or `local-chat-completions` (using `openai-chat-completions` model type) | Any server address that accepts GET requests using HF models and mirror's OpenAI's ChatCompletions interface | `generate_until` | | ... |
| `local-completions` (using `openai-completions` model type) | Any server address that accepts GET requests using HF models and mirror's OpenAI's Completions interface | `generate_until` | | ... |
It is on our roadmap to create task variants designed to enable models which do not serve logprobs/loglikelihoods to be compared with generation performance of open-source models. Models which do not supply logits or logprobs can be used with tasks of type `generate_until` only, while models that are local or APIs that supply logprobs/logits can be run on all task types: `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
For more information on the different task `output_types` and model request types, see [our documentation](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/model_guide.md#interface).
### Other Frameworks ### Other Frameworks
......
...@@ -232,6 +232,10 @@ If you would like to run evaluation on all prompt templates, you can simply call ...@@ -232,6 +232,10 @@ If you would like to run evaluation on all prompt templates, you can simply call
use_prompt: "promptsource:*" use_prompt: "promptsource:*"
``` ```
### Weighting evaluation based on task size
By default, all tasks are aggregated by simple average (A group of 2 task with the same metric will simply be summed and divided by 2 for its group metric). You might find it necessary to aggregate multiple task scores by their weight. To do this, you can set within the task config `weight_by_size` to `True` to have its scores be weighted by the number of samples it has.
### Setting metrics ### Setting metrics
You're almost done! Now we need to choose how to score our task. You're almost done! Now we need to choose how to score our task.
......
...@@ -40,6 +40,21 @@ ALL_OUTPUT_TYPES = [ ...@@ -40,6 +40,21 @@ ALL_OUTPUT_TYPES = [
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
@dataclass
class GroupConfig(dict):
group: str = None
task: Union[str, list] = None
weight_by_size: bool = False
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self):
return asdict(self)
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
...@@ -80,7 +95,7 @@ class TaskConfig(dict): ...@@ -80,7 +95,7 @@ class TaskConfig(dict):
filter_list: Union[str, list] = None filter_list: Union[str, list] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: str = None doc_to_decontamination_query: str = None
weight_by_size: bool = False
metadata: Union[ metadata: Union[
str, list str, list
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks ] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
......
...@@ -124,7 +124,7 @@ def simple_evaluate( ...@@ -124,7 +124,7 @@ def simple_evaluate(
for task_name in task_dict.keys(): for task_name in task_dict.keys():
task_obj = task_dict[task_name] task_obj = task_dict[task_name]
if type(task_obj) == tuple: if type(task_obj) == tuple:
group, task_obj = task_obj _, task_obj = task_obj
if task_obj is None: if task_obj is None:
continue continue
...@@ -160,11 +160,16 @@ def simple_evaluate( ...@@ -160,11 +160,16 @@ def simple_evaluate(
) )
if lm.rank == 0: if lm.rank == 0:
if isinstance(model, str):
model_name = model
elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
model_name = model.config._name_or_path
else:
model_name = type(model).__name__
# add info about the model and few shot config # add info about the model and few shot config
results["config"] = { results["config"] = {
"model": model "model": model_name,
if isinstance(model, str)
else model.model.config._name_or_path,
"model_args": model_args, "model_args": model_args,
"batch_size": batch_size, "batch_size": batch_size,
"batch_sizes": list(lm.batch_sizes.values()) "batch_sizes": list(lm.batch_sizes.values())
...@@ -482,10 +487,7 @@ def evaluate( ...@@ -482,10 +487,7 @@ def evaluate(
if "alias" in metrics: if "alias" in metrics:
metrics.pop("alias") metrics.pop("alias")
# TODO: There should be a way for users if ("weight_by_size" in configs) and configs[task]["weight_by_size"]:
# to toggle between weighted and
# unweighted averaging
if weight_by_size:
current_size = metrics.pop("samples") current_size = metrics.pop("samples")
else: else:
metrics.pop("samples") metrics.pop("samples")
......
...@@ -42,7 +42,7 @@ class MambaLMWrapper(HFLM): ...@@ -42,7 +42,7 @@ class MambaLMWrapper(HFLM):
The HFLM arguments The HFLM arguments
`backend`, `revision`, `subfolder`, `tokenizer`, `truncation`, `max_length`, `backend`, `tokenizer`, `truncation`, `max_length`,
`device`, `dtype`, `batch_size`, `max_batch_size`, `trust_remote_code`, `use_fast_tokenizer` `device`, `dtype`, `batch_size`, `max_batch_size`, `trust_remote_code`, `use_fast_tokenizer`
Are all supported by Mamba where they do not conflict Are all supported by Mamba where they do not conflict
...@@ -98,7 +98,6 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba ...@@ -98,7 +98,6 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
pretrained, pretrained,
device=self._device, device=self._device,
dtype=torch.float16 if dtype == "auto" else utils.get_dtype(dtype), dtype=torch.float16 if dtype == "auto" else utils.get_dtype(dtype),
**kwargs,
) )
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
......
...@@ -2,14 +2,14 @@ import copy ...@@ -2,14 +2,14 @@ import copy
import os import os
from collections import defaultdict from collections import defaultdict
from importlib.util import find_spec from importlib.util import find_spec
from typing import List, Optional, Tuple from typing import List, Literal, Optional, Tuple
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions from lm_eval.utils import eval_logger, retry_on_specific_exceptions
def get_result(response, ctxlen: int) -> Tuple[float, bool]: def get_result(response, ctxlen: int) -> Tuple[float, bool]:
...@@ -40,7 +40,7 @@ def get_result(response, ctxlen: int) -> Tuple[float, bool]: ...@@ -40,7 +40,7 @@ def get_result(response, ctxlen: int) -> Tuple[float, bool]:
return continuation_logprobs, is_greedy return continuation_logprobs, is_greedy
def oa_completion(**kwargs): def oa_completion(client, chat: bool = False, **kwargs):
"""Query OpenAI API for completion. """Query OpenAI API for completion.
Retry with back-off until they respond Retry with back-off until they respond
...@@ -64,19 +64,24 @@ def oa_completion(**kwargs): ...@@ -64,19 +64,24 @@ def oa_completion(**kwargs):
on_exception_callback=_exception_callback, on_exception_callback=_exception_callback,
) )
def completion(): def completion():
return openai.completions.create(**kwargs) if chat:
return client.chat.completions.create(**kwargs)
else:
return client.completions.create(**kwargs)
return completion() return completion()
@register_model("openai-completions") @register_model("openai-completions", "local-completions")
class OpenaiCompletionsLM(LM): class OpenaiCompletionsLM(LM):
REQ_CHUNK_SIZE = 20
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
self, self,
model: str, model: str,
base_url: str = None,
tokenizer: Optional[str] = None,
tokenizer_backend: Literal["tiktoken", "huggingface"] = "tiktoken",
truncate: bool = False, truncate: bool = False,
max_gen_toks: int = 256, max_gen_toks: int = 256,
batch_size: int = 1, batch_size: int = 1,
...@@ -101,15 +106,44 @@ class OpenaiCompletionsLM(LM): ...@@ -101,15 +106,44 @@ class OpenaiCompletionsLM(LM):
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`", please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
) )
self.model = model self.model = model
self.tokenizer = tiktoken.encoding_for_model(self.model) self.base_url = base_url
self.vocab_size = self.tokenizer.n_vocab self.tokenizer_backend = tokenizer_backend
self.truncate = truncate self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token self._batch_size = batch_size
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
self._max_length = max_length self._max_length = max_length
# if we have a local model, use HF tokenizer over tiktoken
if self.tokenizer_backend == "huggingface":
import transformers # noqa: E401
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer if tokenizer else self.model
)
self.vocab_size = self.tokenizer.vocab
self.end_of_text_token_id = self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken":
if self.base_url:
eval_logger.warning(
f"Passed `base_url={self.base_url}` but using Tiktoken tokenizer backend. "
"Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
)
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.end_of_text_token_id = self.tokenizer.eot_token
else:
raise ValueError(
f"Expected tokenizer_backend to be one of ['tiktoken', 'huggingface'] but got {self.tokenizer_backend}"
)
# Read from environment variable OPENAI_API_KEY # Read from environment variable OPENAI_API_KEY
# Set to EMPTY for local
openai.api_key = os.environ["OPENAI_API_KEY"] openai.api_key = os.environ["OPENAI_API_KEY"]
if self.base_url:
self.client = openai.OpenAI(base_url=self.base_url)
else:
self.client = openai.OpenAI()
@property @property
def eot_token_id(self): def eot_token_id(self):
...@@ -127,9 +161,8 @@ class OpenaiCompletionsLM(LM): ...@@ -127,9 +161,8 @@ class OpenaiCompletionsLM(LM):
return self._max_gen_toks return self._max_gen_toks
@property @property
def batch_size(self): def batch_size(self) -> int:
# Isn't used because we override _loglikelihood_tokens return self._batch_size
raise NotImplementedError()
@property @property
def device(self): def device(self):
...@@ -186,7 +219,7 @@ class OpenaiCompletionsLM(LM): ...@@ -186,7 +219,7 @@ class OpenaiCompletionsLM(LM):
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm( for chunk in tqdm(
list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)), list(utils.chunks(re_ord.get_reordered(), self.batch_size)),
disable=disable_tqdm, disable=disable_tqdm,
): ):
inps = [] inps = []
...@@ -203,6 +236,7 @@ class OpenaiCompletionsLM(LM): ...@@ -203,6 +236,7 @@ class OpenaiCompletionsLM(LM):
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
response = oa_completion( response = oa_completion(
client=self.client,
model=self.model, model=self.model,
prompt=inps, prompt=inps,
echo=True, echo=True,
...@@ -251,7 +285,7 @@ class OpenaiCompletionsLM(LM): ...@@ -251,7 +285,7 @@ class OpenaiCompletionsLM(LM):
# todo: more intelligent batching for heterogeneous `until` # todo: more intelligent batching for heterogeneous `until`
for chunk, request_args in tqdm( for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)) list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size))
): ):
inps = [] inps = []
self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks) self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks)
...@@ -265,6 +299,7 @@ class OpenaiCompletionsLM(LM): ...@@ -265,6 +299,7 @@ class OpenaiCompletionsLM(LM):
request_args["temperature"] = request_args.get("temperature", 0) request_args["temperature"] = request_args.get("temperature", 0)
response = oa_completion( response = oa_completion(
client=self.client,
model=self.model, model=self.model,
prompt=inps, prompt=inps,
max_tokens=self.max_gen_toks, max_tokens=self.max_gen_toks,
...@@ -329,35 +364,6 @@ class OpenaiCompletionsLM(LM): ...@@ -329,35 +364,6 @@ class OpenaiCompletionsLM(LM):
return loglikelihoods return loglikelihoods
def oa_chat_completion(client, **kwargs):
"""Query OpenAI API for chat completion.
Retry with back-off until they respond
"""
if not find_spec("openai") or not find_spec("tiktoken"):
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
"Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
)
else:
import openai
def _exception_callback(e: Exception, sleep_time: float) -> None:
import traceback
traceback.print_exc()
@retry_on_specific_exceptions(
on_exceptions=[openai.OpenAIError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return client.chat.completions.create(**kwargs)
return completion()
@register_model("openai-chat-completions", "local-chat-completions") @register_model("openai-chat-completions", "local-chat-completions")
class OpenaiChatCompletionsLM(LM): class OpenaiChatCompletionsLM(LM):
def __init__( def __init__(
...@@ -460,8 +466,12 @@ class OpenaiChatCompletionsLM(LM): ...@@ -460,8 +466,12 @@ class OpenaiChatCompletionsLM(LM):
f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}" f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}"
) )
response = oa_chat_completion( response = oa_completion(
client=self.client, messages=inps, model=self.model, **kwargs client=self.client,
chat=True,
messages=inps,
model=self.model,
**kwargs,
) )
for resp, (context, args_) in zip(response.choices, chunk): for resp, (context, args_) in zip(response.choices, chunk):
......
...@@ -3,34 +3,32 @@ import abc ...@@ -3,34 +3,32 @@ import abc
import yaml import yaml
import collections import collections
from functools import partial, lru_cache from functools import partial
from typing import List, Union, Dict from typing import List, Union, Dict
from lm_eval import utils from lm_eval import utils
from lm_eval import prompts from lm_eval import prompts
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
from lm_eval.api.registry import (
register_task,
register_group,
TASK_REGISTRY,
GROUP_REGISTRY,
)
import logging import logging
# import python tasks # # import python tasks
from .squadv2.task import SQuAD2 # import squadv2.task
from .scrolls.task import ( # import scrolls.task
QuALITY, # python_tasks = {
NarrativeQA, # "squadv2": squadv2.task.SQuAD2,
ContractNLI, # "scrolls_quality": scrolls.task.QuALITY,
GovReport, # "scrolls_narrativeqa": scrolls.task.NarrativeQA,
SummScreenFD, # "scrolls_contractnli": scrolls.task.ContractNLI,
QMSum, # "scrolls_govreport": scrolls.task.GovReport,
) # "scrolls_summscreenfd": scrolls.task.SummScreenFD,
# "scrolls_qmsum": scrolls.task.QMSum,
# }
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
GROUP_KEYS = ["group", "task", "weight_by_size"]
PYTHON_TASK_KEYS = ["task", "class"]
class TaskManager(abc.ABC): class TaskManager(abc.ABC):
...@@ -72,15 +70,25 @@ class TaskManager(abc.ABC): ...@@ -72,15 +70,25 @@ class TaskManager(abc.ABC):
return False return False
def _name_is_task(self, name): def _name_is_task(self, name):
if self.ALL_TASKS[name]["type"] == "task": if self._name_is_registered(name) and ("task" in self.ALL_TASKS[name]["type"]):
return True
return False
def _name_is_python_task(self, name):
if self._name_is_registered(name) and (self.ALL_TASKS[name]["type"] == "python_task"):
return True return True
return False return False
def _config_is_task(self, config): def _config_is_task(self, config):
if list(config.keys()) == ["group", "task"]: if set(config.keys()) <= set(GROUP_KEYS):
return False return False
return True return True
def _config_is_python_task(self, config):
if set(config.keys()) == set(PYTHON_TASK_KEYS):
return True
return False
def _get_yaml_path(self, name): def _get_yaml_path(self, name):
assert name in self.ALL_TASKS assert name in self.ALL_TASKS
return self.ALL_TASKS[name]["yaml_path"] return self.ALL_TASKS[name]["yaml_path"]
...@@ -94,47 +102,75 @@ class TaskManager(abc.ABC): ...@@ -94,47 +102,75 @@ class TaskManager(abc.ABC):
assert self._name_is_task(name) == False assert self._name_is_task(name) == False
return self.ALL_TASKS[name]["task"] return self.ALL_TASKS[name]["task"]
@lru_cache(None) def _load_individual_task_or_group(
def _load_individual_task_or_group(self, name_or_config: Union[str, dict] = None, parent_name: str = None) -> ConfigurableTask: self,
name_or_config: Union[str, dict] = None,
parent_name: str = None,
update_config: dict = None
) -> ConfigurableTask:
def load_task(config, task, group=None): def load_task(config, task, group=None, is_python_class=False):
task_object = ConfigurableTask(config=config) if is_python_class:
task_object = config["class"]()
else:
task_object = ConfigurableTask(config=config)
if group is not None: if group is not None:
task_object = (group, task_object) task_object = (group, task_object)
return {task: task_object} return {task: task_object}
if isinstance(name_or_config, str): if isinstance(name_or_config, str):
if self._name_is_task(name_or_config): if update_config is not None:
# Process name_or_config as a dict instead
name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config) task_config = self._get_config(name_or_config)
return load_task(task_config, task=name_or_config, group=parent_name) is_python_class=False
if self._name_is_python_task(name_or_config):
is_python_class=True
return load_task(task_config, task=name_or_config, group=parent_name, is_python_class=is_python_class)
else: else:
group_name = name_or_config group_name = name_or_config
subtask_list = self._get_tasklist(name_or_config) subtask_list = self._get_tasklist(name_or_config)
if subtask_list == -1: if subtask_list == -1:
subtask_list = self._get_config(name_or_config)["task"] subtask_list = self._get_config(name_or_config)["task"]
elif isinstance(name_or_config, dict): if isinstance(name_or_config, dict):
if update_config is not None:
name_or_config={
**name_or_config,
**update_config,
}
if self._config_is_task(name_or_config): if self._config_is_task(name_or_config):
task_name = name_or_config["task"] name = name_or_config["task"]
if self._name_is_registered(task_name): # If the name is registered as a group
base_task_config = self._get_config(task_name) if self._name_is_task(name) is False:
task_config={ group_name = name
**base_task_config, update_config = {k:v for k,v in name_or_config.items() if k != "task"}
**name_or_config, subtask_list = self._get_tasklist(name)
} if subtask_list == -1:
subtask_list = self._get_config(name)["task"]
else: else:
task_config = name_or_config if self._name_is_registered(name):
return load_task(task_config, task=name_or_config, group=parent_name) base_task_config = self._get_config(name)
task_config={
**base_task_config,
**name_or_config,
}
else:
task_config = name_or_config
return load_task(task_config, task=name, group=parent_name)
else: else:
group_name = name_or_config["group"] group_name = name_or_config["group"]
subtask_list = name_or_config["task"] subtask_list = name_or_config["task"]
if self._get_yaml_path(group_name) == -1: if (self._name_is_registered(group_name) is False) or (self._get_yaml_path(group_name) == -1):
all_subtasks = {group_name: (parent_name, None)} all_subtasks = {group_name: (parent_name, None)}
else: else:
all_subtasks = {} all_subtasks = {}
fn = partial(self._load_individual_task_or_group, parent_name=group_name) fn = partial(self._load_individual_task_or_group, parent_name=group_name, update_config=update_config)
all_subtasks = {**all_subtasks, **dict(collections.ChainMap(*map(fn, subtask_list)))} all_subtasks = {**all_subtasks, **dict(collections.ChainMap(*map(fn, subtask_list)))}
return all_subtasks return all_subtasks
...@@ -161,7 +197,13 @@ class TaskManager(abc.ABC): ...@@ -161,7 +197,13 @@ class TaskManager(abc.ABC):
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
config = utils.simple_load_yaml_config(yaml_path) config = utils.simple_load_yaml_config(yaml_path)
if list(config.keys()) == ["group", "task"]: if set(config.keys()) == set(PYTHON_TASK_KEYS):
# This is a python class config
tasks_and_groups[config["task"]] = {
"type": "python_task",
"yaml_path": yaml_path,
}
elif set(config.keys()) <= set(GROUP_KEYS):
# This is a group config # This is a group config
tasks_and_groups[config["group"]] = { tasks_and_groups[config["group"]] = {
"type": "group", "type": "group",
......
group: group:
- ai2_arc - ai2_arc
task: arc_easy task: arc_easy
dataset_path: ai2_arc dataset_path: allenai/ai2_arc
dataset_name: ARC-Easy dataset_name: ARC-Easy
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
......
group: grouptest group: grouptest
task: task:
- boolq - boolq
- group: arc_stuff - group: random_collection
task: task:
- arc_challenge - ai2_arc
- glue - task: cola
- task: arc_easy
metric_list:
- metric: acc
num_fewshot: 3 num_fewshot: 3
# - task: mmlu - task: mmlu
# num_fewshot: 2 num_fewshot: 2
...@@ -41,5 +41,6 @@ metric_list: ...@@ -41,5 +41,6 @@ metric_list:
- metric: accuracy - metric: accuracy
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
hf_evaluate: true
metadata: metadata:
version: 1.0 version: 1.0
group: qasper group: qasper
task: qasper_bool task: qasper_bool
dataset_path: qasper dataset_path: allenai/qasper
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
......
group: qasper group: qasper
task: qasper_freeform task: qasper_freeform
dataset_path: qasper dataset_path: allenai/qasper
output_type: generate_until output_type: generate_until
training_split: train training_split: train
validation_split: validation validation_split: validation
......
group: scrolls group: scrolls
task: task:
- scrolls_qasper # - task: scrolls_qasper
- scrolls_quality # class: !function task.Qasper
- scrolls_narrativeqa - task: scrolls_quality
- scrolls_contractnli class: !function task.QuALITY
- scrolls_govreport # - scrolls_narrativeqa
- scrolls_summscreenfd # class: !function task.NarrativeQA
- scrolls_qmsum # - scrolls_contractnli
# class: !function task.ContractNLI
# - scrolls_govreport
# class: !function task.GovReport
# - scrolls_summscreenfd
# class: !function task.SummScreenFD
# - scrolls_qmsum
# class: !function task.QMSum
...@@ -279,7 +279,7 @@ class _SCROLLSSummaryTask(_SCROLLSTask): ...@@ -279,7 +279,7 @@ class _SCROLLSSummaryTask(_SCROLLSTask):
return f"{doc['input']}\n\nQuestion: What is a summary of the preceding text?\nAnswer:" return f"{doc['input']}\n\nQuestion: What is a summary of the preceding text?\nAnswer:"
@register_task("scrolls_qasper") # @register_task("scrolls_qasper")
class Qasper(_SCROLLSTask): class Qasper(_SCROLLSTask):
"""A Dataset of Information-Seeking Questions and Answers Anchored in Research Papers """A Dataset of Information-Seeking Questions and Answers Anchored in Research Papers
https://arxiv.org/abs/2105.03011 https://arxiv.org/abs/2105.03011
...@@ -337,7 +337,7 @@ class Qasper(_SCROLLSTask): ...@@ -337,7 +337,7 @@ class Qasper(_SCROLLSTask):
) )
@register_task("scrolls_quality") # @register_task("scrolls_quality")
class QuALITY(_SCROLLSMultipleChoiceTask): class QuALITY(_SCROLLSMultipleChoiceTask):
"""QuALITY: Question Answering with Long Input Texts, Yes! """QuALITY: Question Answering with Long Input Texts, Yes!
https://arxiv.org/abs/2112.08608 https://arxiv.org/abs/2112.08608
...@@ -366,7 +366,7 @@ class QuALITY(_SCROLLSMultipleChoiceTask): ...@@ -366,7 +366,7 @@ class QuALITY(_SCROLLSMultipleChoiceTask):
return [doc] return [doc]
@register_task("scrolls_narrativeqa") # @register_task("scrolls_narrativeqa")
class NarrativeQA(_SCROLLSTask): class NarrativeQA(_SCROLLSTask):
"""The NarrativeQA Reading Comprehension Challenge """The NarrativeQA Reading Comprehension Challenge
https://arxiv.org/abs/1712.07040 https://arxiv.org/abs/1712.07040
...@@ -400,7 +400,7 @@ class NarrativeQA(_SCROLLSTask): ...@@ -400,7 +400,7 @@ class NarrativeQA(_SCROLLSTask):
) )
@register_task("scrolls_contractnli") # @register_task("scrolls_contractnli")
class ContractNLI(_SCROLLSMultipleChoiceTask): class ContractNLI(_SCROLLSMultipleChoiceTask):
"""ContractNLI: A Dataset for Document-level Natural Language Inference for Contracts """ContractNLI: A Dataset for Document-level Natural Language Inference for Contracts
https://arxiv.org/abs/1712.07040 https://arxiv.org/abs/1712.07040
...@@ -419,7 +419,7 @@ class ContractNLI(_SCROLLSMultipleChoiceTask): ...@@ -419,7 +419,7 @@ class ContractNLI(_SCROLLSMultipleChoiceTask):
return f"{doc['text']}\n\nHypothesis: {doc['question']}\nConclusion:" return f"{doc['text']}\n\nHypothesis: {doc['question']}\nConclusion:"
@register_task("scrolls_govreport") # @register_task("scrolls_govreport")
class GovReport(_SCROLLSSummaryTask): class GovReport(_SCROLLSSummaryTask):
"""Efficient Attentions for Long Document Summarization """Efficient Attentions for Long Document Summarization
https://arxiv.org/abs/2104.02112 https://arxiv.org/abs/2104.02112
...@@ -433,7 +433,7 @@ class GovReport(_SCROLLSSummaryTask): ...@@ -433,7 +433,7 @@ class GovReport(_SCROLLSSummaryTask):
DATASET_NAME = "gov_report" DATASET_NAME = "gov_report"
@register_task("scrolls_summscreenfd") # @register_task("scrolls_summscreenfd")
class SummScreenFD(_SCROLLSSummaryTask): class SummScreenFD(_SCROLLSSummaryTask):
"""SummScreen: A Dataset for Abstractive Screenplay Summarization """SummScreen: A Dataset for Abstractive Screenplay Summarization
https://arxiv.org/abs/2104.07091 https://arxiv.org/abs/2104.07091
...@@ -442,7 +442,7 @@ class SummScreenFD(_SCROLLSSummaryTask): ...@@ -442,7 +442,7 @@ class SummScreenFD(_SCROLLSSummaryTask):
DATASET_NAME = "summ_screen_fd" DATASET_NAME = "summ_screen_fd"
@register_task("scrolls_qmsum") # @register_task("scrolls_qmsum")
class QMSum(_SCROLLSSummaryTask): class QMSum(_SCROLLSSummaryTask):
"""QMSum: A New Benchmark for Query-based Multi-domain """QMSum: A New Benchmark for Query-based Multi-domain
Meeting Summarization Meeting Summarization
......
task: squadv2
class: !function task.SQuAD2
\ No newline at end of file
...@@ -21,7 +21,6 @@ from packaging import version ...@@ -21,7 +21,6 @@ from packaging import version
from lm_eval.api.task import Task from lm_eval.api.task import Task
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task
_CITATION = """ _CITATION = """
@misc{rajpurkar2018know, @misc{rajpurkar2018know,
...@@ -47,7 +46,6 @@ def _squad_agg(key, items): ...@@ -47,7 +46,6 @@ def _squad_agg(key, items):
return _squad_metric(predictions=predictions, references=references).get(key, 0) return _squad_metric(predictions=predictions, references=references).get(key, 0)
@register_task("squadv2")
class SQuAD2(Task): class SQuAD2(Task):
VERSION = 3 VERSION = 3
DATASET_PATH = "squad_v2" DATASET_PATH = "squad_v2"
......
...@@ -480,28 +480,10 @@ def get_git_commit_hash(): ...@@ -480,28 +480,10 @@ def get_git_commit_hash():
return git_hash return git_hash
def import_function(loader, node):
function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name)
*module_name, function_name = function_name.split(".")
if isinstance(module_name, list):
module_name = ".".join(module_name)
module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name)))
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
function = getattr(module, function_name)
return function
def ignore_constructor(loader, node):
return node
def simple_load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): def simple_load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
def ignore_constructor(loader, node):
return node
yaml.add_constructor("!function", ignore_constructor) yaml.add_constructor("!function", ignore_constructor)
with open(yaml_path, "rb") as file: with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file) yaml_config = yaml.full_load(file)
...@@ -509,6 +491,24 @@ def simple_load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): ...@@ -509,6 +491,24 @@ def simple_load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
def import_function(loader, node):
function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name)
*module_name, function_name = function_name.split(".")
if isinstance(module_name, list):
module_name = ".".join(module_name)
module_path = os.path.normpath(
os.path.join(yaml_path, "{}.py".format(module_name))
)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
function = getattr(module, function_name)
return function
# Add the import_function constructor to the YAML loader # Add the import_function constructor to the YAML loader
yaml.add_constructor("!function", import_function) yaml.add_constructor("!function", import_function)
if yaml_config is None: if yaml_config is None:
......
...@@ -22,12 +22,13 @@ def load_changed_files(file_path: str) -> List[str]: ...@@ -22,12 +22,13 @@ def load_changed_files(file_path: str) -> List[str]:
# checks the txt file for list of changed files. # checks the txt file for list of changed files.
# if file ends with .yaml then check yaml for task name # if file ends with .yaml then check yaml for task name
# if file ends with .py then parse the folder for all yaml files # if file ends with .py then parse the folder for all yaml files
# skips benchmarks folder
def parser(full_path: List[str]) -> List[str]: def parser(full_path: List[str]) -> List[str]:
_output = set() _output = set()
for x in full_path: for x in full_path:
if x.endswith(".yaml"): if x.endswith(".yaml") and "benchmarks" not in x:
_output.add(load_yaml_config(x)["task"]) _output.add(load_yaml_config(x)["task"])
elif x.endswith(".py"): elif x.endswith(".py") and "benchmarks" not in x:
path = [str(x) for x in (list(Path(x).parent.glob("*.yaml")))] path = [str(x) for x in (list(Path(x).parent.glob("*.yaml")))]
_output |= {load_yaml_config(x)["task"] for x in path} _output |= {load_yaml_config(x)["task"] for x in path}
return list(_output) return list(_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment