Unverified Commit d6ceced5 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge branch 'master' into auto-batching

parents 4d21ab6b fc4428dc
......@@ -79,7 +79,7 @@ class GradeSchoolMath8K(Task):
"""
# NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution.
completion = rf.greedy_until(ctx, ["\n"])
completion = rf.greedy_until(ctx, {'until': ["\n"]})
return completion
def _extract_answer(self, completion):
......
......@@ -63,7 +63,7 @@ class Math(Task):
return " " + doc["solution"]
def construct_requests(self, doc, ctx):
return rf.greedy_until(ctx, ["\n"])
return rf.greedy_until(ctx, {'until': ["\n"]})
def process_results(self, doc, results):
retval = 0
......
......@@ -214,7 +214,7 @@ class QASPER(Task):
"""
# unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
if doc["answer_type"] in ("free form answer"):
return [rf.greedy_until(ctx, ["\n"])]
return [rf.greedy_until(ctx, {'until': ["\n"]})]
elif doc["answer_type"] in ("bool"):
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
......
......@@ -107,7 +107,7 @@ class SQuAD2(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, ["\n"])
continuation = rf.greedy_until(ctx, {'until': ["\n"]})
is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
return continuation, is_unanswerable
......
......@@ -184,7 +184,7 @@ class GeneralTranslationTask(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return rf.greedy_until(ctx, ["\n"])
return rf.greedy_until(ctx, {'until': ["\n"]})
def process_results(self, doc, results):
# Add spaces between words for BLEU score calculation of target languages like Chinese
......
......@@ -247,7 +247,7 @@ class TruthfulQAGeneration(Task):
part of the document for `doc`.
"""
# TODO: Find a way to cap the number of generated tokens to `50` as in the official implementation.
completion = rf.greedy_until(ctx, ["."])
completion = rf.greedy_until(ctx, {'until': ["."]})
return completion
def process_results(self, doc, results):
......
......@@ -59,7 +59,7 @@ class WordUnscrambleTask(Task):
return doc["completion"]
def construct_requests(self, doc, ctx):
completion = rf.greedy_until(ctx, ["\n"])
completion = rf.greedy_until(ctx, {'until': ["\n"]})
return completion
def process_results(self, doc, results):
......
......@@ -10,8 +10,6 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
import re
import inspect
import lm_eval.datasets.wikitext.wikitext
from lm_eval.base import PerplexityTask
......@@ -63,7 +61,7 @@ def wikitext_detokenizer(string):
class WikiText(PerplexityTask):
VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.wikitext.wikitext)
DATASET_PATH = "EleutherAI/wikitext_document_level"
DATASET_NAME = "wikitext-2-raw-v1"
def has_training_docs(self):
......
......@@ -5,7 +5,9 @@ import collections
import functools
import inspect
import sys
from typing import List
from typing import List, Union
import torch
from omegaconf import OmegaConf
......@@ -116,6 +118,26 @@ def make_disjoint_window(pair):
return a[: len(a) - (len(b) - 1)], b
def select_continuation_from_batch_left_padding(
generations: Union[List[List[int]], torch.Tensor], max_context_size: int
):
"""Select the continuation from the batch, removing prompts of different lengths.
Args:
generations (Union[List[List[int]], torch.Tensor]):
A tensor or list-of-lists of shape [batch_size, sequence length].
max_context_size (int):
The size of the biggest context; generations will proceed from that
index.
Example:
PAD PAD Continue : The dog chased the cat [every day of the week]
Riddle me this : The dog chased the cat [yesterday] PAD PAD PAD PAD
Output:
[every day of the week]
[yesterday] PAD PAD PAD PAD
"""
return generations[:, max_context_size:]
class Reorderer:
def __init__(self, arr, fn):
self.size = len(arr)
......@@ -201,3 +223,4 @@ def run_task_tests(task_list: List[str]):
raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)
......@@ -26,7 +26,7 @@ class DryrunLM(LM):
def greedy_until(self, requests):
res = []
for ctx, until in requests:
for ctx, _ in requests:
res.append("lol")
# assume worst case - generates until 256
......
......@@ -26,6 +26,7 @@ setuptools.setup(
"numexpr",
"openai>=0.6.4",
"omegaconf>=2.2",
"peft>=0.2.0",
"pybind11>=2.6.2",
"pycountry",
"pytablewriter",
......@@ -42,5 +43,6 @@ setuptools.setup(
extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment