Commit 16bc6bc0 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

merge conflicts

parents 3d7f777d 465c695b
...@@ -8,10 +8,11 @@ on: ...@@ -8,10 +8,11 @@ on:
branches: branches:
- big-refactor - big-refactor
workflow_dispatch: workflow_dispatch:
# comment/edit out the above to stop/change the triggers
jobs: jobs:
changed_files: changed_files:
runs-on: ubuntu-latest # windows-latest || macos-latest runs-on: ubuntu-latest # windows-latest || macos-latest
timeout-minutes: 120
name: Scan for changed tasks name: Scan for changed tasks
steps: steps:
- name: checkout - name: checkout
...@@ -19,11 +20,15 @@ jobs: ...@@ -19,11 +20,15 @@ jobs:
with: with:
fetch-depth: 0 # OR "2" -> To retrieve the preceding commit. fetch-depth: 0 # OR "2" -> To retrieve the preceding commit.
# Example 1 # Uses the tj-actions/changed-files@v37 action to check for changes.
# Outputs provided here: https://github.com/tj-actions/changed-files#outputs
# The `files_yaml` input optionally takes a yaml string to specify filters,
# and prepends the filter name to the standard output names.
- name: Check task folders - name: Check task folders
id: changed-tasks id: changed-tasks
uses: tj-actions/changed-files@v37.1.2 uses: tj-actions/changed-files@v37.1.2
with: with:
# tasks checks the tasks folder and api checks the api folder for changes
files_yaml: | files_yaml: |
tasks: tasks:
- lm_eval/tasks/** - lm_eval/tasks/**
...@@ -31,18 +36,20 @@ jobs: ...@@ -31,18 +36,20 @@ jobs:
- lm_eval/api/** - lm_eval/api/**
write_output_files: true write_output_files: true
# The next step is optional; the files are written to the workspace by default (above).
# so it's just for debugging
- name: Run Tests - name: Run Tests
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | run: |
echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV' echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
echo "One or more test file(s) has changed." echo "One or more test file(s) has changed."
echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}" echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
- name: Set up Python 3.9 - name: Set up Python 3.9
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: 3.9 python-version: 3.9
cache: 'pip'
- name: Install dependencies - name: Install dependencies
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | run: |
...@@ -52,10 +59,12 @@ jobs: ...@@ -52,10 +59,12 @@ jobs:
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest - name: Test with pytest
# if new tasks are added, run tests on them
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
run: python -m pytest tests/test_tasks.py -s -vv -n=auto --new_task run: python -m pytest tests/extra/test_new_tasks.py -s -vv -n=auto
# if api is modified, run tests on it
- name: Test more tasks with pytest - name: Test more tasks with pytest
env: env:
API: true API: true
if: steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.api_any_modified == 'true'
run: python -m pytest tests/test_api.py -s -vv -n=auto --new_task run: python -m pytest tests/extra/test_new_tasks.py -s -vv -n=auto
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
# just comment out unwanted steps to turn off the test.
name: Unit Tests name: Unit Tests
on: on:
...@@ -11,7 +11,8 @@ on: ...@@ -11,7 +11,8 @@ on:
branches: branches:
- big-refactor - big-refactor
workflow_dispatch: workflow_dispatch:
# Jobs run concurrently and steps run sequentially within a job.
# jobs: linter and cpu_tests. Add more jobs/steps as required.
jobs: jobs:
linter: linter:
name: Linters name: Linters
...@@ -35,9 +36,10 @@ jobs: ...@@ -35,9 +36,10 @@ jobs:
flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Lint with mypy # mypy turned off for now
run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable # - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
# Job 2
testcpu: testcpu:
name: CPU Tests name: CPU Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
...@@ -53,7 +55,7 @@ jobs: ...@@ -53,7 +55,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
......
...@@ -3,6 +3,15 @@ env ...@@ -3,6 +3,15 @@ env
data/ data/
lm_cache lm_cache
.idea .idea
build
*.egg-info/ dist
*.egg-info
venv
.vscode/ .vscode/
temp
__pycache__
.ipynb_checkpoints
temp
# IPython
profile_default/
ipython_config.py
...@@ -119,6 +119,9 @@ class TaskConfig(dict): ...@@ -119,6 +119,9 @@ class TaskConfig(dict):
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self): def to_dict(self):
"""dumps the current config as a dictionary object, as a printable format. """dumps the current config as a dictionary object, as a printable format.
null fields will not be printed. null fields will not be printed.
...@@ -460,6 +463,9 @@ class Task(abc.ABC): ...@@ -460,6 +463,9 @@ class Task(abc.ABC):
return labeled_examples + example return labeled_examples + example
elif type(example) == list: elif type(example) == list:
return [labeled_examples + ex for ex in example] return [labeled_examples + ex for ex in example]
elif type(example) == int:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
def apply_filters(self): def apply_filters(self):
...@@ -944,22 +950,21 @@ class ConfigurableTask(Task): ...@@ -944,22 +950,21 @@ class ConfigurableTask(Task):
if self.multiple_target: if self.multiple_target:
acc = 1.0 if pred in gold else 0.0 acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0 acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] for i in gold]))
else: else:
acc = 1.0 if pred == gold else 0.0 acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0 acc_norm = 1.0 if pred_norm == gold else 0.0
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold])
result_dict = { result_dict = {
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}), **({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}), **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
} }
if "exact_match" in self._metric_fn_list.keys():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
is_greedy = is_greedy[gold] # take value for the gold answer
result_dict["exact_match"] = int(is_greedy)
if "acc_mutual_info" in use_metric: if "acc_mutual_info" in use_metric:
lls_mutual_info = [ lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional) ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
......
...@@ -35,7 +35,7 @@ def simple_evaluate( ...@@ -35,7 +35,7 @@ def simple_evaluate(
model, model,
model_args=None, model_args=None,
tasks=[], tasks=[],
num_fewshot=0, num_fewshot=None,
batch_size=None, batch_size=None,
max_batch_size=None, max_batch_size=None,
device=None, device=None,
...@@ -112,7 +112,17 @@ def simple_evaluate( ...@@ -112,7 +112,17 @@ def simple_evaluate(
+ "_rank" + str(lm.rank) + ".db", + "_rank" + str(lm.rank) + ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot) task_dict = lm_eval.tasks.get_task_dict(tasks)
for task_name in task_dict.keys():
config = task_dict[task_name]._config
if num_fewshot is not None:
if config["num_fewshot"] > 0:
default_num_fewshot = config["num_fewshot"]
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_dict[task_name]._config["num_fewshot"] = num_fewshot
if check_integrity: if check_integrity:
run_task_tests(task_list=tasks) run_task_tests(task_list=tasks)
...@@ -134,7 +144,6 @@ def simple_evaluate( ...@@ -134,7 +144,6 @@ def simple_evaluate(
if isinstance(model, str) if isinstance(model, str)
else model.model.config._name_or_path, else model.model.config._name_or_path,
"model_args": model_args, "model_args": model_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size, "batch_size": batch_size,
"batch_sizes": list(lm.batch_sizes.values()) "batch_sizes": list(lm.batch_sizes.values())
if hasattr(lm, "batch_sizes") if hasattr(lm, "batch_sizes")
...@@ -169,8 +178,6 @@ def evaluate( ...@@ -169,8 +178,6 @@ def evaluate(
Language Model Language Model
:param task_dict: dict[str, Task] :param task_dict: dict[str, Task]
Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param limit: int, optional :param limit: int, optional
Limit the number of examples per task (only use this for testing) Limit the number of examples per task (only use this for testing)
:param bootstrap_iters: :param bootstrap_iters:
......
...@@ -3,21 +3,28 @@ from lm_eval.api.model import LM ...@@ -3,21 +3,28 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
import time import time
import anthropic
from lm_eval.logger import eval_logger
from typing import List, Literal, Any
def anthropic_completion( def anthropic_completion(
client, model, prompt, max_tokens_to_sample, temperature, stop client: anthropic.Anthropic,
model: str,
prompt: str,
max_tokens_to_sample: int,
temperature: float,
stop: List[str],
**kwargs: Any,
): ):
"""Query Anthropic API for completion. """Query Anthropic API for completion.
Retry with back-off until they respond Retry with back-off until they respond
""" """
import anthropic
backoff_time = 3 backoff_time = 3
while True: while True:
try: try:
response = client.completion( response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}", prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model, model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
...@@ -25,36 +32,53 @@ def anthropic_completion( ...@@ -25,36 +32,53 @@ def anthropic_completion(
stop_sequences=[anthropic.HUMAN_PROMPT] + stop, stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
max_tokens_to_sample=max_tokens_to_sample, max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature, temperature=temperature,
**kwargs,
)
return response.completion
except anthropic.RateLimitError as e:
eval_logger.warning(
f"RateLimitError occurred: {e.__cause__}\n Retrying in {backoff_time} seconds"
) )
return response["completion"]
except RuntimeError:
# TODO: I don't actually know what error Anthropic raises when it times out
# So err update this error when we find out.
import traceback
traceback.print_exc()
time.sleep(backoff_time) time.sleep(backoff_time)
backoff_time *= 1.5 backoff_time *= 1.5
@register_model("anthropic") @register_model("anthropic")
class AnthropicLM(LM): class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20 # TODO: not used
def __init__(self, model): def __init__(
""" self,
batch_size: int = 1,
model: str = "claude-2.0",
max_tokens_to_sample: int = 256,
temperature: float = 0, # defaults to 1
**kwargs, # top_p, top_k, etc.
):
"""Anthropic API wrapper.
:param model: str :param model: str
Anthropic model e.g. claude-instant-v1 Anthropic model e.g. 'claude-instant-v1', 'claude-2'
:param max_tokens_to_sample: int
Maximum number of tokens to sample from the model
:param temperature: float
Sampling temperature
:param kwargs: Any
Additional model_args to pass to the API client
""" """
super().__init__() super().__init__()
import anthropic
self.model = model self.model = model
self.client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"]) # defaults to os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic()
self.temperature = temperature
self.max_tokens_to_sample = max_tokens_to_sample
self.tokenizer = self.client.get_tokenizer()
self.kwargs = kwargs
@property @property
def eot_token_id(self): def eot_token_id(self):
# Not sure but anthropic.AI_PROMPT -> [203, 203, 50803, 30]
raise NotImplementedError("No idea about anthropic tokenization.") raise NotImplementedError("No idea about anthropic tokenization.")
@property @property
...@@ -63,23 +87,23 @@ class AnthropicLM(LM): ...@@ -63,23 +87,23 @@ class AnthropicLM(LM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return 256 return self.max_tokens_to_sample
@property @property
def batch_size(self): def batch_size(self):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError("No support for logits.")
@property @property
def device(self): def device(self):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError("No support for logits.")
def tok_encode(self, string: str): def tok_encode(self, string: str) -> List[int]:
raise NotImplementedError("No idea about anthropic tokenization.") return self.tokenizer.encode(string).ids
def tok_decode(self, tokens): def tok_decode(self, tokens: List[int]) -> str:
raise NotImplementedError("No idea about anthropic tokenization.") return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
...@@ -92,20 +116,31 @@ class AnthropicLM(LM): ...@@ -92,20 +116,31 @@ class AnthropicLM(LM):
res = [] res = []
for request in tqdm(requests): for request in tqdm(requests):
try:
inp = request[0] inp = request[0]
request_args = request[1] request_args = request[1]
until = request_args["until"] # generation_kwargs
until = request_args.get("until")
max_gen_toks = request_args.get("max_gen_toks", self.max_length)
temperature = request_args.get("temperature", self.temperature)
response = anthropic_completion( response = anthropic_completion(
client=self.client, client=self.client,
model=self.model, model=self.model,
prompt=inp, prompt=inp,
max_tokens_to_sample=self.max_gen_toks, max_tokens_to_sample=max_gen_toks,
temperature=0.0, # TODO: implement non-greedy sampling for Anthropic temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until, stop=until,
**self.kwargs,
) )
res.append(response) res.append(response)
self.cache_hook.add_partial("greedy_until", request, response) self.cache_hook.add_partial("greedy_until", request, response)
except anthropic.APIConnectionError as e:
eval_logger.critical(f"Server unreachable: {e.__cause__}")
break
except anthropic.APIStatusError as e:
eval_logger.critical(f"API error {e.status_code}: {e.message}")
break
return res return res
...@@ -116,3 +151,9 @@ class AnthropicLM(LM): ...@@ -116,3 +151,9 @@ class AnthropicLM(LM):
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until # Isn't used because we override greedy_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests):
raise NotImplementedError("No support for logits.")
...@@ -740,7 +740,7 @@ class HFLM(LM): ...@@ -740,7 +740,7 @@ class HFLM(LM):
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering # first stop sequence is used to halt generation upon encountering
(primary_until) = until[0] primary_until = [until[0]]
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
......
# Task-name
### Paper
Title: `Semantic Parsing on Freebase from Question-Answer Pairs`
Abstract: `https://cs.stanford.edu/~pliang/papers/freebase-emnlp2013.pdf`
WebQuestions is a benchmark for question answering. The dataset consists of 6,642
question/answer pairs. The questions are supposed to be answerable by Freebase, a
large knowledge graph. The questions are mostly centered around a single named entity.
The questions are popular ones asked on the web (at least in 2013).
Homepage: `https://worksheets.codalab.org/worksheets/0xba659fe363cb46e7a505c5b6a774dc8a`
### Citation
```
@inproceedings{berant-etal-2013-semantic,
title = "Semantic Parsing on {F}reebase from Question-Answer Pairs",
author = "Berant, Jonathan and
Chou, Andrew and
Frostig, Roy and
Liang, Percy",
booktitle = "Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing",
month = oct,
year = "2013",
address = "Seattle, Washington, USA",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/D13-1160",
pages = "1533--1544",
}
```
### Subtasks
List or describe tasks defined in this folder, and their names here:
* `webqs`: `Questions with multiple accepted answers.`
### Checklist
For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
from typing import Dict, List
def doc_to_choice(doc: Dict) -> List[str]:
"""Return all of the accepted answers as choices."""
return _remove_prefixes(doc["answers"])
def doc_to_target(doc: Dict) -> List[int]:
"""Return list of indices of accepted answers (all of them)."""
remaining = _remove_prefixes(doc["answers"])
return list(range(len(remaining)))
def _remove_prefixes(aliases):
"""
Remove any alias that has a strict prefix elsewhere in the list.
This is an optimization. We can do this because if the prefix is acceptable by isgreedy,
we can stop looking.
"""
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
group:
- freebase
- question_answer
task: webqs
dataset_path: web_questions
dataset_name: null
output_type: multiple_choice
training_split: train
validation_split: null
test_split: test
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: !function utils.doc_to_target
doc_to_choice: !function utils.doc_to_choice
should_decontaminate: true
doc_to_decontamination_query: question
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
# Task-name
### Paper
Title: `It's All in the Heads: Using Attention Heads as a Baseline for Cross-Lingual Transfer in Commonsense Reasoning`
Abstract: `https://arxiv.org/abs/2106.12066`
Multilingual winograd schema challenge that includes English, French, Japanese, Portuguese, Russian and Chinese. Winograd schema challenges come from the XWinograd dataset introduced in Tikhonov et al. As it only contains 16 Chinese schemas, we add 488 Chinese schemas from clue/cluewsc2020.
Homepage: `https://huggingface.co/datasets/Muennighoff/xwinograd`
### Citation
```
@misc{muennighoff2022crosslingual,
title={Crosslingual Generalization through Multitask Finetuning},
author={Niklas Muennighoff and Thomas Wang and Lintang Sutawika and Adam Roberts and Stella Biderman and Teven Le Scao and M Saiful Bari and Sheng Shen and Zheng-Xin Yong and Hailey Schoelkopf and Xiangru Tang and Dragomir Radev and Alham Fikri Aji and Khalid Almubarak and Samuel Albanie and Zaid Alyafeai and Albert Webson and Edward Raff and Colin Raffel},
year={2022},
eprint={2211.01786},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{tikhonov2021heads,
title={It's All in the Heads: Using Attention Heads as a Baseline for Cross-Lingual Transfer in Commonsense Reasoning},
author={Alexey Tikhonov and Max Ryabinin},
year={2021},
eprint={2106.12066},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
### Subtasks
List or describe tasks defined in this folder, and their names here:
* `xwinograd_en`: Winograd schema challenges in English.
* `xwinograd_fr`: Winograd schema challenges in French.
* `xwinograd_jp`: Winograd schema challenges in Japanese.
* `xwinograd_pt`: Winograd schema challenges in Portuguese.
* `xwinograd_ru`: Winograd schema challenges in Russian.
* `xwinograd_zh`: Winograd schema challenges in Chinese.
### Checklist
For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
import argparse
from typing import Dict, List
import yaml
# Different languages that are part of xwinograd.
# These correspond to dataset names (Subsets) on HuggingFace.
# A yaml file is generated by this script for each language.
LANGUAGES = ["en", "fr", "jp", "pt", "ru", "zh"]
def doc_to_text(doc: Dict) -> int:
"""
Return index of the correct choice.
Note: We are using the "multiple input" mode of the multiple-choice
output-type, which means we use different contexts with the same target
for the different choices, rather than the same context and different targets.
"""
answer_to_num = {"1": 0, "2": 1}
return answer_to_num[doc["answer"]]
def doc_to_target(doc: Dict) -> str:
"""
Return the target completion.
Note that this does not depend on the correct choice as we are using
"multiple input" mode.
"""
idx = doc["sentence"].index("_") + 1
return doc["sentence"][idx:].strip()
def doc_to_choice(doc: Dict) -> List[str]:
"""Return the choices that will be used as contexts in "multiple input" mode."""
idx = doc["sentence"].index("_")
options = [doc["option1"], doc["option2"]]
return [doc["sentence"][:idx] + opt for opt in options]
def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
"""
Generate a yaml file for each language.
:param output_dir: The directory to output the files to.
:param overwrite: Whether to overwrite files if they already exist.
"""
err = []
for lang in LANGUAGES:
file_name = f"xwinograd_{lang}.yaml"
try:
with open(f"{output_dir}/{file_name}", "w" if overwrite else "x") as f:
f.write("# Generated by utils.py\n")
yaml.dump(
{
"include": "xwinograd_common_yaml",
"dataset_name": lang,
"task": f"xwinograd_{lang}",
},
f,
)
except FileExistsError:
err.append(file_name)
if len(err) > 0:
raise FileExistsError(
"Files were not created because they already exist (use --overwrite flag):"
f" {', '.join(err)}"
)
def main() -> None:
"""Parse CLI args and generate language-specific yaml files."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
help="Overwrite files if they already exist",
)
parser.add_argument(
"--output-dir", default=".", help="Directory to write yaml files to"
)
args = parser.parse_args()
gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite)
if __name__ == "__main__":
main()
# This file will be included in the generated language-specific task configs.
# It doesn't have a yaml file extension as it is not meant to be imported directly
# by the harness.
group:
- winograd
- commonsense
- multilingual
dataset_path: Muennighoff/xwinograd
dataset_name: null # Overridden by language-specific config.
output_type: multiple_choice
training_split: null
validation_split: null
test_split: test
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
doc_to_choice: !function utils.doc_to_choice
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
# Generated by utils.py
dataset_name: en
include: xwinograd_common_yaml
task: xwinograd_en
# Generated by utils.py
dataset_name: fr
include: xwinograd_common_yaml
task: xwinograd_fr
# Generated by utils.py
dataset_name: jp
include: xwinograd_common_yaml
task: xwinograd_jp
# Generated by utils.py
dataset_name: pt
include: xwinograd_common_yaml
task: xwinograd_pt
# Generated by utils.py
dataset_name: ru
include: xwinograd_common_yaml
task: xwinograd_ru
# Generated by utils.py
dataset_name: zh
include: xwinograd_common_yaml
task: xwinograd_zh
...@@ -265,10 +265,20 @@ def make_table(result_dict): ...@@ -265,10 +265,20 @@ def make_table(result_dict):
md_writer = MarkdownTableWriter() md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter() latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Filter", "Metric", "Value", "", "Stderr"] md_writer.headers = [
"Task",
"Version",
"Fewshot",
"Filter",
"Metric",
"Value",
"",
"Stderr",
]
latex_writer.headers = [ latex_writer.headers = [
"Task", "Task",
"Version", "Version",
"Fewshot",
"Filter", "Filter",
"Metric", "Metric",
"Value", "Value",
...@@ -280,6 +290,7 @@ def make_table(result_dict): ...@@ -280,6 +290,7 @@ def make_table(result_dict):
for k, dic in result_dict["results"].items(): for k, dic in result_dict["results"].items():
version = result_dict["versions"][k] version = result_dict["versions"][k]
n = str(result_dict["configs"][k]["num_fewshot"])
for (mf), v in dic.items(): for (mf), v in dic.items():
m, _, f = mf.partition(",") m, _, f = mf.partition(",")
if m.endswith("_stderr"): if m.endswith("_stderr"):
...@@ -287,10 +298,11 @@ def make_table(result_dict): ...@@ -287,10 +298,11 @@ def make_table(result_dict):
if m + "_stderr" + "," + f in dic: if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f] se = dic[m + "_stderr" + "," + f]
values.append([k, version, f, m, "%.4f" % v, "±", "%.4f" % se]) values.append([k, version, n, f, m, "%.4f" % v, "±", "%.4f" % se])
else: else:
values.append([k, version, f, m, "%.4f" % v, "", ""]) values.append([k, version, n, f, m, "%.4f" % v, "", ""])
k = "" k = ""
n = ""
version = "" version = ""
md_writer.value_matrix = values md_writer.value_matrix = values
latex_writer.value_matrix = values latex_writer.value_matrix = values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment