Commit 3263c572 authored by lintangsutawika's avatar lintangsutawika
Browse files

Merge branch 'big-refactor' of https://github.com/EleutherAI/lm-evaluation-harness into squadv2

parents a27e8ed1 33d52483
...@@ -9,7 +9,7 @@ class DecontaminationFilter(Filter): ...@@ -9,7 +9,7 @@ class DecontaminationFilter(Filter):
name = "track_decontamination" name = "track_decontamination"
def __init__(self, path): def __init__(self, path) -> None:
""" """
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path"). TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
...@@ -17,7 +17,7 @@ class DecontaminationFilter(Filter): ...@@ -17,7 +17,7 @@ class DecontaminationFilter(Filter):
""" """
self._decontam_results = None self._decontam_results = None
def apply(self, reps): def apply(self, resps, docs) -> None:
""" """
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
""" """
......
...@@ -6,7 +6,9 @@ from lm_eval.api.filter import Filter ...@@ -6,7 +6,9 @@ from lm_eval.api.filter import Filter
class RegexFilter(Filter): class RegexFilter(Filter):
""" """ """ """
def __init__(self, regex_pattern=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"): def __init__(
self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]"
) -> None:
""" """
pass a string `regex` to run `re.compile(r"regex")` on. pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located. `fallback` defines the output returned if no matches for the regex are located.
...@@ -15,7 +17,7 @@ class RegexFilter(Filter): ...@@ -15,7 +17,7 @@ class RegexFilter(Filter):
self.regex = re.compile(regex_pattern) self.regex = re.compile(regex_pattern)
self.fallback = fallback self.fallback = fallback
def apply(self, resps): def apply(self, resps, docs):
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -41,12 +43,11 @@ class RegexFilter(Filter): ...@@ -41,12 +43,11 @@ class RegexFilter(Filter):
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
""" """ """ """
def __init__(self): def __init__(self) -> None:
pass pass
def apply(self, resps): def apply(self, resps, docs):
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
if resp.startswith(" "): if resp.startswith(" "):
......
...@@ -4,12 +4,12 @@ from lm_eval.api.filter import Filter ...@@ -4,12 +4,12 @@ from lm_eval.api.filter import Filter
class TakeFirstFilter(Filter): class TakeFirstFilter(Filter):
def __init__(self): def __init__(self) -> None:
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
def apply(self, resps): def apply(self, resps, docs):
""" """
Assuming each entry of `resps` is a list of model responses, we discard all but the first response. Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
""" """
...@@ -17,13 +17,12 @@ class TakeFirstFilter(Filter): ...@@ -17,13 +17,12 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter): class TakeKFilter(Filter):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def apply(self, resps): def apply(self, resps, docs):
# check we have at least k responses per doc, else we can't take the first k # check we have at least k responses per doc, else we can't take the first k
assert ( assert (
len(resps[0]) >= self.k len(resps[0]) >= self.k
...@@ -32,12 +31,12 @@ class TakeKFilter(Filter): ...@@ -32,12 +31,12 @@ class TakeKFilter(Filter):
class MajorityVoteFilter(Filter): class MajorityVoteFilter(Filter):
def __init__(self): def __init__(self) -> None:
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
def apply(self, resps): def apply(self, resps, docs):
""" """
Each entry of `resps` is a list of model responses. Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`. We select the response that occurs most frequently in each entry of `resps`.
......
...@@ -6,3 +6,5 @@ logging.basicConfig( ...@@ -6,3 +6,5 @@ logging.basicConfig(
level=logging.INFO, level=logging.INFO,
) )
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
SPACING = " " * 47
...@@ -76,7 +76,7 @@ class AnthropicLM(LM): ...@@ -76,7 +76,7 @@ class AnthropicLM(LM):
max_tokens_to_sample: int = 256, max_tokens_to_sample: int = 256,
temperature: float = 0, # defaults to 1 temperature: float = 0, # defaults to 1
**kwargs, # top_p, top_k, etc. **kwargs, # top_p, top_k, etc.
): ) -> None:
"""Anthropic API wrapper. """Anthropic API wrapper.
:param model: str :param model: str
...@@ -135,11 +135,10 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -135,11 +135,10 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def tok_decode(self, tokens: List[int]) -> str: def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def greedy_until(self, requests) -> List[str]: def greedy_until(self, requests) -> List[str]:
if not requests: if not requests:
return [] return []
......
...@@ -5,7 +5,7 @@ from lm_eval.api.registry import register_model ...@@ -5,7 +5,7 @@ from lm_eval.api.registry import register_model
@register_model("dummy") @register_model("dummy")
class DummyLM(LM): class DummyLM(LM):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
@classmethod @classmethod
......
import os
import torch import torch
import transformers import transformers
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
...@@ -20,7 +22,7 @@ from lm_eval.api.registry import register_model ...@@ -20,7 +22,7 @@ from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -67,6 +69,7 @@ class HFLM(LM): ...@@ -67,6 +69,7 @@ class HFLM(LM):
revision: Optional[str] = "main", revision: Optional[str] = "main",
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
truncation: Optional[bool] = False,
max_length: Optional[int] = None, max_length: Optional[int] = None,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
...@@ -75,6 +78,7 @@ class HFLM(LM): ...@@ -75,6 +78,7 @@ class HFLM(LM):
low_cpu_mem_usage: Optional[bool] = True, low_cpu_mem_usage: Optional[bool] = True,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: Optional[bool] = True,
cache_dir: Optional[Union[str, os.PathLike]] = None,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
...@@ -90,7 +94,7 @@ class HFLM(LM): ...@@ -90,7 +94,7 @@ class HFLM(LM):
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None, bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
gptq: Optional[Union[bool, str]] = False, gptq: Optional[Union[bool, str]] = False,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
): ) -> None:
super().__init__() super().__init__()
assert isinstance(device, str) assert isinstance(device, str)
...@@ -103,17 +107,20 @@ class HFLM(LM): ...@@ -103,17 +107,20 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set( device_list = set(
["cuda", "cpu", "mps"] ["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"]
) )
if device: if device:
if device not in device_list: if device not in device_list:
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
if device == "mps": if device in ("mps", "mps:0") and "dev" not in torch.__version__:
eval_logger.info( eval_logger.info(
"MPS is still in beta and only supports float32; setting dtype to float32." "MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
) )
else: else:
eval_logger.info("Device not specified") eval_logger.info("Device not specified")
...@@ -240,6 +247,8 @@ class HFLM(LM): ...@@ -240,6 +247,8 @@ class HFLM(LM):
use_fast=use_fast_tokenizer, use_fast=use_fast_tokenizer,
) )
self.truncation = truncation
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
...@@ -288,6 +297,13 @@ class HFLM(LM): ...@@ -288,6 +297,13 @@ class HFLM(LM):
eval_logger.info( eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore." "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
) )
else:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else: else:
self._model = accelerator.prepare_model( self._model = accelerator.prepare_model(
self.model, evaluation_mode=True self.model, evaluation_mode=True
...@@ -334,7 +350,7 @@ class HFLM(LM): ...@@ -334,7 +350,7 @@ class HFLM(LM):
return self._DEFAULT_MAX_LENGTH return self._DEFAULT_MAX_LENGTH
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
...@@ -353,7 +369,7 @@ class HFLM(LM): ...@@ -353,7 +369,7 @@ class HFLM(LM):
def world_size(self): def world_size(self):
return self._world_size return self._world_size
def _detect_batch_size(self, requests=None, pos=0): def _detect_batch_size(self, requests=None, pos: int = 0):
if requests: if requests:
_, context_enc, continuation_enc = requests[pos] _, context_enc, continuation_enc = requests[pos]
max_length = len( max_length = len(
...@@ -419,7 +435,11 @@ class HFLM(LM): ...@@ -419,7 +435,11 @@ class HFLM(LM):
return encoding return encoding
def tok_batch_encode( def tok_batch_encode(
self, strings: List[str], padding_side="left", left_truncate_len=None self,
strings: List[str],
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
): ):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
...@@ -432,6 +452,7 @@ class HFLM(LM): ...@@ -432,6 +452,7 @@ class HFLM(LM):
encoding = self.tokenizer( encoding = self.tokenizer(
strings, strings,
truncation=truncation,
padding="longest", padding="longest",
return_tensors="pt", return_tensors="pt",
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
...@@ -595,7 +616,9 @@ class HFLM(LM): ...@@ -595,7 +616,9 @@ class HFLM(LM):
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None): def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
...@@ -856,7 +879,9 @@ class HFLM(LM): ...@@ -856,7 +879,9 @@ class HFLM(LM):
# encode, pad, and truncate contexts for this batch # encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode( context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len contexts,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
) )
context_enc = context_enc.to(self.device) context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device) attn_masks = attn_masks.to(self.device)
......
...@@ -69,7 +69,7 @@ class OpenaiCompletionsLM(LM): ...@@ -69,7 +69,7 @@ class OpenaiCompletionsLM(LM):
engine: str = "text-davinci-003", engine: str = "text-davinci-003",
truncate: bool = False, truncate: bool = False,
batch_size: int = 1, batch_size: int = 1,
): ) -> None:
""" """
:param engine: str :param engine: str
...@@ -99,12 +99,12 @@ class OpenaiCompletionsLM(LM): ...@@ -99,12 +99,12 @@ class OpenaiCompletionsLM(LM):
return self.end_of_text_token_id return self.end_of_text_token_id
@property @property
def max_length(self): def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048 return 2048
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
...@@ -152,7 +152,7 @@ class OpenaiCompletionsLM(LM): ...@@ -152,7 +152,7 @@ class OpenaiCompletionsLM(LM):
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm=False self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
res = [] res = []
......
...@@ -41,7 +41,7 @@ def textsynth_completion(**kwargs): ...@@ -41,7 +41,7 @@ def textsynth_completion(**kwargs):
@register_model("textsynth") @register_model("textsynth")
class TextSynthLM(LM): class TextSynthLM(LM):
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate: bool = False) -> None:
""" """
:param engine: str :param engine: str
TextSynth API engine (e.g. `gptj_6B`) TextSynth API engine (e.g. `gptj_6B`)
...@@ -62,12 +62,12 @@ class TextSynthLM(LM): ...@@ -62,12 +62,12 @@ class TextSynthLM(LM):
raise NotImplementedError() raise NotImplementedError()
@property @property
def max_length(self): def max_length(self) -> int:
# NOTE: Turn on truncation to avoid errors on long inputs. # NOTE: Turn on truncation to avoid errors on long inputs.
return 2048 return 2048
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
......
import ast
from typing import Dict
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -5,7 +8,7 @@ from lm_eval.logger import eval_logger ...@@ -5,7 +8,7 @@ from lm_eval.logger import eval_logger
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name. # prompt category name, and prompt name.
# This allows us to access prompts # This allows us to access prompts
PROMPT_REGISTRY = { PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
"qa-basic": { "qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:", "question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {{question}}\nA:", "q-newline-a": "Q: {{question}}\nA:",
...@@ -13,7 +16,7 @@ PROMPT_REGISTRY = { ...@@ -13,7 +16,7 @@ PROMPT_REGISTRY = {
} }
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None): def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
# unpack prompt name # unpack prompt name
category_name, prompt_name = prompt_id.split(":") category_name, prompt_name = prompt_id.split(":")
if subset_name is None: if subset_name is None:
...@@ -63,6 +66,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa ...@@ -63,6 +66,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
else: else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name) prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)
category_name, prompt_name = use_prompt.split(":") category_name, *prompt_name = use_prompt.split(":")
# TODO allow to multiple prompt naming
# if len(prompt_name) > 1:
# prompt_list = []
# for prompt in prompt_name:
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
# else:
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
return [":".join([category_name, prompt]) for prompt in prompt_list] return [":".join([category_name, prompt]) for prompt in prompt_list]
...@@ -5,8 +5,8 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -5,8 +5,8 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] Glue - [x] Glue
- [x] SuperGlue - [x] SuperGlue
- [ ] CoQA (Lintang) - [x] CoQA
- [ ] DROP (Lintang) - [x] DROP
- [x] ~~Lambada~~ - [x] ~~Lambada~~
- [x] Lambada (Cloze variants) - [x] Lambada (Cloze variants)
- [x] ~~Lambada (Multilingual)~~ - [x] ~~Lambada (Multilingual)~~
...@@ -29,7 +29,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -29,7 +29,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] HeadQA - [x] HeadQA
- [x] MathQA - [x] MathQA
- [x] WebQs - [x] WebQs
- [ ] WSC273 (Lintang) - [x] WSC273
- [x] Winogrande - [x] Winogrande
- [x] ANLI - [x] ANLI
- [x] Hendrycks Ethics (missing some tasks/metrics, see PR 660: <https://github.com/EleutherAI/lm-evaluation-harness/pull/660> for more info) - [x] Hendrycks Ethics (missing some tasks/metrics, see PR 660: <https://github.com/EleutherAI/lm-evaluation-harness/pull/660> for more info)
...@@ -38,11 +38,11 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -38,11 +38,11 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] TruthfulQA (gen) - [x] TruthfulQA (gen)
- [ ] MuTual - [ ] MuTual
- [ ] Hendrycks Math (Hailey) - [ ] Hendrycks Math (Hailey)
- [ ] Asdiv - [x] Asdiv
- [ ] GSM8k - [ ] GSM8k
- [x] Arithmetic - [x] Arithmetic
- [ ] MMMLU (Hailey) - [ ] MMMLU (Hailey)
- [ ] Translation (WMT) suite (Hailey) - [x] Translation (WMT) suite
- [x] Unscramble - [x] Unscramble
- [x] ~~Pile (perplexity)~~ - [x] ~~Pile (perplexity)~~
- [x] BLiMP - [x] BLiMP
...@@ -56,7 +56,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -56,7 +56,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] XWinograd - [x] XWinograd
- [x] PAWS-X - [x] PAWS-X
- [x] XNLI - [x] XNLI
- [ ] MGSM (Lintang) - [x] MGSM
- [ ] SCROLLS - [ ] SCROLLS
- [x] Babi - [x] Babi
......
import os import os
import yaml import yaml
from typing import List, Union from typing import List, Union, Dict
from lm_eval import utils from lm_eval import utils
from lm_eval import prompts from lm_eval import prompts
...@@ -15,7 +15,7 @@ from lm_eval.api.registry import ( ...@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
) )
def register_configurable_task(config): def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type( SubClass = type(
config["task"] + "ConfigurableTask", config["task"] + "ConfigurableTask",
(ConfigurableTask,), (ConfigurableTask,),
...@@ -38,7 +38,7 @@ def register_configurable_task(config): ...@@ -38,7 +38,7 @@ def register_configurable_task(config):
return 0 return 0
def check_prompt_config(config): def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
all_configs = [] all_configs = []
if "use_prompt" in config: if "use_prompt" in config:
prompt_list = prompts.load_prompt_list( prompt_list = prompts.load_prompt_list(
...@@ -69,14 +69,14 @@ def check_prompt_config(config): ...@@ -69,14 +69,14 @@ def check_prompt_config(config):
return all_configs return all_configs
def get_task_name_from_config(task_config): def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "dataset_name" in task_config: if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config) return "{dataset_path}_{dataset_name}".format(**task_config)
else: else:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir): def include_task_folder(task_dir: str) -> None:
""" """
Calling this function Calling this function
""" """
...@@ -128,7 +128,7 @@ def get_task_name_from_object(task_object): ...@@ -128,7 +128,7 @@ def get_task_name_from_object(task_object):
# TODO: pass num_fewshot and other cmdline overrides in a better way # TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
config = {**kwargs} config = {**kwargs}
...@@ -136,6 +136,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -136,6 +136,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
task_name_from_config_dict = {} task_name_from_config_dict = {}
task_name_from_object_dict = {} task_name_from_object_dict = {}
if type(task_name_list) != list:
task_name_list = [task_name_list]
for task_element in task_name_list: for task_element in task_name_list:
if isinstance(task_element, str): if isinstance(task_element, str):
...@@ -143,12 +146,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -143,12 +146,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
group_name = task_element group_name = task_element
for task_name in GROUP_REGISTRY[task_element]: for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict: if task_name not in task_name_from_registry_dict:
task_obj = get_task_dict(task_name)
if task_name in task_obj.keys():
task_dict = {
task_name: (group_name, task_obj[task_name]),
}
else:
task_dict = {
task_name: (group_name, None),
**task_obj,
}
task_name_from_registry_dict = { task_name_from_registry_dict = {
**task_name_from_registry_dict, **task_name_from_registry_dict,
task_name: ( **task_dict,
group_name,
get_task(task_name=task_name, config=config),
),
} }
else: else:
task_name = task_element task_name = task_element
......
task: asdiv
dataset_path: EleutherAI/asdiv
output_type: loglikelihood
validation_split: validation
doc_to_text: "{{body}}\nQuestion:{{question}}\nAnswer:"
doc_to_target: "{{answer.split(' (')[0]}}"
should_decontaminate: true
doc_to_decontamination_query: "{{body}} {{question}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
# CoQA
### Paper
Title: `CoQA: A Conversational Question Answering Challenge`
Abstract: https://arxiv.org/pdf/1808.07042.pdf
CoQA is a large-scale dataset for building Conversational Question Answering
systems. The goal of the CoQA challenge is to measure the ability of machines to
understand a text passage and answer a series of interconnected questions that
appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/
### Citation
```
BibTeX-formatted citation goes here
```
### Groups and Tasks
#### Groups
* Not part of a group yet
#### Tasks
* `coqa`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
task: coqa
dataset_path: EleutherAI/coqa
output_type: greedy_until
training_split: train
validation_split: validation
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
should_decontaminate: true
doc_to_decontamination_query: "{{story}} {{question.input_text|join('\n')}}"
generation_kwargs:
until:
- "\nQ:"
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics
def doc_to_text(doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer
return doc_text
def doc_to_target(doc):
turn_id = len(doc["questions"]["input_text"])
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn)
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][
turn_id - 1
]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
def em(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
return em_sum / max(1, len(gold_list))
def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1)
f1_sum = 0.0
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def process_results(doc, results):
gold_list = doc_to_target(doc)
pred = results[0].strip().split("\n")[0]
scores = compute_scores(gold_list, pred)
return scores
# DROP
### Paper
Title: `DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs`
Abstract: https://aclanthology.org/attachments/N19-1246.Supplementary.pdf
DROP is a QA dataset which tests comprehensive understanding of paragraphs. In
this crowdsourced, adversarially-created, 96k question-answering benchmark, a
system must resolve multiple references in a question, map them onto a paragraph,
and perform discrete operations over them (such as addition, counting, or sorting).
Homepage: https://allenai.org/data/drop
Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
### Citation
```
@misc{dua2019drop,
title={DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs},
author={Dheeru Dua and Yizhong Wang and Pradeep Dasigi and Gabriel Stanovsky and Sameer Singh and Matt Gardner},
year={2019},
eprint={1903.00161},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
### Groups and Tasks
#### Groups
* Not part of a group yet.
#### Tasks
* `drop`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
task: drop
dataset_path: EleutherAI/drop
output_type: greedy_until
training_split: train
validation_split: validation
process_docs: !function utils.process_docs
doc_to_text: "{{passage}} {{question}}"
doc_to_target: "{{ answer|join(',')}}"
target_delimiter: ""
process_results: !function utils.process_results
should_decontaminate: true
doc_to_decontamination_query: "{{passage}} {{question}}"
generation_kwargs:
until:
- "."
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
import re
import string
import numpy as np
from scipy.optimize import linear_sum_assignment
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
def process_docs(dataset):
def _process(doc):
return {
"id": doc["query_id"],
"passage": doc["passage"],
"question": doc["question"],
"answers": get_answers(doc),
}
return dataset.map(_process)
def get_answers(doc):
def _flatten_validated_answers(validated_answers):
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
valid_answers = []
for i in range(len(validated_answers["number"])):
valid_answers.append(
{
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
}
)
return valid_answers
answers = []
answers_set = set()
candidates = [doc["answer"]] + _flatten_validated_answers(doc["validated_answers"])
for candidate in candidates:
answer = parse_answer(candidate)
if answer in answers_set:
continue
answers_set.add(answer)
answers.append(answer)
return answers
def parse_answer(answer):
# NOTE: Everything is returned as a tuple for uniformity and hashability.
if answer["number"] != "":
return (str(answer["number"]),)
if answer["spans"] != []:
return tuple(answer["spans"])
return (
" ".join(
[answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
).strip(),
)
def process_results(doc, results):
preds, golds = results, doc["answers"]
max_em = 0
max_f1 = 0
for gold_answer in golds:
exact_match, f1_score = get_metrics(preds, gold_answer)
if gold_answer[0].strip():
max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score)
return {"em": max_em, "f1": max_f1}
def get_metrics(predicted, gold):
"""
Takes a predicted answer and a gold answer (that are both either a string or a list of
strings), and returns exact match and the DROP F1 metric for the prediction. If you are
writing a script for evaluating objects in memory (say, the output of predictions during
validation, or while training), this is the function you want to call, after using
:func:`answer_json_to_strings` when reading the gold answer from the released data file.
"""
predicted_bags = _answer_to_bags(predicted)
gold_bags = _answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(
gold_bags[0]
):
exact_match = 1.0
else:
exact_match = 0.0
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
f1 = np.mean(f1_per_bag)
f1 = round(f1, 2)
return exact_match, f1
def _answer_to_bags(answer):
if isinstance(answer, (list, tuple)):
raw_spans = answer
else:
raw_spans = [answer]
normalized_spans = []
token_bags = []
for raw_span in raw_spans:
normalized_span = _normalize(raw_span)
normalized_spans.append(normalized_span)
token_bags.append(set(normalized_span.split()))
return normalized_spans, token_bags
def _align_bags(predicted, gold):
"""
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
between them and gets maximum metric values over all the answers.
"""
scores = np.zeros([len(gold), len(predicted)])
for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted):
if _match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item)
row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))])
for row, column in zip(row_ind, col_ind):
max_scores[row] = max(max_scores[row], scores[row, column])
return max_scores
def _compute_f1(predicted_bag, gold_bag):
intersection = len(gold_bag.intersection(predicted_bag))
if not predicted_bag:
precision = 1.0
else:
precision = intersection / float(len(predicted_bag))
if not gold_bag:
recall = 1.0
else:
recall = intersection / float(len(gold_bag))
f1 = (
(2 * precision * recall) / (precision + recall)
if not (precision == 0.0 and recall == 0.0)
else 0.0
)
return f1
def _match_numbers_if_present(gold_bag, predicted_bag):
gold_numbers = set()
predicted_numbers = set()
for word in gold_bag:
if _is_number(word):
gold_numbers.add(word)
for word in predicted_bag:
if _is_number(word):
predicted_numbers.add(word)
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
return True
return False
def _is_number(text):
try:
float(text)
return True
except ValueError:
return False
def _remove_articles(text):
return _ARTICLES.sub(" ", text)
def _white_space_fix(text):
return " ".join(text.split())
def _remove_punc(text):
exclude = set(string.punctuation)
if not _is_number(text):
return "".join(ch for ch in text if ch not in exclude)
else:
return text
def _fix_number(text):
return str(float(text)) if _is_number(text) else text
def _tokenize(text):
return re.split(" |-", text)
def _normalize(answer):
tokens = [
_white_space_fix(_remove_articles(_fix_number(_remove_punc(token.lower()))))
for token in _tokenize(answer)
]
tokens = [token for token in tokens if token.strip()]
normalized = " ".join(tokens).strip()
return normalized
def doc_to_text(doc): def doc_to_text(doc) -> str:
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
doc["premise"], doc["premise"],
doc["hypothesis"].strip() doc["hypothesis"].strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment