Unverified Commit 073b0808 authored by Oskar van der Wal's avatar Oskar van der Wal Committed by GitHub
Browse files

Merge branch 'bigscience-workshop:master' into master

parents 2d861a29 29bff88d
......@@ -353,11 +353,13 @@ class BaseLM(LM):
for context, request_args in tqdm(reord.get_reordered()):
stopping_criteria = request_args["stopping_criteria"]
max_generation_length = request_args["max_generation_length"]
num_fewshot = request_args["num_fewshot"]
assert isinstance(stopping_criteria, str) or stopping_criteria is None
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
assert isinstance(num_fewshot, int) or num_fewshot is None
if stopping_criteria is None:
until = [self.eot_token]
......@@ -382,9 +384,10 @@ class BaseLM(LM):
context_enc,
max_length,
torch.tensor(primary_until),
num_fewshot,
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
s = self.tok_decode(cont.tolist())
for term in until:
s = s.split(term)[0]
......@@ -536,7 +539,7 @@ class Task(abc.ABC):
pass
@abstractmethod
def construct_requests(self, doc, ctx):
def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -546,6 +549,8 @@ class Task(abc.ABC):
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
"""
pass
......@@ -689,11 +694,9 @@ class PromptSourceTask(Task):
def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'.
By default, its None, meaning to generate up to max or EOT, whichever comes first.
By default, its "\n###\n".
"""
return None
return "\n###\n"
def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
......@@ -724,7 +727,7 @@ class PromptSourceTask(Task):
text, _ = self.prompt.apply(doc)
return text
def construct_requests(self, doc, ctx):
def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -734,6 +737,8 @@ class PromptSourceTask(Task):
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
"""
_requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc)
......@@ -749,6 +754,7 @@ class PromptSourceTask(Task):
request_args = {
"stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(),
"num_fewshot": args["num_fewshot"],
}
cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request)
......@@ -915,12 +921,12 @@ class PromptSourceTask(Task):
if num_fewshot == 0:
labeled_examples = ""
fewshotex, fewshotidx, fewshotsource = [], [], None
fewshotex, fewshotidx, self.fewshotsource = [], [], None
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd)
fewshotsource = "train"
self.fewshotsource = "train"
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
......@@ -929,18 +935,18 @@ class PromptSourceTask(Task):
else self.test_docs()
)
if self.has_validation_docs():
fewshotsource = "val"
self.fewshotsource = "val"
elif self.test_docs():
fewshotsource = "test"
self.fewshotsource = "test"
fewshotex, fewshotidx = self._get_fewshot_examples(
self._fewshot_docs, k=num_fewshot + 1, rnd=rnd
)
fewshotex, fewshotidx = [
fewshotex, fewshotidx = zip(*[
(shot, idx)
for shot, idx in zip(fewshotex, fewshotidx)
if shot != doc
]
])
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex, fewshotidx = (
fewshotex[:num_fewshot],
......@@ -966,7 +972,7 @@ class PromptSourceTask(Task):
ctx,
{
"fewshot_idx": fewshotidx,
"fewshot_source": fewshotsource,
"fewshot_source": self.fewshotsource,
"fewshot_num": num_fewshot,
"ctx": ctx,
},
......
......@@ -206,7 +206,8 @@ def evaluate(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
fewshotex_logging_info["doc_id"] = original_doc_id
reqs = task.construct_requests(doc, ctx)
args = {"num_fewshot": num_fewshot}
reqs = task.construct_requests(doc, ctx, args)
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
for i, req in enumerate(reqs):
......
......@@ -12,6 +12,7 @@ class HFLM(BaseLM):
subfolder=None,
tokenizer=None,
batch_size=1,
parallelize=False
):
super().__init__()
......@@ -32,7 +33,7 @@ class HFLM(BaseLM):
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device)
)
self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
......@@ -68,9 +69,11 @@ class HFLM(BaseLM):
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
if parallelize:
self.gpt2.parallelize()
self._device = torch.device('cuda:0')
else:
self.gpt2.to(self._device)
@property
def eot_token(self):
......@@ -146,16 +149,26 @@ class HFLM(BaseLM):
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gpt2.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
if num_fewshot == 0:
generations = self.gpt2.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gpt2.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
# Remove the context from the generations
return generations[0, context.shape[1] :]
# for backwards compatibility
GPT2LM = HFLM
......@@ -8,6 +8,7 @@ class GPTJLM(BaseLM):
self,
device="cuda",
batch_size=1,
parallelize=False,
):
super().__init__()
......@@ -35,9 +36,11 @@ class GPTJLM(BaseLM):
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gptj = nn.DataParallel(self.gptj)
if parallelize:
self.gptj.parallelize()
self._device = torch.device('cuda:0')
else:
self.gptj.to(self._device)
@property
def eot_token(self):
......@@ -113,11 +116,23 @@ class GPTJLM(BaseLM):
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gptj.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
if num_fewshot == 0:
generations = self.gptj.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gptj.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
# Remove the context from the generations
return generations[0, context.shape[1] :]
......@@ -56,7 +56,7 @@ class T0LM(BaseLM):
@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
return 256
@property
def batch_size(self):
......@@ -94,6 +94,14 @@ class T0LM(BaseLM):
inputs, targets = zip(*chunk)
# Fill in empty encoder inputs with eos_token
inputs = (
f"{self.eot_token}"
if len(input_) == 0
else input_
for input_ in inputs
)
inputs_tok = self.tokenizer(
list(inputs),
max_length=self.max_length,
......@@ -172,11 +180,21 @@ class T0LM(BaseLM):
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t0.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
if num_fewshot == 0:
generations = self.t0.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t0.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return generations[0]
......@@ -62,7 +62,7 @@ class T5LM(BaseLM):
@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
return 256
@property
def batch_size(self):
......@@ -186,11 +186,21 @@ class T5LM(BaseLM):
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t5.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
if num_fewshot == 0:
generations = self.t5.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t5.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return generations[0]
......@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)),
}
def stopping_criteria(self):
return "\n\n"
# def stopping_criteria(self):
# return "\n\n"
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
......
......@@ -92,8 +92,8 @@ class DROP(PromptSourceTask):
# """
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
def stopping_criteria(self):
return "."
# def stopping_criteria(self):
# return "."
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
......@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask):
def test_docs(self):
return self.dataset[str(self.SPLIT)]
def stopping_criteria(self):
return None
# def stopping_criteria(self):
# return None
def max_generation_length(self):
return 200
......
......@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask):
else:
return self.dataset["test"]
def stopping_criteria(self):
return None
# def stopping_criteria(self):
# return None
def max_generation_length(self):
return 250
......
......@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask):
def has_test_docs(self):
return False
def stopping_criteria(self):
return "\n"
# def stopping_criteria(self):
# return "\n###\n"
def training_docs(self):
if self._training_docs is None:
......
......@@ -305,3 +305,39 @@ class SGWinogradSchemaChallenge(PromptSourceTask):
def aggregation(self):
return {"acc": mean}
class WinogenderSchemaDiagnostics(PromptSourceTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "axg"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return self.dataset["test"]
class BroadcoverageDiagnostics(PromptSourceTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "axb"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return self.dataset["test"]
......@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask):
def test_docs(self):
return self.dataset["test"]
def stopping_criteria(self):
return "\n"
# def stopping_criteria(self):
# return "\n"
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
......@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask):
return self.dataset["test"]
def stopping_criteria(self):
# TODO: Denote the string where the generation should be split.
# For example, for `coqa`, this is '\nQ:' and for `drop` '.'.
# Only define this method when you want to control few-shot generations on specific tokens.
# The default is set to '\n###\n'.
# NOTE: You may delete this function if the task does not required generation.
return None
return "\n###\n"
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment