".idea/vscode:/vscode.git/clone" did not exist on "a2fb7b791eaa89c02dfaccd327c8e35b3b00b568"
Commit 6a6a0ebb authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

Merge remote-tracking branch 'upstream/big-refactor' into big-refactor-autobatching

parents e4acfcaa 2820042d
# Task-name
### Paper
Title: `Semantic Parsing on Freebase from Question-Answer Pairs`
Abstract: `https://cs.stanford.edu/~pliang/papers/freebase-emnlp2013.pdf`
WebQuestions is a benchmark for question answering. The dataset consists of 6,642
question/answer pairs. The questions are supposed to be answerable by Freebase, a
large knowledge graph. The questions are mostly centered around a single named entity.
The questions are popular ones asked on the web (at least in 2013).
Homepage: `https://worksheets.codalab.org/worksheets/0xba659fe363cb46e7a505c5b6a774dc8a`
### Citation
```
@inproceedings{berant-etal-2013-semantic,
title = "Semantic Parsing on {F}reebase from Question-Answer Pairs",
author = "Berant, Jonathan and
Chou, Andrew and
Frostig, Roy and
Liang, Percy",
booktitle = "Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing",
month = oct,
year = "2013",
address = "Seattle, Washington, USA",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/D13-1160",
pages = "1533--1544",
}
```
### Subtasks
List or describe tasks defined in this folder, and their names here:
* `webqs`: `Questions with multiple accepted answers.`
### Checklist
For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
from typing import Dict, List
def doc_to_choice(doc: Dict) -> List[str]:
"""Return all of the accepted answers as choices."""
return _remove_prefixes(doc["answers"])
def doc_to_target(doc: Dict) -> List[int]:
"""Return list of indices of accepted answers (all of them)."""
remaining = _remove_prefixes(doc["answers"])
return list(range(len(remaining)))
def _remove_prefixes(aliases):
"""
Remove any alias that has a strict prefix elsewhere in the list.
This is an optimization. We can do this because if the prefix is acceptable by isgreedy,
we can stop looking.
"""
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
group:
- freebase
- question_answer
task: webqs
dataset_path: web_questions
dataset_name: null
output_type: multiple_choice
training_split: train
validation_split: null
test_split: test
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: !function utils.doc_to_target
doc_to_choice: !function utils.doc_to_choice
should_decontaminate: true
doc_to_decontamination_query: question
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
def partial_context(doc, option):
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
return doc["sentence"][:pronoun_loc] + option
def doc_to_text(doc):
answer_to_num = {"1": 0, "2": 1}
return answer_to_num[doc["answer"]]
def partial_target(doc):
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
return doc["sentence"][pronoun_loc:].strip()
def create_choices(doc):
choices = []
for option in [doc["option1"], doc["option2"]]:
partial_ctx = partial_context(doc, option)
choices.append(partial_ctx)
return choices
def doc_to_target(doc):
idx = doc["sentence"].index("_") + 1
return doc["sentence"][idx:].strip()
def gold_alias(doc):
answer_to_num = {"1": 0, "2": 1}
return answer_to_num[doc['answer']]
\ No newline at end of file
def doc_to_choice(doc):
idx = doc["sentence"].index("_")
options = [doc["option1"], doc["option2"]]
return [doc["sentence"][:idx] + opt for opt in options]
task: winogrande
dataset_path: winogrande
dataset_name: winogrande_xl
output_type: winograd_schema
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_target: !function preprocess_winogrande.partial_target
doc_to_text: "{{sentence}}"
create_choices: !function preprocess_winogrande.create_choices
gold_alias: !function preprocess_winogrande.gold_alias
doc_to_text: !function preprocess_winogrande.doc_to_text
doc_to_target: !function preprocess_winogrande.doc_to_target
doc_to_choice: !function preprocess_winogrande.doc_to_choice
metric_list:
- metric: acc
aggregation: mean
......
......@@ -5,6 +5,7 @@ import fnmatch
import jsonlines
import argparse
import logging
from pathlib import Path
from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS
......@@ -15,22 +16,41 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="")
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
parser.add_argument(
"--model_args",
default="",
help="String arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
)
parser.add_argument(
"--tasks", default=None, choices=utils.MultiChoice(sorted(ALL_TASKS))
)
parser.add_argument("--config", default=None)
parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=str, default=1)
parser.add_argument(
"--num_fewshot",
type=int,
default=0,
help="Number of examples in few-shot context",
)
parser.add_argument("--batch_size", type=int, default=1) # TODO: only integers
parser.add_argument(
"--max_batch_size",
type=int,
default=None,
help="Maximal batch size to try with --batch_size auto",
)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output_path", default=None)
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to use (e.g. cuda, cuda:0, cpu)",
)
parser.add_argument(
"--output_path",
default=None,
type=str,
metavar="= [dir/file.jsonl] [DIR]",
help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
)
parser.add_argument(
"--limit",
type=float,
......@@ -38,11 +58,30 @@ def parse_args():
help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--use_cache", type=str, default=None)
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument(
"--use_cache",
type=str,
default=None,
help="A path to a sqlite db file for caching model responses. `None` if not caching.",
)
parser.add_argument("--decontamination_ngrams_path", default=None) # TODO: not used
parser.add_argument(
"--check_integrity",
action="store_true",
help="Whether to run the relevant part of the test suite for the tasks",
)
parser.add_argument(
"--write_out",
action="store_true",
default=False,
help="Prints the prompt for the first few documents",
)
parser.add_argument(
"--log_samples",
action="store_true",
default=False,
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis",
)
return parser.parse_args()
......@@ -74,6 +113,25 @@ def main():
config = utils.load_yaml_config(task)
task_names.append(config)
if args.output_path:
path = Path(args.output_path)
# check if file or 'dir/results.json' exists
if path.is_file() or Path(args.output_path).joinpath("results.json").is_file():
eval_logger.warning(
f"File already exists at {path}. Results will be overwritten."
)
assert not path.is_file(), "File already exists"
# if path json then get parent dir
elif path.suffix in (".json", ".jsonl"):
output_path_file = path
path.parent.mkdir(parents=True, exist_ok=True)
path = path.parent
else:
path.mkdir(parents=True, exist_ok=True)
output_path_file = path.joinpath("results.json")
elif args.log_samples and not args.output_path:
assert args.output_path, "Specify --output_path"
eval_logger.info(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate(
......@@ -89,34 +147,29 @@ def main():
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_out=args.write_out,
log_samples=args.log_samples,
)
if results is not None:
samples = results.pop("samples")
if args.log_samples:
samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=lambda o: str(o))
print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
if args.output_path:
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
with open(args.output_path, "w") as f:
f.write(dumped)
for task_name, config in results["configs"].items():
output_name = "{}_{}".format(
re.sub("/", "__", args.model_args), task_name
)
if os.path.isdir(args.output_path):
filename = f"./{args.output_path}/{output_name}.jsonl"
elif os.path.isfile(args.output_path):
filename = (
f"./{os.path.dirname(args.output_path)}/{output_name}.jsonl"
output_path_file.open("w").write(dumped)
if args.log_samples:
for task_name, config in results["configs"].items():
output_name = "{}_{}".format(
re.sub("/", "__", args.model_args), task_name
)
filename = path.joinpath(f"{output_name}.jsonl")
with jsonlines.open(filename, "w") as f:
f.write_all(samples[task_name])
with jsonlines.open(filename, "w") as f:
f.write_all(samples[task_name])
print(
f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
......
......@@ -58,7 +58,6 @@ def main():
ctx = task.fewshot_context(
doc=doc,
num_fewshot=args.num_fewshot,
rnd=rnd,
)
f.write(ctx + "\n")
......
......@@ -50,6 +50,13 @@ setuptools.setup(
],
extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"linting": [
"flake8",
"pylint",
"mypy",
"pre-commit",
],
"testing": ["pytest", "pytest-cov", "pytest-xdist"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
"promptsource": [
......
def pytest_addoption(parser):
parser.addoption(
"--new_task",
action="store_true",
help="new_tasks_found",
)
import json
from typing import List
from lm_eval.utils import load_yaml_config
from pathlib import Path
FILE_PATH = file_path = ".github/outputs/tasks_all_changed_and_modified_files.txt"
def load_changed_files(file_path: str = FILE_PATH) -> List[str]:
with open(file_path, "r") as f:
return [line.strip() for line in f.readlines()]
def parser(full_path: List[str]) -> List[str]:
_output = set()
for x in full_path:
if x.endswith(".yaml"):
_output.add(load_yaml_config(x)["task"])
elif x.endswith(".py"):
path = [str(x) for x in (list(Path(x).parent.glob("*.yaml")))]
_output |= {load_yaml_config(x)["task"] for x in path}
return list(_output)
......@@ -5,7 +5,7 @@ import lm_eval.api.registry as registry
import lm_eval.tasks as tasks
# import lm_eval.models as models
import lm_eval.api as api
import lm_eval.evaluator as evaluator
import random
import pytest
......@@ -15,60 +15,52 @@ import pytest
# test once we break evaluator into smaller, more manageable pieces
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_evaluator(taskname, task_class):
task_dict = tasks.get_task_dict([taskname])
# TODO: re-add cachingLM
# os.system("rm test_cache.db")
# lm = base.CachingLM(models.get_model("dummy")(), "test_cache.db")
lm = registry.get_model("dummy")()
def ll_fn(reqs):
for ctx, cont in reqs:
if len(ctx) == 0:
continue
# space convention
assert ctx[-1] != " "
assert cont[0] == " " or ctx[-1] == "\n"
res = []
random.seed(42)
for _ in reqs:
res.append((-random.random(), False))
return res
def ll_perp_fn(reqs):
for (string,) in reqs:
assert isinstance(string, str)
res = []
random.seed(42)
for _ in reqs:
res.append(-random.random())
return res
lm.loglikelihood = ll_fn
lm.loglikelihood_rolling = ll_perp_fn
@pytest.mark.parametrize(
"task_name,limit,model,model_args",
[
(
["arc_easy"],
10,
"hf",
"pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
)
],
)
def test_evaluator(task_name: list[str], limit: int, model: str, model_args: str):
task_name = task_name
limit = 10
e1 = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
num_fewshot=0,
e1 = evaluator.simple_evaluate(
model=model,
tasks=task_name,
limit=limit,
bootstrap_iters=10,
model_args=model_args,
)
assert e1 is not None
lm = api.registry.get_model(model).create_from_arg_string(
model_args,
{
"batch_size": None,
"max_batch_size": None,
"device": None,
},
)
task_dict = tasks.get_task_dict(task_name, num_fewshot=0)
e2 = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
)
assert e2 is not None
# check that caching is working
assert e1 == e2
def r(x):
return x["results"]["arc_easy"]
assert all(
x == y
for x, y in zip([y for _, y in r(e1).items()], [y for _, y in r(e2).items()])
)
import lm_eval.tasks as tasks
import pytest
from itertools import islice
import lm_eval.tasks as tasks
from tests.extra.test_utils import load_changed_files, parser
from typing import List, ClassVar
import os
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, task_class):
print("Evaluating task", taskname)
task = task_class()
@pytest.fixture()
def any_new_tasks(request) -> bool:
return request.config.getoption("--new_task")
assert task.has_training_docs() in [True, False]
assert task.has_validation_docs() in [True, False]
assert task.has_test_docs() in [True, False]
assert isinstance(task.aggregation(), dict)
assert isinstance(task.higher_is_better(), dict)
assert task.aggregation().keys() == task.higher_is_better().keys()
# ["arc_easy] else get list of new tasks
def new_tasks(any_new_tasks: bool) -> List[str]:
FILENAME = ".github/outputs/tasks_all_changed_and_modified_files.txt"
if any_new_tasks and os.path.exists(FILENAME):
return [parser(load_changed_files(FILENAME))]
elif os.getenv("API") is not None:
return ["arc_easy", "hellaswag", "piqa", "wikitext"]
else:
return ["arc_easy"]
for v in task.higher_is_better().values():
assert v in [True, False]
assert isinstance(task.VERSION, int)
@pytest.fixture(params=new_tasks(any_new_tasks))
def task_class(request):
task_name = request.param
return [cls for name, cls in tasks.TASK_REGISTRY.items() if name in task_name][0]
# test deterministic docs
# (don't test train because it's slow)
task2 = task_class()
@pytest.fixture()
def limit(any_new_tasks: bool) -> int:
return 100 if any_new_tasks else 10
limit = None
if taskname in ["triviaqa"] or taskname.startswith("pile_"):
limit = 10000
if task.has_validation_docs():
arr = list(islice(task.validation_docs(), limit))
arr2 = list(islice(task2.validation_docs(), limit))
# Tests
assert arr == arr2
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
def test_download(task_class):
task_class().download()
assert task_class().dataset is not None
assert reqs == reqs2
if task.has_test_docs():
arr = list(islice(task.test_docs(), limit))
arr2 = list(islice(task2.test_docs(), limit))
def test_has_training_docs(task_class):
assert task_class().has_training_docs() in [True, False]
assert arr == arr2
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
def test_check_training_docs(task_class):
task = task_class()
assert task.has_training_docs() if task._config["training_split"] else True
assert reqs == reqs2
def test_has_validation_docs(task_class):
assert task_class().has_training_docs() in [True, False]
if task.has_training_docs():
arr = list(islice(task.training_docs(), limit))
arr2 = list(islice(task2.training_docs(), limit))
assert arr == arr2
def test_check_validation_docs(task_class):
task = task_class()
assert (
task_class().has_training_docs() if task._config["validation_split"] else True
)
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2
def test_has_test_docs(task_class):
assert task_class().has_training_docs() in [True, False]
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, task_class):
print("Evaluating task", taskname)
def test_check_test_docs(task_class):
task = task_class()
fns = []
if task.has_training_docs():
fns.append(task.training_docs)
if task.has_validation_docs():
fns.append(task.validation_docs)
# test doc might not have labels
# if task.has_test_docs(): fns.append(task.test_docs)
for fn in fns:
# print(list(islice(fn(), 10)))
for doc in islice(fn(), 10):
txt = task.doc_to_text(doc)
tgt = task.doc_to_target(doc)
assert isinstance(txt, str)
assert isinstance(tgt, str)
# space convention
# allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
if len(txt) != 0:
assert txt[-1] != " "
assert tgt[0] == " " or txt[-1] == "\n"
reqs = task.construct_requests(doc, txt)
# construct_requests can return just one request
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
# todo: mock lm after refactoring evaluator.py to not be a mess
# for req in reqs:
# assert isinstance(req, base.Request)
assert task_class().has_training_docs() if task._config["test_split"] else True
def test_should_decontaminate(task_class):
task_class = task_class()
assert task_class.should_decontaminate() in [True, False]
if task_class.should_decontaminate():
assert task_class._config["doc_to_decontamination_query"] is not None
def test_doc_to_text(task_class, limit):
arr = (
list(islice(task_class().test_docs(), limit))
if limit
else list(task_class().test_docs())
)
_array = [task_class().doc_to_text(doc) for doc in arr]
# space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
assert all(
isinstance(x, str) and (x[-1] != " " if len(x) != 0 else True) for x in _array
)
def test_create_choices(task_class, limit):
arr = (
list(islice(task_class().test_docs(), limit))
if limit
else list(task_class().test_docs())
)
_array = [task_class().doc_to_choice(doc) for doc in arr]
# assert all(len(x) == 4 for x in _array)
assert all(isinstance(x, list) for x in _array)
assert all(isinstance(x[0], str) for x in _array)
def test_doc_to_target(task_class, limit):
arr = (
list(islice(task_class().test_docs(), limit))
if limit
else list(task_class().test_target())
)
_array_target = [task_class().doc_to_target(doc) for doc in arr]
assert all(isinstance(label, int) for label in _array_target)
assert len(_array_target) == limit if limit else True
# _array_text = [task.doc_to_text(doc) for doc in arr]
# Not working
# assert all(tgt[0] == " " or txt[-1] == "\n" if len(txt) != 0 else True for txt, tgt in zip(_array_text, _array_target))
def test_build_all_requests(task_class, limit):
task_class().build_all_requests(rank=1, limit=limit, world_size=1)
assert task_class.instances is not None
def test_construct_requests(task_class, limit):
arr = (
list(islice(task_class().test_docs(), limit))
if limit
else list(task_class().test_docs())
)
requests = [
task_class().construct_requests(doc, task_class().doc_to_text(doc))
for doc in arr
]
assert all(isinstance(doc, list) for doc in requests)
assert len(requests) == limit if limit else True
# def test_create_choices(task_class):
# arr = list(islice(task_class().test_docs(), 1))
# choices = task_class().create_choices(arr[0])
# assert choices is not None
# checking if number of choices is correct
# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_basic_interface(taskname, task_class):
# print("Evaluating task", taskname)
# task = task_class()
#
# assert task.has_training_docs() in [True, False]
# assert task.has_validation_docs() in [True, False]
# assert task.has_test_docs() in [True, False]
#
# assert isinstance(task.aggregation(), dict)
# assert isinstance(task.higher_is_better(), dict)
# assert task.aggregation().keys() == task.higher_is_better().keys()
#
# for v in task.higher_is_better().values():
# assert v in [True, False]
#
# assert isinstance(task.VERSION, int)
#
# # test deterministic docs
# # (don't test train because it's slow)
#
# task2 = task_class()
#
# limit = None
#
# if taskname in ["triviaqa"] or taskname.startswith("pile_"):
# limit = 10000
# if task.has_validation_docs():
# arr = list(islice(task.validation_docs(), limit))
# arr2 = list(islice(task2.validation_docs(), limit))
#
# assert arr == arr2
#
# reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
# reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
# assert reqs == reqs2
#
# if task.has_test_docs():
# arr = list(islice(task.test_docs(), limit))
# arr2 = list(islice(task2.test_docs(), limit))
#
# assert arr == arr2
#
# reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
# reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
# assert reqs == reqs2
#
# if task.has_training_docs():
# arr = list(islice(task.training_docs(), limit))
# arr2 = list(islice(task2.training_docs(), limit))
#
# assert arr == arr2
#
# reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
# reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
# assert reqs == reqs2
#
#
# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_documents_and_requests(taskname, task_class):
# print("Evaluating task", taskname)
# task = task_class()
# fns = []
# if task.has_training_docs():
# fns.append(task.training_docs)
# if task.has_validation_docs():
# fns.append(task.validation_docs)
# # test doc might not have labels
# # if task.has_test_docs(): fns.append(task.test_docs)
#
# for fn in fns:
# # print(list(islice(fn(), 10)))
# for doc in islice(fn(), 10):
#
# txt = task.doc_to_text(doc)
# tgt = task.doc_to_target(doc)
#
# assert isinstance(txt, str)
# assert isinstance(tgt, str)
#
# # space convention
# # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
# if len(txt) != 0:
# assert txt[-1] != " "
# assert tgt[0] == " " or txt[-1] == "\n"
#
# reqs = task.construct_requests(doc, txt)
#
# # construct_requests can return just one request
# if not isinstance(reqs, (list, tuple)):
# reqs = [reqs]
#
# # todo: mock lm after refactoring evaluator.py to not be a mess
# # for req in reqs:
# # assert isinstance(req, base.Request)
......@@ -6,7 +6,7 @@ import lm_eval.models
def test_description():
seed = 42
num_examples = 1
task_names = ["arc_challenge", "lambada"]
task_names = ["arc_challenge", "arc_easy"]
description_dict = {
"arc_challenge": "Label for the relevant action:\nSentences describing context, with an incomplete sentence trailing answer that plausibly completes the situation.",
"lambada": "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in.",
......@@ -40,6 +40,5 @@ def test_description():
ctx = task.fewshot_context(
doc=doc,
num_fewshot=1,
rnd=rnd,
)
assert description in ctx
......@@ -44,9 +44,9 @@ def test_generate_13_grams_1(caplog):
pass
os.makedirs(test_working_directory)
assert not os.path.exists("pile")
os.makedirs("pile")
archive = Archive(os.path.join("pile", "test.jsonl.zst"))
assert not os.path.exists("../pile")
os.makedirs("../pile")
archive = Archive(os.path.join("../pile", "test.jsonl.zst"))
archive.add_data(data)
archive.commit()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment