Unverified Commit 5c72066b authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #599 from EleutherAI/fix-gold-aliases

[Refactor] More MCQA fixes
parents 6a000adb abe05c5d
......@@ -85,13 +85,31 @@ Such that {{question}} will be replaced by `doc["question"]` when rendering the
Our intended output is for the model to predict a single whitespace, and then the answer to the question. We do this via:
```yaml
doc_to_target: "{{answer}}"
gold_alias: "{{answer}}"
```
where `doc_to_target` is *the string that will be appended to inputs for each few-shot example*, and `gold_alias` is *what is passed to our metric function as reference or gold answer to score against*. For example, for GSM8k word problems, `doc_to_target` should be the reference text reasoning chain given in the dataset culminating in the answer, and `gold_alias` should be **only the numeric answer** to the word problem that is given at the end of the reasoning chain, and which the evaluated model's answer will be compared against.
**Important**: We always add one whitespace between the input and output, such that the full input-output string is `doc_to_target(doc) + " " + doc_to_text(doc)`. doc_to_text and doc_to_target should not contain trailing right or left whitespace, respectively.
Users can also fill out the optional `template_aliases` YAML field, which is added ahead of both the `doc_to_text` and `doc_to_target` fields. This field should not contain any test, but only Jinja variable definitions (`{% ... %}` clauses). This can be used to perform more involved string manipulations and renamings of dataset columns while the main prompt fields remain easy to parse visually.
#### Multiple choice format
For tasks which are multiple choice (a fixed, finite set of label words per each document) and evaluated via comparing loglikelihoods of all label words (the `multiple_choice` task output type) we enforce a particular convention on prompt format.
An annotated example in the case of SciQ is as follows:
```yaml
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # `template_aliases` must set the list of possible answer choices to the jinja variable `answer_choices` (List[str]), and set what the index within `answer_choices` of this doc's gold label (correct answer choice).
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" # This is the input portion of the prompt for this doc. It will have " {{choice}}" appended to it as target for each choice in answer_choices.
doc_to_target: "{{answer_choices[gold]}}" # this contains the gold-standard answer choice, selected via indexing to index `gold` in the answer choice list.
gold_alias: "{{gold}}" # this must be castable to an integer. It must output only the index within `answer_choices` that is the correct label.
```
Task implementers are thus able to decide what the answer choices should be for a document, and what prompt format to use.
### Using Python Functions for Prompts
There may be cases where the prompt we want to implement is easier expressed in Python instead of Jinja 2. For this, we can use Python helper functions that are defined in the YAML config. It should be noted that the function script must be in the same directory as the yaml.
......@@ -124,21 +142,6 @@ use_prompt: "promptsource:GPT-3 Style"
```
#### Multiple choice format
For tasks which are multiple choice (a fixed, finite set of label words per each document) and evaluated via comparing loglikelihoods of all label words (the `multiple_choice` task output type) we enforce a particular convention on prompt format.
An annotated example in the case of SciQ is as follows:
```yaml
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # `template_aliases` must set the list of possible answer choices to the jinja variable `answer_choices` (List[str]), and set what the index within `answer_choices` of this doc's gold label (correct answer choice).
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" # This is the input portion of the prompt for this doc. It will have " {{choice}}" appended to it as target for each choice in answer_choices.
doc_to_target: "{{gold}}" # this must be castable to an integer. It must output only the index within `answer_choices` that is the correct label.
```
Task implementers are thus able to decide what the answer choices should be for a document, and what prompt format to use.
### Setting metrics
You're almost done! Now we need to choose how to score our task.
......
......@@ -104,3 +104,17 @@ class LM(abc.ABC):
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
@property
def rank(self):
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return 0
@property
def world_size(self):
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return 1
......@@ -7,7 +7,8 @@ class Sampler:
self.task = task
self.config = task._config
self.delimiter = self.config.delimiter
self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
......@@ -30,9 +31,12 @@ class Sampler:
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
self.delimiter.join(
self.fewshot_delimiter.join(
[
self.task.doc_to_text(doc) + self.task.doc_to_target(doc)
# TODO: is separating doc_to_text and doc_to_target by one space always desired?
self.task.doc_to_text(doc)
+ self.target_delimiter
+ self.task.doc_to_target(doc)
for doc in selected_docs
]
)
......
......@@ -66,7 +66,6 @@ class TaskConfig(dict):
doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None
use_prompt: str = None
delimiter: str = "\n\n"
description: str = ""
num_fewshot: int = 0
......@@ -77,6 +76,8 @@ class TaskConfig(dict):
gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until"
generation_kwargs: dict = None
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
filter_list: Union[str, list] = None
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
......
import os
from lm_eval.base import BaseLM
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
import time
......@@ -37,7 +38,8 @@ def anthropic_completion(
backoff_time *= 1.5
class AnthropicLM(BaseLM):
@register_model("anthropic")
class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20
def __init__(self, model):
......
......@@ -12,7 +12,7 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from accelerate import Accelerator
from itertools import islice
from typing import Optional, Union
@register_model("hf-causal")
......@@ -23,6 +23,7 @@ class HFLM(LM):
pretrained="gpt2",
revision="main",
low_cpu_mem_usage=None,
dtype: Optional[Union[str, torch.dtype]] = "auto",
subfolder=None,
tokenizer=None,
batch_size=1,
......@@ -58,10 +59,15 @@ class HFLM(LM):
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
pretrained,
revision=revision,
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=utils.get_dtype(dtype),
).to(self.device)
self.model.eval()
print(self.model.dtype)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision,
......
......@@ -58,7 +58,7 @@ def oa_completion(**kwargs):
@register_model("openai", "openai-completions", "gooseai")
class GPT3LM(LM):
class OpenaiCompletionsLM(LM):
REQ_CHUNK_SIZE = 20
def __init__(self, engine, truncate=False):
......
......@@ -10,7 +10,8 @@ validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{answer_choices[gold]}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......
......@@ -10,7 +10,8 @@ validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{answer_choices[gold]}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......
......@@ -16,6 +16,6 @@ metric_list:
- metric: perplexity
aggregation: perplexity
higher_is_better: false
- metric: accuracy
- metric: acc
aggregation: mean
higher_is_better: true
......@@ -17,6 +17,6 @@ metric_list:
- metric: perplexity
aggregation: perplexity
higher_is_better: false
- metric: accuracy
- metric: acc
aggregation: mean
higher_is_better: true
......@@ -15,6 +15,6 @@ metric_list:
- metric: perplexity
aggregation: perplexity
higher_is_better: false
- metric: accuracy
- metric: acc
aggregation: mean
higher_is_better: true
......@@ -9,7 +9,8 @@ validation_split: validation
test_split: null
template_aliases: "{% set question = goal %}{% set answer_choices = [sol1, sol2] %}{% set gold = label %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{answer_choices[gold]}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......
......@@ -9,7 +9,7 @@ validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: " {{correct_answer}}"
doc_to_target: "{{correct_answer}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
......
......@@ -7,5 +7,6 @@ output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int.
doc_to_target: "{{answer_choices[labe]}}"
gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
......@@ -7,7 +7,8 @@ output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int.
doc_to_target: "{{answer_choices[labe]}}"
gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}"
metric_list:
- metric: acc
......
......@@ -419,3 +419,13 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
......@@ -28,6 +28,7 @@ setuptools.setup(
python_requires=">=3.9",
install_requires=[
"accelerate>=0.18.0",
"evaluate",
"datasets>=2.0.0",
"jsonlines",
"numexpr",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment