"tutorials/models/vscode:/vscode.git/clone" did not exist on "01e8794f124a163935b97d1486f53e6757d1d437"
Commit e0281126 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

Merge remote-tracking branch 'upstream/big-refactor' into big-refactor-autobatching

parents 96ea9f54 4e44f0aa
...@@ -8,10 +8,11 @@ on: ...@@ -8,10 +8,11 @@ on:
branches: branches:
- big-refactor - big-refactor
workflow_dispatch: workflow_dispatch:
# comment/edit out the above to stop/change the triggers
jobs: jobs:
changed_files: changed_files:
runs-on: ubuntu-latest # windows-latest || macos-latest runs-on: ubuntu-latest # windows-latest || macos-latest
timeout-minutes: 120
name: Scan for changed tasks name: Scan for changed tasks
steps: steps:
- name: checkout - name: checkout
...@@ -19,11 +20,15 @@ jobs: ...@@ -19,11 +20,15 @@ jobs:
with: with:
fetch-depth: 0 # OR "2" -> To retrieve the preceding commit. fetch-depth: 0 # OR "2" -> To retrieve the preceding commit.
# Example 1 # Uses the tj-actions/changed-files@v37 action to check for changes.
# Outputs provided here: https://github.com/tj-actions/changed-files#outputs
# The `files_yaml` input optionally takes a yaml string to specify filters,
# and prepends the filter name to the standard output names.
- name: Check task folders - name: Check task folders
id: changed-tasks id: changed-tasks
uses: tj-actions/changed-files@v37.1.2 uses: tj-actions/changed-files@v37.1.2
with: with:
# tasks checks the tasks folder and api checks the api folder for changes
files_yaml: | files_yaml: |
tasks: tasks:
- lm_eval/tasks/** - lm_eval/tasks/**
...@@ -31,31 +36,35 @@ jobs: ...@@ -31,31 +36,35 @@ jobs:
- lm_eval/api/** - lm_eval/api/**
write_output_files: true write_output_files: true
# The next step is optional; the files are written to the workspace by default (above).
# so it's just for debugging
- name: Run Tests - name: Run Tests
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | run: |
echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV' echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
echo "One or more test file(s) has changed." echo "One or more test file(s) has changed."
echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}" echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
- name: Set up Python 3.9 - name: Set up Python 3.9
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: 3.9 python-version: 3.9
cache: 'pip'
- name: Install dependencies - name: Install dependencies
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest - name: Test with pytest
# if new tasks are added, run tests on them
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
run: python -m pytest tests/test_tasks.py -s -vv -n=auto --new_task run: python -m pytest tests/extra/test_new_tasks.py -s -vv -n=auto
# if api is modified, run tests on it
- name: Test more tasks with pytest - name: Test more tasks with pytest
env: env:
API: true API: true
if: steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.api_any_modified == 'true'
run: python -m pytest tests/test_api.py -s -vv -n=auto --new_task run: python -m pytest tests/extra/test_new_tasks.py -s -vv -n=auto
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
# just comment out unwanted steps to turn off the test.
name: Unit Tests name: Unit Tests
on: on:
...@@ -11,7 +11,8 @@ on: ...@@ -11,7 +11,8 @@ on:
branches: branches:
- big-refactor - big-refactor
workflow_dispatch: workflow_dispatch:
# Jobs run concurrently and steps run sequentially within a job.
# jobs: linter and cpu_tests. Add more jobs/steps as required.
jobs: jobs:
linter: linter:
name: Linters name: Linters
...@@ -35,9 +36,10 @@ jobs: ...@@ -35,9 +36,10 @@ jobs:
flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Lint with mypy # mypy turned off for now
run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable # - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
# Job 2
testcpu: testcpu:
name: CPU Tests name: CPU Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
......
...@@ -3,6 +3,15 @@ env ...@@ -3,6 +3,15 @@ env
data/ data/
lm_cache lm_cache
.idea .idea
build
*.egg-info/ dist
*.egg-info
venv
.vscode/ .vscode/
temp
__pycache__
.ipynb_checkpoints
temp
# IPython
profile_default/
ipython_config.py
...@@ -471,6 +471,9 @@ class Task(abc.ABC): ...@@ -471,6 +471,9 @@ class Task(abc.ABC):
return labeled_examples + example return labeled_examples + example
elif type(example) == list: elif type(example) == list:
return [labeled_examples + ex for ex in example] return [labeled_examples + ex for ex in example]
elif type(example) == int:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
def apply_filters(self): def apply_filters(self):
......
...@@ -223,10 +223,7 @@ def evaluate( ...@@ -223,10 +223,7 @@ def evaluate(
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
reqtype = ( reqtype = (
"loglikelihood" "loglikelihood"
if ( if task.OUTPUT_TYPE == "multiple_choice"
task.OUTPUT_TYPE == "multiple_choice"
or task.OUTPUT_TYPE == "winograd_schema"
)
else task.OUTPUT_TYPE else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py ) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances) requests[reqtype].extend(task.instances)
......
...@@ -812,7 +812,7 @@ class HFLM(LM): ...@@ -812,7 +812,7 @@ class HFLM(LM):
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering # first stop sequence is used to halt generation upon encountering
(primary_until) = until[0] primary_until = [until[0]]
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
......
# Task-name
### Paper
Title: `It's All in the Heads: Using Attention Heads as a Baseline for Cross-Lingual Transfer in Commonsense Reasoning`
Abstract: `https://arxiv.org/abs/2106.12066`
Multilingual winograd schema challenge that includes English, French, Japanese, Portuguese, Russian and Chinese. Winograd schema challenges come from the XWinograd dataset introduced in Tikhonov et al. As it only contains 16 Chinese schemas, we add 488 Chinese schemas from clue/cluewsc2020.
Homepage: `https://huggingface.co/datasets/Muennighoff/xwinograd`
### Citation
```
@misc{muennighoff2022crosslingual,
title={Crosslingual Generalization through Multitask Finetuning},
author={Niklas Muennighoff and Thomas Wang and Lintang Sutawika and Adam Roberts and Stella Biderman and Teven Le Scao and M Saiful Bari and Sheng Shen and Zheng-Xin Yong and Hailey Schoelkopf and Xiangru Tang and Dragomir Radev and Alham Fikri Aji and Khalid Almubarak and Samuel Albanie and Zaid Alyafeai and Albert Webson and Edward Raff and Colin Raffel},
year={2022},
eprint={2211.01786},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{tikhonov2021heads,
title={It's All in the Heads: Using Attention Heads as a Baseline for Cross-Lingual Transfer in Commonsense Reasoning},
author={Alexey Tikhonov and Max Ryabinin},
year={2021},
eprint={2106.12066},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
### Subtasks
List or describe tasks defined in this folder, and their names here:
* `xwinograd_en`: Winograd schema challenges in English.
* `xwinograd_fr`: Winograd schema challenges in French.
* `xwinograd_jp`: Winograd schema challenges in Japanese.
* `xwinograd_pt`: Winograd schema challenges in Portuguese.
* `xwinograd_ru`: Winograd schema challenges in Russian.
* `xwinograd_zh`: Winograd schema challenges in Chinese.
### Checklist
For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
import argparse
from typing import Dict, List
import yaml
# Different languages that are part of xwinograd.
# These correspond to dataset names (Subsets) on HuggingFace.
# A yaml file is generated by this script for each language.
LANGUAGES = ["en", "fr", "jp", "pt", "ru", "zh"]
def doc_to_text(doc: Dict) -> int:
"""
Return index of the correct choice.
Note: We are using the "multiple input" mode of the multiple-choice
output-type, which means we use different contexts with the same target
for the different choices, rather than the same context and different targets.
"""
answer_to_num = {"1": 0, "2": 1}
return answer_to_num[doc["answer"]]
def doc_to_target(doc: Dict) -> str:
"""
Return the target completion.
Note that this does not depend on the correct choice as we are using
"multiple input" mode.
"""
idx = doc["sentence"].index("_") + 1
return doc["sentence"][idx:].strip()
def doc_to_choice(doc: Dict) -> List[str]:
"""Return the choices that will be used as contexts in "multiple input" mode."""
idx = doc["sentence"].index("_")
options = [doc["option1"], doc["option2"]]
return [doc["sentence"][:idx] + opt for opt in options]
def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
"""
Generate a yaml file for each language.
:param output_dir: The directory to output the files to.
:param overwrite: Whether to overwrite files if they already exist.
"""
err = []
for lang in LANGUAGES:
file_name = f"xwinograd_{lang}.yaml"
try:
with open(f"{output_dir}/{file_name}", "w" if overwrite else "x") as f:
f.write("# Generated by utils.py\n")
yaml.dump(
{
"include": "xwinograd_common_yaml",
"dataset_name": lang,
"task": f"xwinograd_{lang}",
},
f,
)
except FileExistsError:
err.append(file_name)
if len(err) > 0:
raise FileExistsError(
"Files were not created because they already exist (use --overwrite flag):"
f" {', '.join(err)}"
)
def main() -> None:
"""Parse CLI args and generate language-specific yaml files."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
help="Overwrite files if they already exist",
)
parser.add_argument(
"--output-dir", default=".", help="Directory to write yaml files to"
)
args = parser.parse_args()
gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite)
if __name__ == "__main__":
main()
# This file will be included in the generated language-specific task configs.
# It doesn't have a yaml file extension as it is not meant to be imported directly
# by the harness.
group:
- winograd
- commonsense
- multilingual
dataset_path: Muennighoff/xwinograd
dataset_name: null # Overridden by language-specific config.
output_type: multiple_choice
training_split: null
validation_split: null
test_split: test
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
doc_to_choice: !function utils.doc_to_choice
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
# Generated by utils.py
dataset_name: en
include: xwinograd_common_yaml
task: xwinograd_en
# Generated by utils.py
dataset_name: fr
include: xwinograd_common_yaml
task: xwinograd_fr
# Generated by utils.py
dataset_name: jp
include: xwinograd_common_yaml
task: xwinograd_jp
# Generated by utils.py
dataset_name: pt
include: xwinograd_common_yaml
task: xwinograd_pt
# Generated by utils.py
dataset_name: ru
include: xwinograd_common_yaml
task: xwinograd_ru
# Generated by utils.py
dataset_name: zh
include: xwinograd_common_yaml
task: xwinograd_zh
...@@ -18,6 +18,9 @@ setuptools.setup( ...@@ -18,6 +18,9 @@ setuptools.setup(
"lm_eval": ["**/*.yaml"], "lm_eval": ["**/*.yaml"],
"examples": ["**/*.yaml"], "examples": ["**/*.yaml"],
}, },
entry_points={
"console_scripts": ["lm-eval = main:main", "lm_eval = main:main"],
},
include_package_data=True, include_package_data=True,
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
......
def pytest_addoption(parser):
parser.addoption(
"--new_task",
action="store_true",
help="new_tasks_found",
)
import pytest
from itertools import islice
import lm_eval.tasks as tasks
from .utilities_testing import load_changed_files, parser
from typing import List
from lm_eval.api.task import ConfigurableTask
import os
# GitHub CI
def new_tasks() -> List[str]:
FILENAME = ".github/outputs/tasks_all_changed_and_modified_files.txt"
if os.path.exists(FILENAME):
# If tasks folder has changed then we get the list of files from FILENAME
# and parse the yaml files to get the task names.
return parser(load_changed_files(FILENAME))
elif os.getenv("API") is not None:
# Or if API has changed then we set the ENV variable API to True
# and run given tasks.
return ["arc_easy", "hellaswag", "piqa", "wikitext"]
# if both not true just do arc_easy
else:
return ["arc_easy"]
def get_task_class() -> List[ConfigurableTask]:
task_name = new_tasks()
x = [cls for name, cls in tasks.TASK_REGISTRY.items() if name in task_name]
return x
@pytest.fixture()
def limit() -> int:
return 10
# Tests
@pytest.mark.parametrize("task_class", get_task_class())
class TestNewTasks:
def test_download(self, task_class: ConfigurableTask):
task_class().download()
assert task_class().dataset is not None
def test_has_training_docs(self, task_class: ConfigurableTask):
assert task_class().has_training_docs() in [True, False]
def test_check_training_docs(self, task_class: ConfigurableTask):
task = task_class()
if task.has_training_docs():
assert task._config["training_split"] is not None
def test_has_validation_docs(self, task_class):
assert task_class().has_validation_docs() in [True, False]
def test_check_validation_docs(self, task_class):
task = task_class()
if task.has_validation_docs():
assert task._config["validation_split"] is not None
def test_has_test_docs(self, task_class):
assert task_class().has_test_docs() in [True, False]
def test_check_test_docs(self, task_class):
task = task_class()
if task.has_test_docs():
assert task._config["test_split"] is not None
def test_should_decontaminate(self, task_class):
task = task_class()
assert task.should_decontaminate() in [True, False]
if task.should_decontaminate():
assert task._config["doc_to_decontamination_query"] is not None
def test_doc_to_text(self, task_class, limit):
task = task_class()
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
_array = [task.doc_to_text(doc) for doc in arr]
# space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
assert all(
isinstance(x, str) and (x[-1] != " " if len(x) != 0 else True)
for x in _array
)
def test_create_choices(self, task_class, limit):
task = task_class()
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
if "multiple_choice" in task._config.group:
_array = [task.doc_to_choice(doc) for doc in arr]
# assert all(len(x) == 4 for x in _array)
assert all(isinstance(x, list) for x in _array)
assert all(isinstance(x[0], str) for x in _array)
def test_doc_to_target(self, task_class, limit):
task = task_class()
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
_array_target = [task.doc_to_target(doc) for doc in arr]
assert all(isinstance(label, int) for label in _array_target)
assert len(_array_target) == limit if limit else True
# _array_text = [task.doc_to_text(doc) for doc in arr]
# Not working
# assert all(tgt[0] == " " or txt[-1] == "\n" if len(txt) != 0 else True for txt, tgt in zip(_array_text, _array_target))
def test_build_all_requests(self, task_class, limit):
task_class().build_all_requests(rank=1, limit=limit, world_size=1)
assert task_class.instances is not None
def test_construct_requests(self, task_class, limit):
task = task_class()
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
requests = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
assert all(isinstance(doc, list) for doc in requests)
assert len(requests) == limit if limit else True
...@@ -2,16 +2,25 @@ import json ...@@ -2,16 +2,25 @@ import json
from typing import List from typing import List
from lm_eval.utils import load_yaml_config from lm_eval.utils import load_yaml_config
from pathlib import Path from pathlib import Path
import sys
# This is the path where the output for the changed files for the tasks folder is stored
# FILE_PATH = file_path = ".github/outputs/tasks_all_changed_and_modified_files.txt"
FILE_PATH = file_path = ".github/outputs/tasks_all_changed_and_modified_files.txt"
# reads a text file and returns a list of words
def load_changed_files(file_path: str = FILE_PATH) -> List[str]: # used to read the output of the changed txt from tj-actions/changed-files
def load_changed_files(file_path: str) -> List[str]:
with open(file_path, "r") as f: with open(file_path, "r") as f:
return [line.strip() for line in f.readlines()] content = f.read()
words_list = [x for x in content.split()]
sys.stdout.write(f"list of files: {words_list}")
return words_list
# checks the txt file for list of changed files.
# if file ends with .yaml then check yaml for task name
# if file ends with .py then parse the folder for all yaml files
def parser(full_path: List[str]) -> List[str]: def parser(full_path: List[str]) -> List[str]:
_output = set() _output = set()
for x in full_path: for x in full_path:
......
import pytest
from itertools import islice from itertools import islice
import pytest
from typing import List
import lm_eval.tasks as tasks import lm_eval.tasks as tasks
from tests.extra.test_utils import load_changed_files, parser from lm_eval.api.task import ConfigurableTask
from typing import List, ClassVar
import os
# Using fixtures to get the task class and limit
@pytest.fixture() @pytest.fixture()
def any_new_tasks(request) -> bool: def task_class() -> ConfigurableTask:
return request.config.getoption("--new_task") task_name = ["arc_easy"]
x = [cls for name, cls in tasks.TASK_REGISTRY.items() if name in task_name]
return x[0]
# ["arc_easy] else get list of new tasks
def new_tasks(any_new_tasks: bool) -> List[str]:
FILENAME = ".github/outputs/tasks_all_changed_and_modified_files.txt"
if any_new_tasks and os.path.exists(FILENAME):
return [parser(load_changed_files(FILENAME))]
elif os.getenv("API") is not None:
return ["arc_easy", "hellaswag", "piqa", "wikitext"]
else:
return ["arc_easy"]
@pytest.fixture(params=new_tasks(any_new_tasks))
def task_class(request):
task_name = request.param
return [cls for name, cls in tasks.TASK_REGISTRY.items() if name in task_name][0]
@pytest.fixture() @pytest.fixture()
def limit(any_new_tasks: bool) -> int: def limit() -> int:
return 100 if any_new_tasks else 10 return 10
# Tests # Tests
def test_download(task_class): def test_download(task_class: ConfigurableTask):
task_class().download() task_class().download()
assert task_class().dataset is not None assert task_class().dataset is not None
def test_has_training_docs(task_class): def test_has_training_docs(task_class: ConfigurableTask):
assert task_class().has_training_docs() in [True, False] assert task_class().has_training_docs() in [True, False]
def test_check_training_docs(task_class): def test_check_training_docs(task_class: ConfigurableTask):
task = task_class() task = task_class()
assert task.has_training_docs() if task._config["training_split"] else True if task.has_training_docs():
assert task._config["training_split"] is not None
def test_has_validation_docs(task_class): def test_has_validation_docs(task_class):
assert task_class().has_training_docs() in [True, False] assert task_class().has_validation_docs() in [True, False]
def test_check_validation_docs(task_class): def test_check_validation_docs(task_class):
task = task_class() task = task_class()
assert ( if task.has_validation_docs():
task_class().has_training_docs() if task._config["validation_split"] else True assert task._config["validation_split"] is not None
)
def test_has_test_docs(task_class): def test_has_test_docs(task_class):
assert task_class().has_training_docs() in [True, False] assert task_class().has_test_docs() in [True, False]
def test_check_test_docs(task_class): def test_check_test_docs(task_class):
task = task_class() task = task_class()
assert task_class().has_training_docs() if task._config["test_split"] else True if task.has_test_docs():
assert task._config["test_split"] is not None
def test_should_decontaminate(task_class): def test_should_decontaminate(task_class):
task_class = task_class() task = task_class()
assert task_class.should_decontaminate() in [True, False] assert task.should_decontaminate() in [True, False]
if task_class.should_decontaminate(): if task.should_decontaminate():
assert task_class._config["doc_to_decontamination_query"] is not None assert task._config["doc_to_decontamination_query"] is not None
def test_doc_to_text(task_class, limit): def test_doc_to_text(task_class, limit):
task = task_class()
arr = ( arr = (
list(islice(task_class().test_docs(), limit)) list(islice(task.test_docs(), limit))
if limit if task.has_test_docs()
else list(task_class().test_docs()) else list(islice(task.validation_docs(), limit))
) )
_array = [task_class().doc_to_text(doc) for doc in arr] _array = [task.doc_to_text(doc) for doc in arr]
# space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on # space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
assert all( assert all(
isinstance(x, str) and (x[-1] != " " if len(x) != 0 else True) for x in _array isinstance(x, str) and (x[-1] != " " if len(x) != 0 else True) for x in _array
...@@ -91,24 +77,27 @@ def test_doc_to_text(task_class, limit): ...@@ -91,24 +77,27 @@ def test_doc_to_text(task_class, limit):
def test_create_choices(task_class, limit): def test_create_choices(task_class, limit):
task = task_class()
arr = ( arr = (
list(islice(task_class().test_docs(), limit)) list(islice(task.test_docs(), limit))
if limit if task.has_test_docs()
else list(task_class().test_docs()) else list(islice(task.validation_docs(), limit))
) )
_array = [task_class().doc_to_choice(doc) for doc in arr] if "multiple_choice" in task._config.group:
# assert all(len(x) == 4 for x in _array) _array = [task.doc_to_choice(doc) for doc in arr]
assert all(isinstance(x, list) for x in _array) # assert all(len(x) == 4 for x in _array)
assert all(isinstance(x[0], str) for x in _array) assert all(isinstance(x, list) for x in _array)
assert all(isinstance(x[0], str) for x in _array)
def test_doc_to_target(task_class, limit): def test_doc_to_target(task_class, limit):
task = task_class()
arr = ( arr = (
list(islice(task_class().test_docs(), limit)) list(islice(task.test_docs(), limit))
if limit if task.has_test_docs()
else list(task_class().test_target()) else list(islice(task.validation_docs(), limit))
) )
_array_target = [task_class().doc_to_target(doc) for doc in arr] _array_target = [task.doc_to_target(doc) for doc in arr]
assert all(isinstance(label, int) for label in _array_target) assert all(isinstance(label, int) for label in _array_target)
assert len(_array_target) == limit if limit else True assert len(_array_target) == limit if limit else True
# _array_text = [task.doc_to_text(doc) for doc in arr] # _array_text = [task.doc_to_text(doc) for doc in arr]
...@@ -122,15 +111,13 @@ def test_build_all_requests(task_class, limit): ...@@ -122,15 +111,13 @@ def test_build_all_requests(task_class, limit):
def test_construct_requests(task_class, limit): def test_construct_requests(task_class, limit):
task = task_class()
arr = ( arr = (
list(islice(task_class().test_docs(), limit)) list(islice(task.test_docs(), limit))
if limit if task.has_test_docs()
else list(task_class().test_docs()) else list(islice(task.validation_docs(), limit))
) )
requests = [ requests = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
task_class().construct_requests(doc, task_class().doc_to_text(doc))
for doc in arr
]
assert all(isinstance(doc, list) for doc in requests) assert all(isinstance(doc, list) for doc in requests)
assert len(requests) == limit if limit else True assert len(requests) == limit if limit else True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment