Unverified Commit 2387f39d authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'big-refactor' into flan-benchmark

parents 7601d828 784fe037
...@@ -5,7 +5,7 @@ from lm_eval.api.registry import register_model ...@@ -5,7 +5,7 @@ from lm_eval.api.registry import register_model
@register_model("dummy") @register_model("dummy")
class DummyLM(LM): class DummyLM(LM):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
@classmethod @classmethod
......
...@@ -22,7 +22,7 @@ from lm_eval.api.registry import register_model ...@@ -22,7 +22,7 @@ from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -94,7 +94,7 @@ class HFLM(LM): ...@@ -94,7 +94,7 @@ class HFLM(LM):
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None, bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
gptq: Optional[Union[bool, str]] = False, gptq: Optional[Union[bool, str]] = False,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
): ) -> None:
super().__init__() super().__init__()
assert isinstance(device, str) assert isinstance(device, str)
...@@ -294,6 +294,13 @@ class HFLM(LM): ...@@ -294,6 +294,13 @@ class HFLM(LM):
eval_logger.info( eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore." "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
) )
else:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else: else:
self._model = accelerator.prepare_model( self._model = accelerator.prepare_model(
self.model, evaluation_mode=True self.model, evaluation_mode=True
...@@ -340,7 +347,7 @@ class HFLM(LM): ...@@ -340,7 +347,7 @@ class HFLM(LM):
return self._DEFAULT_MAX_LENGTH return self._DEFAULT_MAX_LENGTH
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
...@@ -359,7 +366,7 @@ class HFLM(LM): ...@@ -359,7 +366,7 @@ class HFLM(LM):
def world_size(self): def world_size(self):
return self._world_size return self._world_size
def _detect_batch_size(self, requests=None, pos=0): def _detect_batch_size(self, requests=None, pos: int = 0):
if requests: if requests:
_, context_enc, continuation_enc = requests[pos] _, context_enc, continuation_enc = requests[pos]
max_length = len( max_length = len(
...@@ -428,9 +435,9 @@ class HFLM(LM): ...@@ -428,9 +435,9 @@ class HFLM(LM):
def tok_batch_encode( def tok_batch_encode(
self, self,
strings: List[str], strings: List[str],
padding_side="left", padding_side: str = "left",
left_truncate_len=None, left_truncate_len: int = None,
truncation=False, truncation: bool = False,
): ):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
...@@ -611,7 +618,9 @@ class HFLM(LM): ...@@ -611,7 +618,9 @@ class HFLM(LM):
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None): def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
......
...@@ -69,7 +69,7 @@ class OpenaiCompletionsLM(LM): ...@@ -69,7 +69,7 @@ class OpenaiCompletionsLM(LM):
engine: str = "text-davinci-003", engine: str = "text-davinci-003",
truncate: bool = False, truncate: bool = False,
batch_size: int = 1, batch_size: int = 1,
): ) -> None:
""" """
:param engine: str :param engine: str
...@@ -99,12 +99,12 @@ class OpenaiCompletionsLM(LM): ...@@ -99,12 +99,12 @@ class OpenaiCompletionsLM(LM):
return self.end_of_text_token_id return self.end_of_text_token_id
@property @property
def max_length(self): def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048 return 2048
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
...@@ -152,7 +152,7 @@ class OpenaiCompletionsLM(LM): ...@@ -152,7 +152,7 @@ class OpenaiCompletionsLM(LM):
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm=False self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
res = [] res = []
......
...@@ -41,7 +41,7 @@ def textsynth_completion(**kwargs): ...@@ -41,7 +41,7 @@ def textsynth_completion(**kwargs):
@register_model("textsynth") @register_model("textsynth")
class TextSynthLM(LM): class TextSynthLM(LM):
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate: bool = False) -> None:
""" """
:param engine: str :param engine: str
TextSynth API engine (e.g. `gptj_6B`) TextSynth API engine (e.g. `gptj_6B`)
...@@ -62,12 +62,12 @@ class TextSynthLM(LM): ...@@ -62,12 +62,12 @@ class TextSynthLM(LM):
raise NotImplementedError() raise NotImplementedError()
@property @property
def max_length(self): def max_length(self) -> int:
# NOTE: Turn on truncation to avoid errors on long inputs. # NOTE: Turn on truncation to avoid errors on long inputs.
return 2048 return 2048
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
......
...@@ -5,7 +5,7 @@ from lm_eval.logger import eval_logger ...@@ -5,7 +5,7 @@ from lm_eval.logger import eval_logger
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name. # prompt category name, and prompt name.
# This allows us to access prompts # This allows us to access prompts
PROMPT_REGISTRY = { PROMPT_REGISTRY: dict[str, dict[str, str]] = {
"qa-basic": { "qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:", "question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {{question}}\nA:", "q-newline-a": "Q: {{question}}\nA:",
...@@ -13,7 +13,7 @@ PROMPT_REGISTRY = { ...@@ -13,7 +13,7 @@ PROMPT_REGISTRY = {
} }
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None): def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
# unpack prompt name # unpack prompt name
category_name, prompt_name = prompt_id.split(":") category_name, prompt_name = prompt_id.split(":")
if subset_name is None: if subset_name is None:
......
...@@ -29,7 +29,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -29,7 +29,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] HeadQA - [x] HeadQA
- [x] MathQA - [x] MathQA
- [x] WebQs - [x] WebQs
- [ ] WSC273 (Lintang) - [x] WSC273
- [x] Winogrande - [x] Winogrande
- [x] ANLI - [x] ANLI
- [x] Hendrycks Ethics (missing some tasks/metrics, see PR 660: <https://github.com/EleutherAI/lm-evaluation-harness/pull/660> for more info) - [x] Hendrycks Ethics (missing some tasks/metrics, see PR 660: <https://github.com/EleutherAI/lm-evaluation-harness/pull/660> for more info)
...@@ -42,7 +42,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -42,7 +42,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] GSM8k - [ ] GSM8k
- [x] Arithmetic - [x] Arithmetic
- [ ] MMMLU (Hailey) - [ ] MMMLU (Hailey)
- [ ] Translation (WMT) suite (Hailey) - [x] Translation (WMT) suite
- [x] Unscramble - [x] Unscramble
- [x] ~~Pile (perplexity)~~ - [x] ~~Pile (perplexity)~~
- [x] BLiMP - [x] BLiMP
......
...@@ -15,7 +15,7 @@ from lm_eval.api.registry import ( ...@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
) )
def register_configurable_task(config): def register_configurable_task(config: dict[str, str]) -> int:
SubClass = type( SubClass = type(
config["task"] + "ConfigurableTask", config["task"] + "ConfigurableTask",
(ConfigurableTask,), (ConfigurableTask,),
...@@ -38,7 +38,7 @@ def register_configurable_task(config): ...@@ -38,7 +38,7 @@ def register_configurable_task(config):
return 0 return 0
def check_prompt_config(config): def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]:
all_configs = [] all_configs = []
if "use_prompt" in config: if "use_prompt" in config:
prompt_list = prompts.load_prompt_list( prompt_list = prompts.load_prompt_list(
...@@ -69,14 +69,14 @@ def check_prompt_config(config): ...@@ -69,14 +69,14 @@ def check_prompt_config(config):
return all_configs return all_configs
def get_task_name_from_config(task_config): def get_task_name_from_config(task_config: dict[str, str]) -> str:
if "dataset_name" in task_config: if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config) return "{dataset_path}_{dataset_name}".format(**task_config)
else: else:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir): def include_task_folder(task_dir: str) -> None:
""" """
Calling this function Calling this function
""" """
......
def doc_to_text(doc): def doc_to_text(doc) -> str:
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
doc["premise"], doc["premise"],
doc["hypothesis"].strip() doc["hypothesis"].strip()
......
...@@ -15,7 +15,7 @@ def _preproc_doc(doc): ...@@ -15,7 +15,7 @@ def _preproc_doc(doc):
return doc return doc
def doc_to_text(doc): def doc_to_text(doc) -> str:
doc = _preproc_doc(doc) doc = _preproc_doc(doc)
return f"Scenario 1: {doc['scenarios'][0]}\nScenario 2: {doc['scenarios'][1]}\nQuestion: Is Scenario 1 preferable?\nAnswer:" return f"Scenario 1: {doc['scenarios'][0]}\nScenario 2: {doc['scenarios'][1]}\nQuestion: Is Scenario 1 preferable?\nAnswer:"
......
def doc_to_text(doc): def doc_to_text(doc) -> str:
ctxs = "\n".join(doc["context"]["contexts"]) ctxs = "\n".join(doc["context"]["contexts"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format( return "Abstract: {}\nQuestion: {}\nAnswer:".format(
ctxs, doc["question"], doc["final_decision"] ctxs, doc["question"], doc["final_decision"]
) )
def doc_to_target(doc): def doc_to_target(doc) -> str:
return " {}".format(doc["final_decision"]) return " {}".format(doc["final_decision"])
......
import os import os
import sys
import json import json
import requests import requests
import numpy as np import numpy as np
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
try:
API_KEY = os.environ["PERSPECTIVE_API_KEY"]
except KeyError:
API_KEY = None
eval_logger.warning(
"PERSPECTIVE_API_KEY is not set. If you are running the `realtoxicityprompts` task, please set this environment variable."
)
url = f"https://commentanalyzer.googleapis.com/v1alpha1/comments:analyze?key={API_KEY}"
def toxicity_perspective_api(references, predictions, **kwargs): def toxicity_perspective_api(references, predictions, **kwargs):
"""Toxicity Perspective API is a metric which uses Perspective API to score the toxicity of a generated sentence. """Toxicity Perspective API is a metric which uses Perspective API to score the toxicity of a generated sentence.
...@@ -25,6 +14,16 @@ def toxicity_perspective_api(references, predictions, **kwargs): ...@@ -25,6 +14,16 @@ def toxicity_perspective_api(references, predictions, **kwargs):
Lower is better Lower is better
""" """
try:
API_KEY = os.environ["PERSPECTIVE_API_KEY"]
except KeyError:
eval_logger.warning(
"PERSPECTIVE_API_KEY is not set. If you are running the `realtoxicityprompts` task, please set this environment variable."
)
raise
url = f"https://commentanalyzer.googleapis.com/v1alpha1/comments:analyze?key={API_KEY}"
scores = [] scores = []
for pred in predictions: for pred in predictions:
data = { data = {
...@@ -35,7 +34,7 @@ def toxicity_perspective_api(references, predictions, **kwargs): ...@@ -35,7 +34,7 @@ def toxicity_perspective_api(references, predictions, **kwargs):
headers = { headers = {
"content-type": "application/json", "content-type": "application/json",
} }
req_response = requests.post(url, data=data, headers=headers) req_response = requests.post(url, json=data, headers=headers)
if req_response.ok: if req_response.ok:
response = json.loads(req_response.text) response = json.loads(req_response.text)
if ( if (
...@@ -54,6 +53,6 @@ def toxicity_perspective_api(references, predictions, **kwargs): ...@@ -54,6 +53,6 @@ def toxicity_perspective_api(references, predictions, **kwargs):
raise SystemExit(0) raise SystemExit(0)
else: else:
eval_logger.error("Unhandled Exception") eval_logger.error("Unhandled Exception")
raise SystemExit(0) req_response.raise_for_status()
return np.mean(scores) return np.mean(scores)
# Translation Tasks
### Paper
### Citation
```
```
### Groups and Tasks
#### Groups
* `gpt3_translation_tasks`
* `wmt14`
* `wmt16`
* `wmt20`
* `iwslt2017`
#### Tasks
*
### Checklist
For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
* [ ] Checked for equivalence with v0.3.0 LM Evaluation Harness
# Generated by utils.py
dataset_name: iwslt2017-en-ar
dataset_path: iwslt2017
doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'Arabic phrase: {{translation["ar"]}}
English phrase:'
group:
- greedy_until
- translation
- iwslt2017
include: wmt_common_yaml
task: iwslt2017-ar-en
# Generated by utils.py
dataset_name: iwslt2017-en-ar
dataset_path: iwslt2017
doc_to_target: ' {{translation["ar"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
Arabic phrase:'
group:
- greedy_until
- translation
- iwslt2017
include: wmt_common_yaml
task: iwslt2017-en-ar
import argparse
from typing import Dict, List
import yaml
import sacrebleu
try:
import pycountry
except ModuleNotFoundError:
raise Exception(
"`pycountry` is required for generating translation task prompt templates. \
please install pycountry via pip install lm-eval[multilingua] or pip install -e .[multilingual]",
)
# Different translation benchmarks included in the library. Mostly WMT.
# These correspond to dataset names (subsets) on HuggingFace for each dataset.
# A yaml file is generated by this script for each language pair.
gpt3_translation_benchmarks = {
"wmt14": ["fr-en"], # ["en-fr", "fr-en"], # French
"wmt16": [
"ro-en",
"de-en",
], # ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
}
# 28 total
LANGUAGES = {
**gpt3_translation_benchmarks,
# "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt2017": ["en-ar"], # Arabic
}
def code_to_language(code):
# key is alpha_2 or alpha_3 depending on the code length
language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
return language_tuple.name
def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
"""
Generate a yaml file for each language.
:param output_dir: The directory to output the files to.
:param overwrite: Whether to overwrite files if they already exist.
"""
err = []
for lang in LANGUAGES.keys():
for dataset_name in LANGUAGES[lang]:
src_lang, _, tgt_lang = dataset_name.partition("-")
for src, tgt in [[src_lang, tgt_lang], [tgt_lang, src_lang]]:
# both translation directions for each lang pair
lang_pair = src + "-" + tgt
file_name = f"{lang}_{lang_pair}.yaml"
try:
source, target = code_to_language(src), code_to_language(tgt)
groups = ["greedy_until", "translation", lang]
if lang in gpt3_translation_benchmarks.keys():
groups += ["gpt3_translation_benchmarks"]
with open(
f"{output_dir}/{file_name}",
"w" if overwrite else "x",
encoding="utf8",
) as f:
f.write("# Generated by utils.py\n")
yaml.dump(
{
"include": "wmt_common_yaml",
"group": groups,
"dataset_path": lang,
"dataset_name": dataset_name
if not (lang == "iwslt2017")
else "iwslt2017-" + dataset_name,
"task": f"{lang}-{lang_pair}",
"doc_to_text": f"{source} phrase: "
+ "{{translation["
+ f'"{src}"'
+ "]}}\n"
+ f"{target} phrase:",
"doc_to_target": " {{"
+ "translation["
+ f'"{tgt}"]'
+ "}}",
},
f,
)
except FileExistsError:
err.append(file_name)
if len(err) > 0:
raise FileExistsError(
"Files were not created because they already exist (use --overwrite flag):"
f" {', '.join(err)}"
)
def main() -> None:
"""Parse CLI args and generate language-specific yaml files."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
help="Overwrite files if they already exist",
)
parser.add_argument(
"--output-dir", default=".", help="Directory to write yaml files to"
)
args = parser.parse_args()
gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite)
if __name__ == "__main__":
main()
# Generated by utils.py
dataset_name: fr-en
dataset_path: wmt14
doc_to_target: ' {{translation["fr"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
French phrase:'
group:
- greedy_until
- translation
- wmt14
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt14-en-fr
# Generated by utils.py
dataset_name: fr-en
dataset_path: wmt14
doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'French phrase: {{translation["fr"]}}
English phrase:'
group:
- greedy_until
- translation
- wmt14
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt14-fr-en
# Generated by utils.py
dataset_name: de-en
dataset_path: wmt16
doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'German phrase: {{translation["de"]}}
English phrase:'
group:
- greedy_until
- translation
- wmt16
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt16-de-en
# Generated by utils.py
dataset_name: de-en
dataset_path: wmt16
doc_to_target: ' {{translation["de"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
German phrase:'
group:
- greedy_until
- translation
- wmt16
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt16-en-de
# Generated by utils.py
dataset_name: ro-en
dataset_path: wmt16
doc_to_target: ' {{translation["ro"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
Romanian phrase:'
group:
- greedy_until
- translation
- wmt16
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt16-en-ro
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment