Unverified Commit 073b0808 authored by Oskar van der Wal's avatar Oskar van der Wal Committed by GitHub
Browse files

Merge branch 'bigscience-workshop:master' into master

parents 2d861a29 29bff88d
...@@ -353,11 +353,13 @@ class BaseLM(LM): ...@@ -353,11 +353,13 @@ class BaseLM(LM):
for context, request_args in tqdm(reord.get_reordered()): for context, request_args in tqdm(reord.get_reordered()):
stopping_criteria = request_args["stopping_criteria"] stopping_criteria = request_args["stopping_criteria"]
max_generation_length = request_args["max_generation_length"] max_generation_length = request_args["max_generation_length"]
num_fewshot = request_args["num_fewshot"]
assert isinstance(stopping_criteria, str) or stopping_criteria is None assert isinstance(stopping_criteria, str) or stopping_criteria is None
assert ( assert (
isinstance(max_generation_length, int) or max_generation_length is None isinstance(max_generation_length, int) or max_generation_length is None
) )
assert isinstance(num_fewshot, int) or num_fewshot is None
if stopping_criteria is None: if stopping_criteria is None:
until = [self.eot_token] until = [self.eot_token]
...@@ -382,9 +384,10 @@ class BaseLM(LM): ...@@ -382,9 +384,10 @@ class BaseLM(LM):
context_enc, context_enc,
max_length, max_length,
torch.tensor(primary_until), torch.tensor(primary_until),
num_fewshot,
) )
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :]) s = self.tok_decode(cont.tolist())
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
...@@ -536,7 +539,7 @@ class Task(abc.ABC): ...@@ -536,7 +539,7 @@ class Task(abc.ABC):
pass pass
@abstractmethod @abstractmethod
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -546,6 +549,8 @@ class Task(abc.ABC): ...@@ -546,6 +549,8 @@ class Task(abc.ABC):
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
""" """
pass pass
...@@ -689,11 +694,9 @@ class PromptSourceTask(Task): ...@@ -689,11 +694,9 @@ class PromptSourceTask(Task):
def stopping_criteria(self) -> Optional[str]: def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end. """Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'. By default, its "\n###\n".
By default, its None, meaning to generate up to max or EOT, whichever comes first.
""" """
return None return "\n###\n"
def max_generation_length(self) -> Optional[int]: def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task.""" """Denote where the max length of the generation if it is obvious from the task."""
...@@ -724,7 +727,7 @@ class PromptSourceTask(Task): ...@@ -724,7 +727,7 @@ class PromptSourceTask(Task):
text, _ = self.prompt.apply(doc) text, _ = self.prompt.apply(doc)
return text return text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -734,6 +737,8 @@ class PromptSourceTask(Task): ...@@ -734,6 +737,8 @@ class PromptSourceTask(Task):
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
""" """
_requests = [] _requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc) answer_choices_list = self.prompt.get_answer_choices_list(doc)
...@@ -749,6 +754,7 @@ class PromptSourceTask(Task): ...@@ -749,6 +754,7 @@ class PromptSourceTask(Task):
request_args = { request_args = {
"stopping_criteria": self.stopping_criteria(), "stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(), "max_generation_length": self.max_generation_length(),
"num_fewshot": args["num_fewshot"],
} }
cont_request = rf.greedy_until(ctx, request_args) cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request) _requests.append(cont_request)
...@@ -915,12 +921,12 @@ class PromptSourceTask(Task): ...@@ -915,12 +921,12 @@ class PromptSourceTask(Task):
if num_fewshot == 0: if num_fewshot == 0:
labeled_examples = "" labeled_examples = ""
fewshotex, fewshotidx, fewshotsource = [], [], None fewshotex, fewshotidx, self.fewshotsource = [], [], None
else: else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc* # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs(): if self.has_training_docs():
fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd) fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd)
fewshotsource = "train" self.fewshotsource = "train"
else: else:
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list( self._fewshot_docs = list(
...@@ -929,18 +935,18 @@ class PromptSourceTask(Task): ...@@ -929,18 +935,18 @@ class PromptSourceTask(Task):
else self.test_docs() else self.test_docs()
) )
if self.has_validation_docs(): if self.has_validation_docs():
fewshotsource = "val" self.fewshotsource = "val"
elif self.test_docs(): elif self.test_docs():
fewshotsource = "test" self.fewshotsource = "test"
fewshotex, fewshotidx = self._get_fewshot_examples( fewshotex, fewshotidx = self._get_fewshot_examples(
self._fewshot_docs, k=num_fewshot + 1, rnd=rnd self._fewshot_docs, k=num_fewshot + 1, rnd=rnd
) )
fewshotex, fewshotidx = [ fewshotex, fewshotidx = zip(*[
(shot, idx) (shot, idx)
for shot, idx in zip(fewshotex, fewshotidx) for shot, idx in zip(fewshotex, fewshotidx)
if shot != doc if shot != doc
] ])
# get rid of the doc that's the one we're evaluating, if it's in the fewshot # get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex, fewshotidx = ( fewshotex, fewshotidx = (
fewshotex[:num_fewshot], fewshotex[:num_fewshot],
...@@ -966,7 +972,7 @@ class PromptSourceTask(Task): ...@@ -966,7 +972,7 @@ class PromptSourceTask(Task):
ctx, ctx,
{ {
"fewshot_idx": fewshotidx, "fewshot_idx": fewshotidx,
"fewshot_source": fewshotsource, "fewshot_source": self.fewshotsource,
"fewshot_num": num_fewshot, "fewshot_num": num_fewshot,
"ctx": ctx, "ctx": ctx,
}, },
......
...@@ -206,7 +206,8 @@ def evaluate( ...@@ -206,7 +206,8 @@ def evaluate(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
) )
fewshotex_logging_info["doc_id"] = original_doc_id fewshotex_logging_info["doc_id"] = original_doc_id
reqs = task.construct_requests(doc, ctx) args = {"num_fewshot": num_fewshot}
reqs = task.construct_requests(doc, ctx, args)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
......
...@@ -12,6 +12,7 @@ class HFLM(BaseLM): ...@@ -12,6 +12,7 @@ class HFLM(BaseLM):
subfolder=None, subfolder=None,
tokenizer=None, tokenizer=None,
batch_size=1, batch_size=1,
parallelize=False
): ):
super().__init__() super().__init__()
...@@ -32,7 +33,7 @@ class HFLM(BaseLM): ...@@ -32,7 +33,7 @@ class HFLM(BaseLM):
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""), revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device) )
self.gpt2.eval() self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
...@@ -68,9 +69,11 @@ class HFLM(BaseLM): ...@@ -68,9 +69,11 @@ class HFLM(BaseLM):
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu # TODO: fix multi-gpu
# gpus = torch.cuda.device_count() if parallelize:
# if gpus > 1: self.gpt2.parallelize()
# self.gpt2 = nn.DataParallel(self.gpt2) self._device = torch.device('cuda:0')
else:
self.gpt2.to(self._device)
@property @property
def eot_token(self): def eot_token(self):
...@@ -146,16 +149,26 @@ class HFLM(BaseLM): ...@@ -146,16 +149,26 @@ class HFLM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gpt2.generate(
if num_fewshot == 0:
generations = self.gpt2.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gpt2.generate(
context, context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
do_sample=False, do_sample=False,
) )
# Remove the context from the generations
return generations[0, context.shape[1] :]
# for backwards compatibility # for backwards compatibility
GPT2LM = HFLM GPT2LM = HFLM
...@@ -8,6 +8,7 @@ class GPTJLM(BaseLM): ...@@ -8,6 +8,7 @@ class GPTJLM(BaseLM):
self, self,
device="cuda", device="cuda",
batch_size=1, batch_size=1,
parallelize=False,
): ):
super().__init__() super().__init__()
...@@ -35,9 +36,11 @@ class GPTJLM(BaseLM): ...@@ -35,9 +36,11 @@ class GPTJLM(BaseLM):
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu # TODO: fix multi-gpu
# gpus = torch.cuda.device_count() if parallelize:
# if gpus > 1: self.gptj.parallelize()
# self.gptj = nn.DataParallel(self.gptj) self._device = torch.device('cuda:0')
else:
self.gptj.to(self._device)
@property @property
def eot_token(self): def eot_token(self):
...@@ -113,11 +116,23 @@ class GPTJLM(BaseLM): ...@@ -113,11 +116,23 @@ class GPTJLM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gptj.generate(
if num_fewshot == 0:
generations = self.gptj.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gptj.generate(
context, context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
do_sample=False, do_sample=False,
) )
# Remove the context from the generations
return generations[0, context.shape[1] :]
...@@ -56,7 +56,7 @@ class T0LM(BaseLM): ...@@ -56,7 +56,7 @@ class T0LM(BaseLM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return self.tokenizer.model_max_length return 256
@property @property
def batch_size(self): def batch_size(self):
...@@ -94,6 +94,14 @@ class T0LM(BaseLM): ...@@ -94,6 +94,14 @@ class T0LM(BaseLM):
inputs, targets = zip(*chunk) inputs, targets = zip(*chunk)
# Fill in empty encoder inputs with eos_token
inputs = (
f"{self.eot_token}"
if len(input_) == 0
else input_
for input_ in inputs
)
inputs_tok = self.tokenizer( inputs_tok = self.tokenizer(
list(inputs), list(inputs),
max_length=self.max_length, max_length=self.max_length,
...@@ -172,11 +180,21 @@ class T0LM(BaseLM): ...@@ -172,11 +180,21 @@ class T0LM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t0.generate(
if num_fewshot == 0:
generations = self.t0.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t0.generate(
context, context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
do_sample=False, do_sample=False,
) )
return generations[0]
...@@ -62,7 +62,7 @@ class T5LM(BaseLM): ...@@ -62,7 +62,7 @@ class T5LM(BaseLM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return self.tokenizer.model_max_length return 256
@property @property
def batch_size(self): def batch_size(self):
...@@ -186,11 +186,21 @@ class T5LM(BaseLM): ...@@ -186,11 +186,21 @@ class T5LM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t5.generate(
if num_fewshot == 0:
generations = self.t5.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t5.generate(
context, context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
do_sample=False, do_sample=False,
) )
return generations[0]
...@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask): ...@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)), "f1": f1_sum / max(1, len(gold_list)),
} }
def stopping_criteria(self): # def stopping_criteria(self):
return "\n\n" # return "\n\n"
# def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of # """Uses RequestFactory to construct Requests and returns an iterable of
......
...@@ -92,8 +92,8 @@ class DROP(PromptSourceTask): ...@@ -92,8 +92,8 @@ class DROP(PromptSourceTask):
# """ # """
# conts = [rf.greedy_until(ctx, ["."])] # conts = [rf.greedy_until(ctx, ["."])]
# return conts # return conts
def stopping_criteria(self): # def stopping_criteria(self):
return "." # return "."
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
......
...@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask): ...@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask):
def test_docs(self): def test_docs(self):
return self.dataset[str(self.SPLIT)] return self.dataset[str(self.SPLIT)]
def stopping_criteria(self): # def stopping_criteria(self):
return None # return None
def max_generation_length(self): def max_generation_length(self):
return 200 return 200
......
...@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask): ...@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask):
else: else:
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self): # def stopping_criteria(self):
return None # return None
def max_generation_length(self): def max_generation_length(self):
return 250 return 250
......
...@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask): ...@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def stopping_criteria(self): # def stopping_criteria(self):
return "\n" # return "\n###\n"
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
......
...@@ -305,3 +305,39 @@ class SGWinogradSchemaChallenge(PromptSourceTask): ...@@ -305,3 +305,39 @@ class SGWinogradSchemaChallenge(PromptSourceTask):
def aggregation(self): def aggregation(self):
return {"acc": mean} return {"acc": mean}
class WinogenderSchemaDiagnostics(PromptSourceTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "axg"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return self.dataset["test"]
class BroadcoverageDiagnostics(PromptSourceTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "axb"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return self.dataset["test"]
...@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask): ...@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask):
def test_docs(self): def test_docs(self):
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self): # def stopping_criteria(self):
return "\n" # return "\n"
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
......
...@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask): ...@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask):
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self): def stopping_criteria(self):
# TODO: Denote the string where the generation should be split. # Only define this method when you want to control few-shot generations on specific tokens.
# For example, for `coqa`, this is '\nQ:' and for `drop` '.'. # The default is set to '\n###\n'.
# NOTE: You may delete this function if the task does not required generation. # NOTE: You may delete this function if the task does not required generation.
return None return "\n###\n"
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment