Unverified Commit d6ceced5 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge branch 'master' into auto-batching

parents 4d21ab6b fc4428dc
...@@ -79,7 +79,7 @@ class GradeSchoolMath8K(Task): ...@@ -79,7 +79,7 @@ class GradeSchoolMath8K(Task):
""" """
# NOTE: The paper implements "verifiers" that assign a score to multiple # NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution. # solutions and output the highest ranked solution.
completion = rf.greedy_until(ctx, ["\n"]) completion = rf.greedy_until(ctx, {'until': ["\n"]})
return completion return completion
def _extract_answer(self, completion): def _extract_answer(self, completion):
......
...@@ -63,7 +63,7 @@ class Math(Task): ...@@ -63,7 +63,7 @@ class Math(Task):
return " " + doc["solution"] return " " + doc["solution"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
return rf.greedy_until(ctx, ["\n"]) return rf.greedy_until(ctx, {'until': ["\n"]})
def process_results(self, doc, results): def process_results(self, doc, results):
retval = 0 retval = 0
......
...@@ -214,7 +214,7 @@ class QASPER(Task): ...@@ -214,7 +214,7 @@ class QASPER(Task):
""" """
# unanswerable = rf.loglikelihood(ctx, " " + "unanswerable") # unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
if doc["answer_type"] in ("free form answer"): if doc["answer_type"] in ("free form answer"):
return [rf.greedy_until(ctx, ["\n"])] return [rf.greedy_until(ctx, {'until': ["\n"]})]
elif doc["answer_type"] in ("bool"): elif doc["answer_type"] in ("bool"):
ll_yes, _ = rf.loglikelihood(ctx, " yes") ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no") ll_no, _ = rf.loglikelihood(ctx, " no")
......
...@@ -107,7 +107,7 @@ class SQuAD2(Task): ...@@ -107,7 +107,7 @@ class SQuAD2(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
continuation = rf.greedy_until(ctx, ["\n"]) continuation = rf.greedy_until(ctx, {'until': ["\n"]})
is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable") is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
return continuation, is_unanswerable return continuation, is_unanswerable
......
...@@ -184,7 +184,7 @@ class GeneralTranslationTask(Task): ...@@ -184,7 +184,7 @@ class GeneralTranslationTask(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
return rf.greedy_until(ctx, ["\n"]) return rf.greedy_until(ctx, {'until': ["\n"]})
def process_results(self, doc, results): def process_results(self, doc, results):
# Add spaces between words for BLEU score calculation of target languages like Chinese # Add spaces between words for BLEU score calculation of target languages like Chinese
......
...@@ -247,7 +247,7 @@ class TruthfulQAGeneration(Task): ...@@ -247,7 +247,7 @@ class TruthfulQAGeneration(Task):
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: Find a way to cap the number of generated tokens to `50` as in the official implementation. # TODO: Find a way to cap the number of generated tokens to `50` as in the official implementation.
completion = rf.greedy_until(ctx, ["."]) completion = rf.greedy_until(ctx, {'until': ["."]})
return completion return completion
def process_results(self, doc, results): def process_results(self, doc, results):
......
...@@ -59,7 +59,7 @@ class WordUnscrambleTask(Task): ...@@ -59,7 +59,7 @@ class WordUnscrambleTask(Task):
return doc["completion"] return doc["completion"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
completion = rf.greedy_until(ctx, ["\n"]) completion = rf.greedy_until(ctx, {'until': ["\n"]})
return completion return completion
def process_results(self, doc, results): def process_results(self, doc, results):
......
...@@ -10,8 +10,6 @@ NOTE: This `Task` is based on WikiText-2. ...@@ -10,8 +10,6 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
""" """
import re import re
import inspect
import lm_eval.datasets.wikitext.wikitext
from lm_eval.base import PerplexityTask from lm_eval.base import PerplexityTask
...@@ -63,7 +61,7 @@ def wikitext_detokenizer(string): ...@@ -63,7 +61,7 @@ def wikitext_detokenizer(string):
class WikiText(PerplexityTask): class WikiText(PerplexityTask):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.wikitext.wikitext) DATASET_PATH = "EleutherAI/wikitext_document_level"
DATASET_NAME = "wikitext-2-raw-v1" DATASET_NAME = "wikitext-2-raw-v1"
def has_training_docs(self): def has_training_docs(self):
......
...@@ -5,7 +5,9 @@ import collections ...@@ -5,7 +5,9 @@ import collections
import functools import functools
import inspect import inspect
import sys import sys
from typing import List from typing import List, Union
import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
...@@ -116,6 +118,26 @@ def make_disjoint_window(pair): ...@@ -116,6 +118,26 @@ def make_disjoint_window(pair):
return a[: len(a) - (len(b) - 1)], b return a[: len(a) - (len(b) - 1)], b
def select_continuation_from_batch_left_padding(
generations: Union[List[List[int]], torch.Tensor], max_context_size: int
):
"""Select the continuation from the batch, removing prompts of different lengths.
Args:
generations (Union[List[List[int]], torch.Tensor]):
A tensor or list-of-lists of shape [batch_size, sequence length].
max_context_size (int):
The size of the biggest context; generations will proceed from that
index.
Example:
PAD PAD Continue : The dog chased the cat [every day of the week]
Riddle me this : The dog chased the cat [yesterday] PAD PAD PAD PAD
Output:
[every day of the week]
[yesterday] PAD PAD PAD PAD
"""
return generations[:, max_context_size:]
class Reorderer: class Reorderer:
def __init__(self, arr, fn): def __init__(self, arr, fn):
self.size = len(arr) self.size = len(arr)
...@@ -201,3 +223,4 @@ def run_task_tests(task_list: List[str]): ...@@ -201,3 +223,4 @@ def run_task_tests(task_list: List[str]):
raise ValueError( raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
) )
...@@ -26,7 +26,7 @@ class DryrunLM(LM): ...@@ -26,7 +26,7 @@ class DryrunLM(LM):
def greedy_until(self, requests): def greedy_until(self, requests):
res = [] res = []
for ctx, until in requests: for ctx, _ in requests:
res.append("lol") res.append("lol")
# assume worst case - generates until 256 # assume worst case - generates until 256
......
...@@ -26,6 +26,7 @@ setuptools.setup( ...@@ -26,6 +26,7 @@ setuptools.setup(
"numexpr", "numexpr",
"openai>=0.6.4", "openai>=0.6.4",
"omegaconf>=2.2", "omegaconf>=2.2",
"peft>=0.2.0",
"pybind11>=2.6.2", "pybind11>=2.6.2",
"pycountry", "pycountry",
"pytablewriter", "pytablewriter",
...@@ -42,5 +43,6 @@ setuptools.setup( ...@@ -42,5 +43,6 @@ setuptools.setup(
extras_require={ extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"], "dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"], "multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
}, },
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment