Unverified Commit 1444a36c authored by KhalidAlt's avatar KhalidAlt Committed by GitHub
Browse files

Merge branch 'master' into master

parents 13676905 22155f7d
...@@ -694,11 +694,9 @@ class PromptSourceTask(Task): ...@@ -694,11 +694,9 @@ class PromptSourceTask(Task):
def stopping_criteria(self) -> Optional[str]: def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end. """Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'. By default, its "\n###\n".
By default, its None, meaning to generate up to max or EOT, whichever comes first.
""" """
return None return "\n###\n"
def max_generation_length(self) -> Optional[int]: def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task.""" """Denote where the max length of the generation if it is obvious from the task."""
......
...@@ -8,6 +8,7 @@ class GPTJLM(BaseLM): ...@@ -8,6 +8,7 @@ class GPTJLM(BaseLM):
self, self,
device="cuda", device="cuda",
batch_size=1, batch_size=1,
parallelize=False,
): ):
super().__init__() super().__init__()
...@@ -35,9 +36,11 @@ class GPTJLM(BaseLM): ...@@ -35,9 +36,11 @@ class GPTJLM(BaseLM):
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu # TODO: fix multi-gpu
# gpus = torch.cuda.device_count() if parallelize:
# if gpus > 1: self.gptj.parallelize()
# self.gptj = nn.DataParallel(self.gptj) self._device = torch.device('cuda:0')
else:
self.gptj.to(self._device)
@property @property
def eot_token(self): def eot_token(self):
...@@ -113,11 +116,23 @@ class GPTJLM(BaseLM): ...@@ -113,11 +116,23 @@ class GPTJLM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gptj.generate(
if num_fewshot == 0:
generations = self.gptj.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gptj.generate(
context, context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
do_sample=False, do_sample=False,
) )
# Remove the context from the generations
return generations[0, context.shape[1] :]
...@@ -56,7 +56,7 @@ class T0LM(BaseLM): ...@@ -56,7 +56,7 @@ class T0LM(BaseLM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return self.tokenizer.model_max_length return 256
@property @property
def batch_size(self): def batch_size(self):
...@@ -94,6 +94,14 @@ class T0LM(BaseLM): ...@@ -94,6 +94,14 @@ class T0LM(BaseLM):
inputs, targets = zip(*chunk) inputs, targets = zip(*chunk)
# Fill in empty encoder inputs with eos_token
inputs = (
f"{self.eot_token}"
if len(input_) == 0
else input_
for input_ in inputs
)
inputs_tok = self.tokenizer( inputs_tok = self.tokenizer(
list(inputs), list(inputs),
max_length=self.max_length, max_length=self.max_length,
...@@ -172,11 +180,21 @@ class T0LM(BaseLM): ...@@ -172,11 +180,21 @@ class T0LM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t0.generate(
if num_fewshot == 0:
generations = self.t0.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t0.generate(
context, context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
do_sample=False, do_sample=False,
) )
return generations[0]
...@@ -62,6 +62,7 @@ from . import gem_mlsum ...@@ -62,6 +62,7 @@ from . import gem_mlsum
from . import wino_bias from . import wino_bias
from . import e2e_nlg_cleaned from . import e2e_nlg_cleaned
from . import gem_asset_turk from . import gem_asset_turk
from . import crows_pairs_multilingual
from . import lama from . import lama
######################################## ########################################
...@@ -333,6 +334,10 @@ TASK_REGISTRY = { ...@@ -333,6 +334,10 @@ TASK_REGISTRY = {
"wino_bias_type1_anti": wino_bias.WinoBiasType1Anti, "wino_bias_type1_anti": wino_bias.WinoBiasType1Anti,
"wino_bias_type2_pro": wino_bias.WinoBiasType2Pro, "wino_bias_type2_pro": wino_bias.WinoBiasType2Pro,
"wino_bias_type2_anti": wino_bias.WinoBiasType2Anti, "wino_bias_type2_anti": wino_bias.WinoBiasType2Anti,
# Crows-Pairs
"crows_pairs_english": crows_pairs_multilingual.CrowsPairsEnglish,
"crows_pairs_french": crows_pairs_multilingual.CrowsPairsFrench,
} }
......
...@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask): ...@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)), "f1": f1_sum / max(1, len(gold_list)),
} }
def stopping_criteria(self): # def stopping_criteria(self):
return "\n\n" # return "\n\n"
# def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of # """Uses RequestFactory to construct Requests and returns an iterable of
......
"""
French CrowS-Pairs: Extending a challenge dataset for measuring social bias in masked language models to a language other than English
https://hal.inria.fr/hal-03629677/file/ACLFinal.pdf
Measuring social biases in masked language models in English and French.
https://gitlab.inria.fr/french-crows-pairs/acl-2022-paper-data-and-code/-/tree/main
"""
from lm_eval.base import PromptSourceTask
_CITATION = """\
@inproceedings{neveol2022french,
title={French CrowS-Pairs: Extending a challenge dataset for measuring social bias in masked language models to a language other than English},
author={N{\'e}v{\'e}ol, Aur{\'e}lie and Dupont, Yoann and Bezan{\c{c}}on, Julien and Fort, Kar{\"e}n},
booktitle={ACL 2022-60th Annual Meeting of the Association for Computational Linguistics},
year={2022}
"""
class CrowsPairsEnglish(PromptSourceTask):
VERSION = 0
DATASET_PATH = "oskarvanderwal/crows_pairs_multilingual"
DATASET_NAME = "english"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def training_docs(self):
pass
def validation_docs(self):
pass
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
class CrowsPairsFrench(PromptSourceTask):
VERSION = 0
DATASET_PATH = "oskarvanderwal/crows_pairs_multilingual"
DATASET_NAME = "french"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def training_docs(self):
pass
def validation_docs(self):
pass
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
...@@ -92,8 +92,8 @@ class DROP(PromptSourceTask): ...@@ -92,8 +92,8 @@ class DROP(PromptSourceTask):
# """ # """
# conts = [rf.greedy_until(ctx, ["."])] # conts = [rf.greedy_until(ctx, ["."])]
# return conts # return conts
def stopping_criteria(self): # def stopping_criteria(self):
return "." # return "."
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
......
...@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask): ...@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask):
def test_docs(self): def test_docs(self):
return self.dataset[str(self.SPLIT)] return self.dataset[str(self.SPLIT)]
def stopping_criteria(self): # def stopping_criteria(self):
return None # return None
def max_generation_length(self): def max_generation_length(self):
return 200 return 200
......
...@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask): ...@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask):
else: else:
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self): # def stopping_criteria(self):
return None # return None
def max_generation_length(self): def max_generation_length(self):
return 250 return 250
......
...@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask): ...@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def stopping_criteria(self): # def stopping_criteria(self):
return "\n" # return "\n###\n"
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
......
...@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask): ...@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask):
def test_docs(self): def test_docs(self):
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self): # def stopping_criteria(self):
return "\n" # return "\n"
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
......
...@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask): ...@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask):
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self): def stopping_criteria(self):
# TODO: Denote the string where the generation should be split. # Only define this method when you want to control few-shot generations on specific tokens.
# For example, for `coqa`, this is '\nQ:' and for `drop` '.'. # The default is set to '\n###\n'.
# NOTE: You may delete this function if the task does not required generation. # NOTE: You may delete this function if the task does not required generation.
return None return "\n###\n"
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment