Commit a702689d authored by Alexander's avatar Alexander
Browse files

merge with upstream

parents 8d66cfef 008fc2a2
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Custom TriviaQA because HF version sanitizes the dataset differently.
# https://github.com/huggingface/datasets/blob/9977ade72191ff0b6907ec63935448c6269a91a1/datasets/trivia_qa/trivia_qa.py#L285
"""TriviaQA (Unfiltered Raw) dataset."""
import json
import os
import datasets
_CITATION = """\
@InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
month = {July},
year = {2017},
address = {Vancouver, Canada},
publisher = {Association for Computational Linguistics},
}
"""
_DESCRIPTION = """\
TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.
"""
_HOMEPAGE = "https://nlp.cs.washington.edu/triviaqa/"
_LICENSE = "Apache License 2.0"
_URLS = "https://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz"
class Triviaqa(datasets.GeneratorBasedBuilder):
"""TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence triples"""
VERSION = datasets.Version("0.0.2")
BUILDER_CONFIGS = [
datasets.BuilderConfig(
name="triviaqa", version=VERSION, description="The TriviaQA dataset"
),
]
def _info(self):
features = datasets.Features(
{
"question_id": datasets.Value("string"),
"question_source": datasets.Value("string"),
"question": datasets.Value("string"),
"answer": {
"aliases": datasets.features.Sequence(
datasets.Value("string"),
),
"value": datasets.Value("string"),
},
"search_results": datasets.features.Sequence(
{
"description": datasets.Value("string"),
"filename": datasets.Value("string"),
"rank": datasets.Value("int32"),
"title": datasets.Value("string"),
"url": datasets.Value("string"),
"search_context": datasets.Value("string"),
}
),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
urls = _URLS
data_dir = dl_manager.download_and_extract(urls)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(
data_dir, "triviaqa-unfiltered", "unfiltered-web-train.json"
),
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(
data_dir, "triviaqa-unfiltered", "unfiltered-web-dev.json"
),
},
),
]
# method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
def _generate_examples(self, filepath):
with open(filepath, encoding="utf-8") as f:
json_data = json.load(f)["Data"]
for key, data in enumerate(json_data):
search_results = []
for search_result in data["SearchResults"]:
search_results.append(
{
"description": search_result["Description"]
if "Description" in search_result
else "",
"filename": search_result["Filename"]
if "Filename" in search_result
else "",
"rank": search_result["Rank"]
if "Rank" in search_result
else -1,
"title": search_result["Title"]
if "Title" in search_result
else "",
"url": search_result["Url"]
if "Url" in search_result
else "",
"search_context": search_result["SearchContext"]
if "SearchContext" in search_result
else "",
}
)
yield key, {
"question_id": data["QuestionId"],
"question_source": data["QuestionSource"],
"question": data["Question"],
"answer": {
"aliases": data["Answer"]["Aliases"],
"value": data["Answer"]["Value"],
},
"search_results": search_results,
}
......@@ -42,8 +42,7 @@ addition, or deletion of characters, and asking it to recover the original word.
_HOMEPAGE = "https://github.com/openai/gpt-3/tree/master/data"
# TODO: Add the licence for the dataset here if you can find it
_LICENSE = ""
_LICENSE = "No license found"
_BASE_URL = "https://raw.githubusercontent.com/openai/gpt-3/master/data"
......
import collections
import itertools
import numpy as np
import random
import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
from lm_eval.utils import positional_deprecated, run_task_tests
from lm_eval.models.gpt2 import HFLM
import numpy as np
import transformers
@positional_deprecated
......@@ -16,6 +20,7 @@ def simple_evaluate(
tasks=[],
num_fewshot=0,
batch_size=None,
max_batch_size=None,
device=None,
no_cache=False,
limit=None,
......@@ -30,7 +35,7 @@ def simple_evaluate(
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
Name of model, transformers.PreTrainedModel object, or LM object, see lm_eval.models.get_model
:param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object.
......@@ -38,8 +43,10 @@ def simple_evaluate(
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param batch_size: int, optional
:param batch_size: int or str, optional
Batch size for model
:param max_batch_size: int, optional
Maximal batch size to try with automatic batch size detection
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool
......@@ -68,8 +75,24 @@ def simple_evaluate(
if model_args is None:
model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(
<<<<<<< HEAD
model_args, {"batch_size": batch_size, "device": device, "tokenizer": tokenizer, "trust_remote_code": True}
=======
model_args,
{
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
},
)
elif isinstance(model, transformers.PreTrainedModel):
lm = lm_eval.models.get_model("hf-causal")(
pretrained=model,
batch_size=batch_size,
max_batch_size=max_batch_size,
>>>>>>> origin/master
)
no_cache = True
else:
assert isinstance(model, lm_eval.base.LM)
lm = model
......@@ -78,7 +101,7 @@ def simple_evaluate(
lm = lm_eval.base.CachingLM(
lm,
"lm_cache/"
+ model
+ (model if isinstance(model, str) else model.model.config._name_or_path)
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
......@@ -102,11 +125,19 @@ def simple_evaluate(
)
# add info about the model and few shot config
model_name = None
if isinstance(model, str):
model_name = model
elif isinstance(model, transformers.PreTrainedModel):
model_name = "pretrained=" + model.config._name_or_path
results["config"] = {
"model": model,
"model": model_name,
"model_args": model_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size,
"batch_sizes": list(lm.batch_sizes.values())
if hasattr(lm, "batch_sizes")
else [],
"device": device,
"no_cache": no_cache,
"limit": limit,
......
......@@ -4,6 +4,7 @@ from . import anthropic_llms
from . import huggingface
from . import textsynth
from . import dummy
from . import gguf
MODEL_REGISTRY = {
"hf": gpt2.HFLM,
......@@ -15,7 +16,7 @@ MODEL_REGISTRY = {
"anthropic": anthropic_llms.AnthropicLM,
"textsynth": textsynth.TextSynthLM,
"dummy": dummy.DummyLM,
"optimum-causal": gpt2.OPTIMUMLM,
"gguf": gguf.GGUFLM
}
......
......@@ -4,7 +4,9 @@ from tqdm import tqdm
import time
def anthropic_completion(client, model, prompt, max_tokens_to_sample, temperature, stop):
def anthropic_completion(
client, model, prompt, max_tokens_to_sample, temperature, stop
):
"""Query Anthropic API for completion.
Retry with back-off until they respond
......@@ -14,7 +16,7 @@ def anthropic_completion(client, model, prompt, max_tokens_to_sample, temperatur
backoff_time = 3
while True:
try:
response = client.completion(
response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
......@@ -24,7 +26,7 @@ def anthropic_completion(client, model, prompt, max_tokens_to_sample, temperatur
temperature=temperature,
)
print(response)
return response["completion"]
return response.completion
except RuntimeError:
# TODO: I don't actually know what error Anthropic raises when it times out
# So err update this error when we find out.
......@@ -38,7 +40,7 @@ def anthropic_completion(client, model, prompt, max_tokens_to_sample, temperatur
class AnthropicLM(BaseLM):
REQ_CHUNK_SIZE = 20
def __init__(self, model):
def __init__(self, model="claude-2"):
"""
:param model: str
......@@ -46,8 +48,9 @@ class AnthropicLM(BaseLM):
"""
super().__init__()
import anthropic
self.model = model
self.client = anthropic.Client(os.environ['ANTHROPIC_API_KEY'])
self.client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
@property
def eot_token_id(self):
......
import requests
import logging
import time
from tqdm import tqdm
from requests.exceptions import RequestException
import transformers
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM
logger = logging.getLogger(__name__)
def get_result(logprobs, context_length):
is_greedy = True
offsets = logprobs['text_offset']
tokens = logprobs['tokens']
tokens_logprobs = logprobs['token_logprobs']
idx = 0
while offsets[idx] < context_length:
idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)):
token = tokens[i]
top_tokens = logprobs["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
class GGUFLM(BaseLM):
def __init__(self, base_url, max_length=2048):
super().__init__()
self.base_url = base_url
self.logprobs = 10
self.temperature = 0.0
self.max_length = max_length
def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
for _ in range(retries):
try:
prompt = context
request = {'prompt': prompt, 'logprobs': self.logprobs,
'temperature': self.temperature}
if continuation:
prompt += continuation
request.update({'prompt': prompt, 'max_tokens': 1, 'echo': True})
if stop is not None:
request['stop'] = stop
response = requests.post(f"{self.base_url}/v1/completions", json=request)
response.raise_for_status()
return response.json()
except RequestException as e:
logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying
else:
raise Exception(f"Failed to get a valid response after {retries} retries.")
def loglikelihood(self, requests):
if not requests:
return []
res = []
for context, continuation in tqdm(requests):
response = self.gguf_completion(context=context, continuation=continuation)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
logprobs = choice.get("logprobs")
if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]:
logprob, is_greedy = get_result(logprobs, len(context))
res.append((logprob, is_greedy))
else:
logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
else:
logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False
return res
def greedy_until(self, requests):
if not requests:
return []
res = []
for request in tqdm(requests):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = self.gguf_completion(context=inp, stop=until)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
if "text" in choice:
generated_text = choice["text"].strip()
res.append(generated_text)
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError("loglikelihood_rolling not yet supported for GGUF models")
def _model_call(self, inps):
# Placeholder implementation
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Placeholder implementation
raise NotImplementedError()
def tok_encode(self, string: str):
raise NotImplementedError()
def tok_decode(self, tokens):
raise NotImplementedError()
@property
def batch_size(self):
# Placeholder implementation
raise NotImplementedError()
@property
def device(self):
# Placeholder implementation
raise NotImplementedError()
@property
def eot_token_id(self):
# Placeholder implementation
raise NotImplementedError()
def max_length(self):
return self.max_length
@property
def max_gen_toks(self):
# Placeholder implementation
raise NotImplementedError()
......@@ -6,9 +6,7 @@ import optimum
from optimum.intel.openvino import OVModelForCausalLM
def _get_dtype(
dtype: Union[str, torch.dtype]
) -> torch.dtype:
def _get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
......@@ -19,6 +17,9 @@ def _get_dtype(
class HFLM(BaseLM):
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
device="cuda",
......@@ -28,67 +29,91 @@ class HFLM(BaseLM):
subfolder=None,
tokenizer=None,
batch_size=1,
max_batch_size=512,
max_length=None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
dtype: Optional[Union[str, torch.dtype]]="auto",
dtype: Optional[Union[str, torch.dtype]] = "auto",
):
super().__init__()
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, (int, str))
# Initialize model
if isinstance(pretrained, transformers.PreTrainedModel):
self.model = pretrained
self._device = self.model.device
if tokenizer:
assert isinstance(
tokenizer, transformers.PreTrainedTokenizer
) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
self.tokenizer = tokenizer
else:
# Get tokenizer
model_name = self.model.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
)
elif isinstance(pretrained, str):
# Initialize device
assert isinstance(device, str)
device_list = set(
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device and device in device_list:
self._device = torch.device(device)
print(f"Using device '{device}'")
else:
print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
revision = revision + ("/" + subfolder if subfolder is not None else "")
# Initialize new model and tokenizer instances
self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision,
torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code,
).to(self.device)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer if tokenizer else pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
)
device_list = set(
["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device and device in device_list:
self._device = torch.device(device)
print(f"Using device '{device}'")
else:
print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
raise TypeError(
"Parameter pretrained should be of type str or transformers.PreTrainedModel"
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision,
torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code,
).to(self.device)
self.gpt2.eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
)
self.model.eval()
self.vocab_size = self.tokenizer.vocab_size
if isinstance(
self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# Validate batch_size
assert isinstance(batch_size, (int, str))
# setup for automatic batch size detection
if batch_size == "auto":
self.batch_size_per_gpu = batch_size
if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":")
self.batch_size_per_gpu = batch_size[0]
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
else:
self.batch_size_per_gpu = int(batch_size)
self.max_batch_size = max_batch_size
self._max_length = max_length
@property
def eot_token_id(self):
......@@ -97,11 +122,17 @@ class HFLM(BaseLM):
@property
def max_length(self):
try:
return self.gpt2.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
return self.gpt2.config.max_position_embeddings
if self._max_length: # if max length manually set, return it
return self._max_length
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self.model.config, attr):
return getattr(self.model.config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property
def max_gen_toks(self):
......@@ -132,14 +163,16 @@ class HFLM(BaseLM):
logits returned from the model
"""
with torch.no_grad():
return self.gpt2(inps)[0]
return self.model(inps)[0]
def _model_generate(self, context, max_length, eos_token_id):
generation_kwargs = {"do_sample": False, "max_length": max_length}
if eos_token_id is not None:
generation_kwargs['eos_token_id'] = eos_token_id
generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
return self.gpt2.generate(context, **generation_kwargs)
generation_kwargs["eos_token_id"] = eos_token_id
generation_kwargs[
"pad_token_id"
] = eos_token_id # setting eos_token_id as pad token
return self.model.generate(context, **generation_kwargs)
# for backwards compatibility
......
......@@ -198,14 +198,13 @@ class GPT3LM(BaseLM):
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
response = oa_completion(
engine=self.engine,
prompt=inps,
max_tokens=self.max_gen_toks,
temperature=0.0,
logprobs=10,
stop=until,
stop=until["until"],
)
for resp, (context, until_) in zip(response.choices, chunk):
......
......@@ -9,7 +9,6 @@ from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm
from transformers import BatchEncoding
from accelerate import find_executable_batch_size
from lm_eval import utils
from lm_eval.base import BaseLM
......@@ -76,10 +75,12 @@ class HuggingFaceAutoLM(BaseLM):
subfolder: Optional[str] = None,
revision: Optional[str] = "main",
batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 512,
max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None,
use_accelerate: Optional[bool] = False,
low_cpu_mem_usage: Optional[bool] = True,
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
......@@ -91,6 +92,10 @@ class HuggingFaceAutoLM(BaseLM):
load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False,
inject_fused_attention: Optional[bool] = True,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
bnb_4bit_use_double_quant: Optional[bool] = False,
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.
Args:
......@@ -111,6 +116,8 @@ class HuggingFaceAutoLM(BaseLM):
use_accelerate (bool, optional, defaults to False):
If True, uses the `accelerate` library to load a large model across
multiple devices.
low_cpu_mem_usage (bool, optional, defaults to True):
It True, uses the `accelerate` library to accelerate loading the model.
device_map_option (str, optional, defaults to "auto"):
The device map option to use when loading the model with
`accelerate`.
......@@ -152,6 +159,18 @@ class HuggingFaceAutoLM(BaseLM):
If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False):
Use Triton for GPTQ inference.
inject_fused_attention (bool, optional, defaults to True):
Inject fused attention into GPTQ model.
bnb_4bit_quant_type (str, optional, defaults to None):
The quantization type to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L77
bnb_4bit_compute_dtype (Union[str, torch.dtype], optional, defaults to None):
The compute dtype to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L74
bnb_4bit_use_double_quant (bool, optional, defaults to False):
Whether or not to use double quant to quantize the absmax.
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L80
"""
super().__init__()
......@@ -172,10 +191,13 @@ class HuggingFaceAutoLM(BaseLM):
), "Evaluating causal models with `add_special_tokens=True` is currently not supported."
# setup for automatic batch size detection
if batch_size == "auto":
self._batch_size = batch_size
if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":")
self._batch_size = batch_size[0]
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
else:
self._batch_size = int(batch_size)
self.max_batch_size = max_batch_size
self._max_gen_toks = max_gen_toks
self._max_length = max_length
......@@ -191,6 +213,7 @@ class HuggingFaceAutoLM(BaseLM):
revision=revision,
subfolder=subfolder,
tokenizer=tokenizer,
trust_remote_code=trust_remote_code,
)
self.tokenizer.model_max_length = self.max_length
......@@ -210,8 +233,13 @@ class HuggingFaceAutoLM(BaseLM):
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
gptq_use_triton=gptq_use_triton,
inject_fused_attention=inject_fused_attention,
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
low_cpu_mem_usage=low_cpu_mem_usage,
**model_kwargs,
)
# note: peft_path can be different than pretrained model path
......@@ -232,8 +260,13 @@ class HuggingFaceAutoLM(BaseLM):
# the user specified one so we force `self._device` to be the same as
# `lm_head`'s.
self._device = self.model.hf_device_map["lm_head"]
if not use_accelerate:
self.model.to(self._device)
if not use_accelerate and not (load_in_4bit or load_in_8bit):
try:
self.model.to(self._device)
except:
print(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
)
def _create_auto_model(
self,
......@@ -242,6 +275,7 @@ class HuggingFaceAutoLM(BaseLM):
quantized: Optional[Union[bool, str]] = False,
revision: str,
subfolder: str,
low_cpu_mem_usage: Optional[bool] = True,
device_map: Optional[Union[str, _DeviceMapping]] = None,
max_memory: Optional[dict] = None,
offload_folder: Optional[str] = None,
......@@ -250,17 +284,35 @@ class HuggingFaceAutoLM(BaseLM):
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False,
inject_fused_attention: Optional[bool] = True,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
bnb_4bit_use_double_quant: Optional[bool] = False,
) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration."""
if not quantized:
if load_in_4bit:
assert transformers.__version__ >= "4.30.0", "load_in_4bit requires transformers >= 4.30.0"
assert (
transformers.__version__ >= "4.30.0"
), "load_in_4bit requires transformers >= 4.30.0"
model_kwargs = {}
if transformers.__version__ >= "4.30.0":
model_kwargs["load_in_4bit"] = load_in_4bit
if load_in_4bit:
if bnb_4bit_quant_type:
model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type
if bnb_4bit_compute_dtype:
model_kwargs["bnb_4bit_compute_dtype"] = _get_dtype(
bnb_4bit_compute_dtype
)
if bnb_4bit_use_double_quant:
model_kwargs[
"bnb_4bit_use_double_quant"
] = bnb_4bit_use_double_quant
model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
......@@ -271,15 +323,19 @@ class HuggingFaceAutoLM(BaseLM):
)
else:
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
pretrained,
model_basename=None if quantized == True else Path(quantized).stem,
device_map=device_map,
max_memory=max_memory,
trust_remote_code=trust_remote_code,
use_safetensors=True if quantized == True else quantized.endswith('.safetensors'),
use_safetensors=True
if quantized == True
else quantized.endswith(".safetensors"),
use_triton=gptq_use_triton,
warmup_triton=gptq_use_triton,
inject_fused_attention=inject_fused_attention,
)
return model
......@@ -308,11 +364,13 @@ class HuggingFaceAutoLM(BaseLM):
revision: str,
subfolder: str,
tokenizer: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
) -> transformers.PreTrainedTokenizer:
"""Returns a pre-trained tokenizer from a pre-trained tokenizer configuration."""
tokenizer = self.AUTO_TOKENIZER_CLASS.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
......@@ -411,19 +469,7 @@ class HuggingFaceAutoLM(BaseLM):
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")
@find_executable_batch_size(
starting_batch_size=512
) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones(
(batch_size, self.max_length), device=self.device
).long()
for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size
batch_size = forward_batch()
batch_size = self._detect_batch_size()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
......@@ -488,12 +534,14 @@ class AutoCausalLM(HuggingFaceAutoLM):
revision: str,
subfolder: str,
tokenizer: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
) -> transformers.PreTrainedTokenizer:
tokenizer = super()._create_auto_tokenizer(
pretrained=pretrained,
revision=revision,
subfolder=subfolder,
tokenizer=tokenizer,
trust_remote_code=trust_remote_code,
)
tokenizer.padding_side = "left"
return tokenizer
......
......@@ -4,6 +4,7 @@ from typing import List, Union
import sacrebleu
import lm_eval.base
from . import babi
from . import superglue
from . import glue
from . import arc
......@@ -19,6 +20,7 @@ from . import swag
from . import openbookqa
from . import squad
from . import naturalqs
from . import nqopen
from . import sat
from . import arithmetic
from . import lambada
......@@ -60,6 +62,11 @@ from . import xwinograd
from . import pawsx
from . import xnli
from . import mgsm
from . import scrolls
from . import ceval
from . import csatqa
from . import haerae
from . import cmmlu
########################################
# Translation tasks
......@@ -92,6 +99,7 @@ all_translation_benchmarks = {
TASK_REGISTRY = {
"babi": babi.Babi,
# GLUE
"cola": glue.CoLA,
"mnli": glue.MNLI,
......@@ -144,6 +152,7 @@ TASK_REGISTRY = {
"squad2": squad.SQuAD2,
"race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
"nq_open": nqopen.NQOpen,
"headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
"headqa_es": headqa.HeadQAEs,
"headqa_en": headqa.HeadQAEn,
......@@ -314,6 +323,19 @@ TASK_REGISTRY = {
"crows_pairs_french_nationality": crowspairs.CrowsPairsFrenchNationality,
"crows_pairs_french_physical_appearance": crowspairs.CrowsPairsFrenchPhysicalAppearance,
"crows_pairs_french_autre": crowspairs.CrowsPairsFrenchAutre,
"csatqa_wr": csatqa.WR,
"csatqa_gr": csatqa.GR,
"csatqa_rcs": csatqa.RCS,
"csatqa_rcss": csatqa.RCSS,
"csatqa_rch": csatqa.RCH,
"csatqa_li": csatqa.LI,
"haerae_hi": haerae.HI,
"haerae_kgk": haerae.KGK,
"haerae_lw": haerae.LW,
"haerae_rc": haerae.RC,
"haerae_rw": haerae.RW,
"haerae_sn": haerae.SN,
# Requires manual download
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
......@@ -325,6 +347,9 @@ TASK_REGISTRY = {
**pawsx.construct_tasks(),
**xnli.construct_tasks(),
**mgsm.construct_tasks(),
**scrolls.construct_tasks(),
**ceval.create_all_tasks(),
**cmmlu.create_all_tasks()
}
......
"""
Inspired by https://github.com/stanford-crfm/helm/blob/0eaaa62a2263ddb94e9850ee629423b010f57e4a/src/helm/benchmark/scenarios/babi_qa_scenario.py
"""
import numpy as np
from collections import defaultdict
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
@article{weston2015towards,
title={Towards ai-complete question answering: A set of prerequisite toy tasks},
author={Weston, Jason and Bordes, Antoine and Chopra, Sumit and Rush, Alexander M and Van Merri{\"e}nboer, Bart and Joulin, Armand and Mikolov, Tomas},
journal={arXiv preprint arXiv:1502.05698},
year={2015}
}
"""
class Babi(Task):
VERSION = 0
DATASET_PATH = "Muennighoff/babi"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
return self.dataset["train"]
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["valid"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def doc_to_text(self, doc):
return doc["passage"] + doc["question"]
def should_decontaminate(self):
return False # TODO Necessary?
def doc_to_decontamination_query(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc):
return " " + doc["answer"]
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
gold = doc["answer"]
pred = gold.strip() == results[0].strip()
return {"em": pred}
def aggregation(self):
return {
"em": mean,
}
def higher_is_better(self):
return {
"em": True,
}
"""
C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models
https://arxiv.org/pdf/2305.08322.pdf
C-Eval is a comprehensive Chinese evaluation suite for foundation models.
It consists of 13948 multi-choice questions spanning 52 diverse disciplines
and four difficulty levels.
Homepage: https://cevalbenchmark.com/
"""
from lm_eval.base import MultipleChoiceTask
_CITATION = """
@article{huang2023ceval,
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
journal={arXiv preprint arXiv:2305.08322},
year={2023}
}
"""
SUBJECTS = {
"computer_network": "计算机网络",
"operating_system": "操作系统",
"computer_architecture": "计算机组成",
"college_programming": "大学编程",
"college_physics": "大学物理",
"college_chemistry": "大学化学",
"advanced_mathematics": "高等数学",
"probability_and_statistics": "概率统计",
"discrete_mathematics": "离散数学",
"electrical_engineer": "注册电气工程师",
"metrology_engineer": "注册计量师",
"high_school_mathematics": "高中数学",
"high_school_physics": "高中物理",
"high_school_chemistry": "高中化学",
"high_school_biology": "高中生物",
"middle_school_mathematics": "初中数学",
"middle_school_biology": "初中生物",
"middle_school_physics": "初中物理",
"middle_school_chemistry": "初中化学",
"veterinary_medicine": "兽医学",
"college_economics": "大学经济学",
"business_administration": "工商管理",
"marxism": "马克思主义基本原理",
"mao_zedong_thought": "毛泽东思想和中国特色社会主义理论体系概论",
"education_science": "教育学",
"teacher_qualification": "教师资格",
"high_school_politics": "高中政治",
"high_school_geography": "高中地理",
"middle_school_politics": "初中政治",
"middle_school_geography": "初中地理",
"modern_chinese_history": "近代史纲要",
"ideological_and_moral_cultivation": "思想道德修养与法律基础",
"logic": "逻辑学",
"law": "法学",
"chinese_language_and_literature": "中国语言文学",
"art_studies": "艺术学",
"professional_tour_guide": "导游资格",
"legal_professional": "法律职业资格",
"high_school_chinese": "高中语文",
"high_school_history": "高中历史",
"middle_school_history": "初中历史",
"civil_servant": "公务员",
"sports_science": "体育学",
"plant_protection": "植物保护",
"basic_medicine": "基础医学",
"clinical_medicine": "临床医学",
"urban_and_rural_planner": "注册城乡规划师",
"accountant": "注册会计师",
"fire_engineer": "注册消防工程师",
"environmental_impact_assessment_engineer": "环境影响评价工程师",
"tax_accountant": "税务师",
"physician": "医师资格",
}
def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
e.g. {Ceval-computer_network: Task, Ceval-clinical_medicine: Task}
"""
return {f"Ceval-valid-{sub}": create_task(sub) for sub in SUBJECTS.keys()}
def create_task(subject):
class Ceval(CevalSubject):
def __init__(self):
super().__init__(subject)
return Ceval
class CevalSubject(MultipleChoiceTask):
VERSION = 1
DATASET_PATH = "ceval/ceval-exam"
DATASET_NAME = None
def __init__(self, subject):
self.DATASET_NAME = subject
super().__init__()
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def validation_docs(self):
if self.has_validation_docs():
return map(self._process_doc, self.dataset["val"])
def test_docs(self):
if self.has_test_docs():
return map(self._process_doc, self.dataset["test"])
def _format_subject(self, subject):
words = subject.split("_")
return " ".join(words)
def fewshot_context(self, doc, num_fewshot, **kwargs):
subject = self.DATASET_NAME
description = f"以下是中国关于{SUBJECTS[subject]}的单项选择题,请选出其中的正确答案。"
kwargs["description"] = description
return super().fewshot_context(doc=doc, num_fewshot=num_fewshot, **kwargs)
def _process_doc(self, doc):
def format_example(doc, keys):
"""
<prompt>
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
答案:
"""
question = doc["question"].strip()
choices = "".join([f"{key}. {doc[key]}\n" for key in keys])
prompt = f"{question}\n{choices}答案:"
return prompt
keys = ["A", "B", "C", "D"]
return {
"query": format_example(doc, keys),
"choices": keys,
"gold": ord(doc["answer"]) - ord("A"),
}
def fewshot_examples(self, k, rnd):
if self._fewshot_docs is None:
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
# use the unchanged order of the dev set without sampling,
return self._fewshot_docs[:k]
def doc_to_text(self, doc):
return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
"""
CMMLU: Measuring massive multitask language understanding in Chinese
https://arxiv.org/abs/2306.09212
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge
and reasoning abilities of LLMs within the Chinese language and cultural context. CMMLU covers a wide range of
subjects, comprising 67 topics that span from elementary to advanced professional levels. It includes subjects that
require computational expertise, such as physics and mathematics, as well as disciplines within humanities and
social sciences. Many of these tasks are not easily translatable from other languages due to their specific
contextual nuances and wording. Furthermore, numerous tasks within CMMLU have answers that are specific to
China and may not be universally applicable or considered correct in other regions or languages.
Homepage: https://github.com/haonan-li/CMMLU
Huggingface homepage: https://huggingface.co/datasets/haonan-li/cmmlu
"""
import os
from lm_eval.base import MultipleChoiceTask, rf
_CITATION = """
@misc{li2023cmmlu,
title={CMMLU: Measuring massive multitask language understanding in Chinese},
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
year={2023},
eprint={2306.09212},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
SUBJECTS = [
"agronomy",
"anatomy",
"ancient_chinese",
"arts",
"astronomy",
"business_ethics",
"chinese_civil_service_exam",
"chinese_driving_rule",
"chinese_food_culture",
"chinese_foreign_policy",
"chinese_history",
"chinese_literature",
"chinese_teacher_qualification",
"clinical_knowledge",
"college_actuarial_science",
"college_education",
"college_engineering_hydrology",
"college_law",
"college_mathematics",
"college_medical_statistics",
"college_medicine",
"computer_science",
"computer_security",
"conceptual_physics",
"construction_project_management",
"economics",
"education",
"electrical_engineering",
"elementary_chinese",
"elementary_commonsense",
"elementary_information_and_technology",
"elementary_mathematics",
"ethnology",
"food_science",
"genetics",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_geography",
"high_school_mathematics",
"high_school_physics",
"high_school_politics",
"human_sexuality",
"international_law",
"journalism",
"jurisprudence",
"legal_and_moral_basis",
"logical",
"machine_learning",
"management",
"marketing",
"marxist_theory",
"modern_chinese",
"nutrition",
"philosophy",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_study",
"sociology",
"sports_science",
"traditional_chinese_medicine",
"virology",
"world_history",
"world_religions"
]
SUBJECT_MAPPING = {
"agronomy": "农学",
"anatomy": "解剖学",
"ancient_chinese": "古汉语",
"arts": "艺术学",
"astronomy": "天文学",
"business_ethics": "商业伦理",
"chinese_civil_service_exam": "中国公务员考试",
"chinese_driving_rule": "中国驾驶规则",
"chinese_food_culture": "中国饮食文化",
"chinese_foreign_policy": "中国外交政策",
"chinese_history": "中国历史",
"chinese_literature": "中国文学",
"chinese_teacher_qualification": "中国教师资格",
"clinical_knowledge": "临床知识",
"college_actuarial_science": "大学精算学",
"college_education": "大学教育学",
"college_engineering_hydrology": "大学工程水文学",
"college_law": "大学法律",
"college_mathematics": "大学数学",
"college_medical_statistics": "大学医学统计",
"college_medicine": "大学医学",
"computer_science": "计算机科学",
"computer_security": "计算机安全",
"conceptual_physics": "概念物理学",
"construction_project_management": "建设工程管理",
"economics": "经济学",
"education": "教育学",
"electrical_engineering": "电气工程",
"elementary_chinese": "小学语文",
"elementary_commonsense": "小学常识",
"elementary_information_and_technology": "小学信息技术",
"elementary_mathematics": "初等数学",
"ethnology": "民族学",
"food_science": "食品科学",
"genetics": "遗传学",
"global_facts": "全球事实",
"high_school_biology": "高中生物",
"high_school_chemistry": "高中化学",
"high_school_geography": "高中地理",
"high_school_mathematics": "高中数学",
"high_school_physics": "高中物理学",
"high_school_politics": "高中政治",
"human_sexuality": "人类性行为",
"international_law": "国际法学",
"journalism": "新闻学",
"jurisprudence": "法理学",
"legal_and_moral_basis": "法律与道德基础",
"logical": "逻辑学",
"machine_learning": "机器学习",
"management": "管理学",
"marketing": "市场营销",
"marxist_theory": "马克思主义理论",
"modern_chinese": "现代汉语",
"nutrition": "营养学",
"philosophy": "哲学",
"professional_accounting": "专业会计",
"professional_law": "专业法学",
"professional_medicine": "专业医学",
"professional_psychology": "专业心理学",
"public_relations": "公共关系",
"security_study": "安全研究",
"sociology": "社会学",
"sports_science": "体育学",
"traditional_chinese_medicine": "中医中药",
"virology": "病毒学",
"world_history": "世界历史",
"world_religions": "世界宗教",
}
SUBJECT_CATEGORIES = {
"agronomy": ['other'],
"anatomy": ['biology'],
"ancient_chinese": ['linguistics','china specific'],
"arts": ['arts'],
"astronomy": ['physics'],
"business_ethics": ['business'],
"chinese_civil_service_exam": ['politics','china specific'],
"chinese_driving_rule": ['other','china specific'],
"chinese_food_culture": ['culture','china specific'],
"chinese_foreign_policy": ['politics','china specific'],
"chinese_history":['history','china specific'],
"chinese_literature": ['literature','china specific'],
"chinese_teacher_qualification": ['education','china specific'],
"college_actuarial_science":['math'],
"college_education":['education'],
"college_engineering_hydrology": ['engineering'],
"college_law": ['law'],
"college_mathematics": ['math'],
"college_medical_statistics":['statistics'],
"clinical_knowledge": ['other'],
"college_medicine": ['other'],
"computer_science": ['computer science'],
"computer_security": ['other'],
"conceptual_physics": ['physics'],
"construction_project_management": ['other','china specific'],
"economics": ['economics'],
"education": ['education'],
"elementary_chinese":['linguistics','china specific'],
"elementary_commonsense":['other','china specific'],
"elementary_information_and_technology": ['other'],
"electrical_engineering": ['engineering'],
"elementary_mathematics": ['math'],
"ethnology": ['culture','china specific'],
"food_science": ['other'],
"genetics": ['biology'],
"global_facts": ['global'],
"high_school_biology": ['biology'],
"high_school_chemistry": ['chemistry'],
"high_school_geography": ['geography'],
"high_school_mathematics": ['math'],
"high_school_physics": ['physics'],
"high_school_politics": ['politics','china specific'],
"human_sexuality": ['other'],
"international_law": ['law'],
"journalism": ['sociology'],
"jurisprudence": ['law'],
"legal_and_moral_basis": ['other'],
"logical": ['philosophy'],
"machine_learning": ['computer science'],
"management": ['business'],
"marketing": ['business'],
"marxist_theory": ['philosophy'],
"modern_chinese": ['linguistics','china specific'],
"nutrition": ['other'],
"philosophy": ['philosophy'],
"professional_accounting": ['business'],
"professional_law": ['law'],
"professional_medicine": ['other'],
"professional_psychology": ['psychology'],
"public_relations": ['politics'],
"security_study": ['politics'],
"sociology": ['culture'],
"sports_science": ['other'],
"traditional_chinese_medicine": ['other','china specific'],
"virology": ['biology'],
"world_history":['history'],
"world_religions": ['global'],
}
CATEGORIES = {
"STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics"],
"Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
"Social Science": ['linguistics',"business", "politics", "culture", "economics", "geography", "psychology", "education", "sociology"],
"Other":["other"],
"China specific": ["china specific"],
}
def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
e.g. {cmmlu-physician: Task, cmmlu-tax_accountant: Task}
"""
return {f"cmmlu-{sub}": create_task(sub) for sub in SUBJECTS}
def create_task(subject):
class CmmluTest(GeneralCmmluTest):
def __init__(self):
super().__init__(subject)
return CmmluTest
class GeneralCmmluTest(MultipleChoiceTask):
VERSION = 1
DATASET_PATH = os.path.join("haonan-li/cmmlu")
DATASET_NAME = None
def __init__(self, subject):
self.DATASET_NAME = subject
super().__init__()
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def fewshot_context(self, doc, num_fewshot, **kwargs):
subject = self.DATASET_NAME
description = f"以下是关于{SUBJECT_MAPPING[subject]}的单项选择题,请直接给出正确答案的选项。"
kwargs["description"] = description
return super().fewshot_context(doc=doc, num_fewshot=num_fewshot, **kwargs)
def _process_doc(self, doc):
def format_example(doc, keys):
"""
题目:<prompt>
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
答案是:
"""
question = doc["Question"].strip()
choices = "".join(
[f"{key}. {doc[key]}\n" for key in keys]
)
prompt = f"题目:{question}\n{choices}答案是:"
return prompt
keys = ["A", "B", "C", "D"]
return {
"query": format_example(doc, keys),
"choices": keys,
"gold": keys.index(doc["Answer"]),
}
def fewshot_examples(self, k, rnd):
if self._fewshot_docs is None:
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return self._fewshot_docs[:k]
def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, "{}".format(choice))[0] for choice in doc["choices"]
]
return lls
def doc_to_text(self, doc):
return doc["query"]
def doc_to_target(self, doc):
return doc["choices"][doc["gold"]]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
......@@ -132,7 +132,7 @@ class CrowsPairsMutilingual(Task):
def higher_is_better(self):
# For all metrics lower is better
return {"likelihood_difference": False, "pct_stereotype": True}
return {"likelihood_difference": False, "pct_stereotype": False}
def aggregation(self):
return {"likelihood_difference": mean, "pct_stereotype": mean}
......
from lm_eval.base import MultipleChoiceTask
class CSATQA(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "EleutherAI/csatqa"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
instruction = f"""다음을 읽고 정답으로 알맞은 것을 고르시요.
### Context: {doc["context"]}
### Question: {doc["question"]}
### Options:
(1) {doc['option#1']}\n(2) {doc["option#2"]}\n(3) {doc["option#3"]}\n(4) {doc['option#4']}\n(5) {doc['option#5']}
### Answer: 주어진 문제의 정답은"""
choices = [
doc["option#1"],
doc["option#2"],
doc["option#3"],
doc["option#4"],
doc["option#5"],
]
out_doc = {
"question": instruction,
"choices": ["(1)", "(2)", "(3)", "(4)", "(5)"],
"gold": int(doc["gold"]) - 1,
}
return out_doc
def doc_to_text(self, doc):
return doc["question"]
class WR(CSATQA):
DATASET_NAME = "WR"
class GR(CSATQA):
DATASET_NAME = "GR"
class RCS(CSATQA):
DATASET_NAME = "RCS"
class RCSS(CSATQA):
DATASET_NAME = "RCSS"
class RCH(CSATQA):
DATASET_NAME = "RCH"
class LI(CSATQA):
DATASET_NAME = "LI"
from lm_eval.base import MultipleChoiceTask
class Haerae(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "amphora/haerae_bench"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
choices = [doc["o1"], doc["o2"], doc["o3"], doc["o4"]]
if doc.get("o5") is not None:
choices.append(doc["o5"])
out_doc = {
"query": doc["query"],
"choices": choices,
"gold": int(doc["gold"]) - 1,
}
return out_doc
def doc_to_text(self, doc):
return doc["query"]
class HI(Haerae):
DATASET_NAME = "HI"
class KGK(Haerae):
DATASET_NAME = "KGK"
class LW(Haerae):
DATASET_NAME = "LW"
class RC(Haerae):
DATASET_NAME = "RC"
class RW(Haerae):
DATASET_NAME = "RW"
class SN(Haerae):
DATASET_NAME = "SN"
......@@ -14,7 +14,6 @@ Homepage: https://github.com/hendrycks/test
"""
from lm_eval.base import MultipleChoiceTask
_CITATION = """
@article{hendryckstest2021,
title={Measuring Massive Multitask Language Understanding},
......@@ -103,8 +102,8 @@ def create_task(subject):
class GeneralHendrycksTest(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "hendrycks_test"
VERSION = 1
DATASET_PATH = "cais/mmlu"
DATASET_NAME = None
def __init__(self, subject):
......@@ -112,7 +111,7 @@ class GeneralHendrycksTest(MultipleChoiceTask):
super().__init__()
def has_training_docs(self):
return False
return True
def has_validation_docs(self):
return True
......@@ -126,41 +125,50 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _format_subject(self, subject):
words = subject.split("_")
return " ".join(words)
def fewshot_context(self, doc, num_fewshot, **kwargs):
subject = self.DATASET_NAME
description = f"The following are multiple choice questions (with answers) about {self._format_subject(subject)}."
kwargs["description"] = description
return super().fewshot_context(doc=doc, num_fewshot=num_fewshot, **kwargs)
def _process_doc(self, doc):
def format_example(doc, keys):
"""
Question: <prompt>
Choices:
<prompt>
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
Answer:
"""
prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join(
question = doc["question"].strip()
choices = "".join(
[f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])]
)
prompt += "Answer:"
prompt = f"{question}\n{choices}Answer:"
return prompt
keys = ["A", "B", "C", "D"]
return {
"query": format_example(doc, keys),
"choices": doc["choices"],
"gold": keys.index(doc["answer"])
if isinstance(doc["answer"], str)
else doc["answer"],
"choices": keys,
"gold": doc["answer"],
}
def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't
if self._fewshot_docs is None:
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k)
# use the unchanged order of the dev set without sampling,
# just as in the original code https://github.com/hendrycks/test/blob/master/evaluate.py#L28
return self._fewshot_docs[:k]
def doc_to_text(self, doc):
return doc["query"]
......
"""
Latent Retrieval for Weakly Supervised Open Domain Question Answering
https://arxiv.org/pdf/1906.00300.pdf
Natural Questions: a Benchmark for Question Answering Research
https://storage.googleapis.com/pub-tools-public-publication-data/pdf/1f7b46b5378d757553d3e92ead36bda2e4254244.pdf
The NQ-Open task, introduced by Lee et. al. 2019, is an open-domain question
answering benchmark that is derived from Natural Questions. The goal is to predict
an English answer string for an input English question. All questions can be
answered using the contents of English Wikipedia.
Homepage: https://github.com/google-research-datasets/natural-questions/tree/master/nq_open
"""
import regex
import string
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
_CITATION = """
@inproceedings{lee-etal-2019-latent,
title = "Latent Retrieval for Weakly Supervised Open Domain Question Answering",
author = "Lee, Kenton and
Chang, Ming-Wei and
Toutanova, Kristina",
booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2019",
address = "Florence, Italy",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/P19-1612",
doi = "10.18653/v1/P19-1612",
pages = "6086--6096",
abstract = "Recent work on open domain question answering (QA) assumes strong supervision of the supporting evidence and/or assumes a blackbox information retrieval (IR) system to retrieve evidence candidates. We argue that both are suboptimal, since gold evidence is not always available, and QA is fundamentally different from IR. We show for the first time that it is possible to jointly learn the retriever and reader from question-answer string pairs and without any IR system. In this setting, evidence retrieval from all of Wikipedia is treated as a latent variable. Since this is impractical to learn from scratch, we pre-train the retriever with an Inverse Cloze Task. We evaluate on open versions of five QA datasets. On datasets where the questioner already knows the answer, a traditional IR system such as BM25 is sufficient. On datasets where a user is genuinely seeking an answer, we show that learned retrieval is crucial, outperforming BM25 by up to 19 points in exact match.",
}
"""
class NQOpen(Task):
VERSION = 0
DATASET_PATH = "nq_open"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
raise NotImplementedError()
def doc_to_text(self, doc):
return f"Q: {doc['question']}\nA:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc):
return " " + doc["answer"][0]
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, {"until": ["\n", ".", ","]})
return continuation
def _normalize_answer(self, text):
# Lowercase and remove punctuation, strip whitespace
text = text.strip().lower().translate(str.maketrans("", "", string.punctuation))
# Remove articles, resulting in duplicate whitespace
text = regex.sub(r"\b(a|an|the)\b", " ", text)
# Remove duplicate whitespace
text = " ".join(text.split())
return text
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
continuation = self._normalize_answer(results[0])
answers = [self._normalize_answer(answer) for answer in doc["answer"]]
return {"em": float(continuation in answers)}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"em": mean,
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"em": True,
}
......@@ -33,34 +33,43 @@ _CITATION = """
class Pubmed_QA(Task):
VERSION = 0
DATASET_PATH = "pubmed_qa"
DATASET_NAME = "pqa_labeled"
DATASET_PATH = "bigbio/pubmed_qa"
DATASET_NAME = "pubmed_qa_labeled_fold0_source"
def has_training_docs(self):
return False
return True
def has_validation_docs(self):
return False
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = self.dataset["train"]
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
# HF is labelled as train but its really just for testing
return self.dataset["train"]
return self.dataset["test"]
def doc_to_text(self, doc):
ctxs = "\n".join(doc["context"]["contexts"])
ctxs = "\n".join(doc["CONTEXTS"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format(
ctxs, doc["question"], doc["final_decision"]
ctxs, doc["QUESTION"], doc["final_decision"]
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"] + " " + "\n".join(doc["context"]["contexts"])
return doc["question"] + " " + "\n".join(doc["CONTEXTS"])
def doc_to_target(self, doc):
return " {}".format(doc["final_decision"])
......
"""
SCROLLS: Standardized CompaRison Over Long Language Sequences
https://arxiv.org/abs/2201.03533
SCROLLS is a suite of datasets that require synthesizing information over long texts.
The benchmark includes seven natural language tasks across multiple domains,
including summarization, question answering, and natural language inference.
Homepage: https://www.scrolls-benchmark.com/
Since SCROLLS tasks are generally longer than the maximum sequence length of many models,
it is possible to create "subset" tasks that contain only those samples whose tokenized length
is less than some pre-defined limit. For example, to create a subset of "Qasper" that would
be suitable for a model using the GPTNeoX tokenizer and a 4K maximium sequence length:
```
class QasperGPTNeoX4K(Qasper):
PRUNE_TOKENIZERS = ["EleutherAI/pythia-410m-deduped"]
PRUNE_MAX_TOKENS = 4096
PRUNE_NUM_PROC = _num_cpu_cores() # optional, to speed up pruning of large datasets like NarrativeQA
```
`PRUNE_TOKENIZERS` can contain more than one tokenizer; this will include only samples that are
less than `PRUNE_MAX_TOKENS` for ALL of the tokenizers. This can be useful to comparing models
that use different tokenizers but the same maximum sequence length.
Once the subset task class has been defined in this file, it can be used by adding the class
to `lm_eval/tasks/__init__.py`.
NOTE: GovReport may need `max_gen_toks` set larger for causal models.
"""
from abc import abstractmethod
from datasets import load_metric
from transformers import AutoTokenizer
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
from functools import reduce
import transformers.data.metrics.squad_metrics as squad_metrics
import numpy as np
import re
_CITATION = """
@inproceedings{shaham-etal-2022-scrolls,
title = "{SCROLLS}: Standardized {C}ompa{R}ison Over Long Language Sequences",
author = "Shaham, Uri and
Segal, Elad and
Ivgi, Maor and
Efrat, Avia and
Yoran, Ori and
Haviv, Adi and
Gupta, Ankit and
Xiong, Wenhan and
Geva, Mor and
Berant, Jonathan and
Levy, Omer",
booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing",
month = dec,
year = "2022",
address = "Abu Dhabi, United Arab Emirates",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2022.emnlp-main.823",
pages = "12007--12021"
}
"""
# SCROLLS is formualted as a sequence-to-sequence task.
# To allow for evaluation of causal models, we'll
# reformualte these with appropriate prompts
def _download_metric():
import os
import shutil
from huggingface_hub import hf_hub_download
scrolls_metric_path = hf_hub_download(
repo_id="tau/scrolls", repo_type="dataset", filename="metrics/scrolls.py"
)
updated_scrolls_metric_path = (
os.path.dirname(scrolls_metric_path)
+ os.path.basename(scrolls_metric_path).replace(".", "_")
+ ".py"
)
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
return updated_scrolls_metric_path
def _process_doc_prepended_question(doc):
# "When a query is given in addition to the raw text (as
# in QMSum, Qasper, NarrativeQA, QuALITY, and ContractNLI),
# we prepend it to the text, using two newlines as a natural separator"
input = doc["input"]
split = input.find("\n\n")
return {
"id": doc["id"],
"pid": doc["pid"],
"input": input,
"outputs": doc["outputs"],
"question": input[0:split],
"text": input[split + 2 :],
}
def _drop_duplicates_in_input(untokenized_dataset):
# from scrolls/evaluator/dataset_evaluator.py
indices_to_keep = []
id_to_idx = {}
outputs = []
for i, (id_, output) in enumerate(
zip(untokenized_dataset["id"], untokenized_dataset["output"])
):
if id_ in id_to_idx:
outputs[id_to_idx[id_]].append(output)
continue
indices_to_keep.append(i)
id_to_idx[id_] = len(outputs)
outputs.append([output])
untokenized_dataset = untokenized_dataset.select(indices_to_keep).flatten_indices()
untokenized_dataset = untokenized_dataset.remove_columns("output")
untokenized_dataset = untokenized_dataset.add_column("outputs", outputs)
return untokenized_dataset
def _num_cpu_cores():
# https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python/55423170#55423170
try:
import psutil
return psutil.cpu_count(logical=False)
except ImportError:
import os
return len(os.sched_getaffinity(0))
class _SCROLLSTask(Task):
VERSION = 0
DATASET_PATH = "tau/scrolls"
DATASET_NAME = None
PRUNE_TOKENIZERS = None
PRUNE_MAX_TOKENS = None
PRUNE_NUM_PROC = None
def __init__(self, no_metric=False):
super().__init__()
self.metric = (
load_metric(_download_metric(), config_name=self.DATASET_NAME)
if not no_metric
else None
)
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
for doc in self.dataset["train"]:
yield from self._process_doc(doc)
def validation_docs(self):
for doc in self.dataset["validation"]:
yield from self._process_doc(doc)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["input"]
def download(self, *args, **kwargs):
super().download(*args, **kwargs)
del self.dataset["test"]
for split in self.dataset:
self.dataset[split] = _drop_duplicates_in_input(self.dataset[split])
if self.PRUNE_TOKENIZERS is not None and self.PRUNE_TOKENIZERS is not None:
self.prune()
def _get_prune_text(self, sample):
return self.doc_to_text(self._process_doc(sample)[0])
def prune(self):
"""Create a pruned version of a SCROLLS task dataset containing only inputs
that are less than `max_tokens` when tokenized by each tokenizer
"""
tokenizers = [
AutoTokenizer.from_pretrained(tokenizer)
for tokenizer in self.PRUNE_TOKENIZERS
]
cache = {}
def _filter(sample):
text = self._get_prune_text(sample)
cached = cache.get(text, None)
if cached is None:
for tokenizer in tokenizers:
if len(tokenizer(text).input_ids) > self.PRUNE_MAX_TOKENS:
cache[text] = False
return False
cache[text] = True
return True
else:
return cached
self.dataset = self.dataset.filter(_filter, num_proc=self.PRUNE_NUM_PROC)
def doc_to_target(self, doc):
return " " + ", ".join(doc["outputs"])
def doc_to_text(self, doc):
return f"{doc['text']}\n\nQuestion: {doc['question']}\nAnswer:"
def higher_is_better(self):
return {x: True for x in self._scrolls_metrics().keys()}
@abstractmethod
def _scrolls_metrics(self):
pass
def _make_compute_metrics(self, value):
def compute_metrics(samples):
predictions, references = zip(*samples) # unzip, if you will
computed = self.metric.compute(
predictions=predictions, references=references
)
return computed[value]
return compute_metrics
def aggregation(self):
return {
key: self._make_compute_metrics(value)
for key, value in self._scrolls_metrics().items()
}
class _SCROLLSMultipleChoiceTask(_SCROLLSTask):
def __init__(self):
super().__init__(no_metric=True)
def _scrolls_metrics(self):
return None
def aggregation(self):
return {"em": mean, "acc": mean, "acc_norm": mean}
def higher_is_better(self):
return {"em": True, "acc": True, "acc_norm": True}
def process_results(self, doc, results):
gold = doc["gold"]
acc = 1.0 if np.argmax(results) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
return {
"acc": acc,
"acc_norm": acc_norm,
"em": acc_norm * 100.0,
}
def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
]
return lls
class _SCROLLSSummaryTask(_SCROLLSTask):
def _process_doc(self, doc):
return [doc]
def _scrolls_metrics(self):
return {
"rouge1": "rouge/rouge1",
"rouge2": "rouge/rouge2",
"rougeL": "rouge/rougeL",
}
def process_results(self, doc, results):
return {
"rouge1": (results[0], doc["outputs"]),
"rouge2": (results[0], doc["outputs"]),
"rougeL": (results[0], doc["outputs"]),
}
def construct_requests(self, doc, ctx):
return [rf.greedy_until(ctx, {"until": ["\n"]})]
def doc_to_text(self, doc):
return f"{doc['input']}\n\nQuestion: What is a summary of the preceding text?\nAnswer:"
class Qasper(_SCROLLSTask):
"""A Dataset of Information-Seeking Questions and Answers Anchored in Research Papers
https://arxiv.org/abs/2105.03011
"""
DATASET_NAME = "qasper"
def _process_doc(self, doc):
doc = _process_doc_prepended_question(doc)
doc["is_yes_no"] = reduce(
lambda prev, cur: prev
and squad_metrics.normalize_answer(cur) in ["yes", "no"],
doc["outputs"],
True,
)
return [doc]
def _scrolls_metrics(self):
return {"f1": "f1"}
def process_results(self, doc, results):
if doc["is_yes_no"]:
prediction = " yes" if results[0] > results[1] else " no"
elif len(results[0].strip()) == 0:
prediction = "Unanswerable"
else:
prediction = results[0]
return {"f1": (prediction, doc["outputs"])}
def construct_requests(self, doc, ctx):
if doc["is_yes_no"]:
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
return [ll_yes, ll_no]
else:
return [rf.greedy_until(ctx, {"until": ["\n"]})]
class QuALITY(_SCROLLSMultipleChoiceTask):
"""QuALITY: Question Answering with Long Input Texts, Yes!
https://arxiv.org/abs/2112.08608
"""
DATASET_NAME = "quality"
_multiple_choice_pattern = re.compile(r" *\([A-D]\) *")
@staticmethod
def _normalize_answer(text):
return " ".join(text.split()).strip()
def _process_doc(self, doc):
doc = _process_doc_prepended_question(doc)
split = doc["text"].find("\n\n", doc["text"].find("(D)"))
choices_text = doc["text"][:split]
doc["text"] = doc["text"][split:].strip()
doc["choices"] = [
QuALITY._normalize_answer(choice)
for choice in re.split(QuALITY._multiple_choice_pattern, choices_text)[1:]
]
doc["gold"] = doc["choices"].index(QuALITY._normalize_answer(doc["outputs"][0]))
return [doc]
class NarrativeQA(_SCROLLSTask):
"""The NarrativeQA Reading Comprehension Challenge
https://arxiv.org/abs/1712.07040
"""
DATASET_NAME = "narrative_qa"
def _process_doc(self, doc):
return [_process_doc_prepended_question(doc)]
def _scrolls_metrics(self):
return {"f1": "f1"}
def _get_prune_text(self, doc):
# pruning narrativeqa takes forever -- let's cheat a bit
# and just cache on the text, not the question, since
# the dataset is different questions about the same large
# documents
return self._process_doc(doc)[0]["text"]
def process_results(self, doc, results):
return {"f1": (results[0], doc["outputs"])}
def construct_requests(self, doc, ctx):
return [rf.greedy_until(ctx, {"until": ["\n"]})]
class ContractNLI(_SCROLLSMultipleChoiceTask):
"""ContractNLI: A Dataset for Document-level Natural Language Inference for Contracts
https://arxiv.org/abs/1712.07040
"""
DATASET_NAME = "contract_nli"
CHOICES = ["Not mentioned", "Entailment", "Contradiction"]
def _process_doc(self, doc):
doc = _process_doc_prepended_question(doc)
doc["choices"] = ContractNLI.CHOICES
doc["gold"] = ContractNLI.CHOICES.index(doc["outputs"][0])
return [doc]
def doc_to_text(self, doc):
return f"{doc['text']}\n\nHypothesis: {doc['question']}\nConclusion:"
class GovReport(_SCROLLSSummaryTask):
"""Efficient Attentions for Long Document Summarization
https://arxiv.org/abs/2104.02112
Note: The average length of the reference summaries is ~3,000
characters, or ~600 tokens as tokenized by GPT-NeoX. For causal models,
it is recommended to set `max_gen_toks` sufficently large (e.g. 1024)
to allow a full summary to be generated.
"""
DATASET_NAME = "gov_report"
class SummScreenFD(_SCROLLSSummaryTask):
"""SummScreen: A Dataset for Abstractive Screenplay Summarization
https://arxiv.org/abs/2104.07091
"""
DATASET_NAME = "summ_screen_fd"
class QMSum(_SCROLLSSummaryTask):
"""QMSum: A New Benchmark for Query-based Multi-domain
Meeting Summarization
https://arxiv.org/abs/2104.05938
"""
DATASET_NAME = "qmsum"
def _process_doc(self, doc):
return [_process_doc_prepended_question(doc)]
def doc_to_text(self, doc):
return f"{doc['text']}\n\nQuestion: {doc['question']}\nAnswer:"
def construct_tasks():
return {
"scrolls_qasper": Qasper,
"scrolls_quality": QuALITY,
"scrolls_narrativeqa": NarrativeQA,
"scrolls_contractnli": ContractNLI,
"scrolls_govreport": GovReport,
"scrolls_summscreenfd": SummScreenFD,
"scrolls_qmsum": QMSum,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment