Unverified Commit 1444a36c authored by KhalidAlt's avatar KhalidAlt Committed by GitHub
Browse files

Merge branch 'master' into master

parents 13676905 22155f7d
......@@ -694,11 +694,9 @@ class PromptSourceTask(Task):
def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'.
By default, its None, meaning to generate up to max or EOT, whichever comes first.
By default, its "\n###\n".
"""
return None
return "\n###\n"
def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
......
......@@ -8,6 +8,7 @@ class GPTJLM(BaseLM):
self,
device="cuda",
batch_size=1,
parallelize=False,
):
super().__init__()
......@@ -35,9 +36,11 @@ class GPTJLM(BaseLM):
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gptj = nn.DataParallel(self.gptj)
if parallelize:
self.gptj.parallelize()
self._device = torch.device('cuda:0')
else:
self.gptj.to(self._device)
@property
def eot_token(self):
......@@ -113,11 +116,23 @@ class GPTJLM(BaseLM):
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gptj.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
if num_fewshot == 0:
generations = self.gptj.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gptj.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
# Remove the context from the generations
return generations[0, context.shape[1] :]
......@@ -56,7 +56,7 @@ class T0LM(BaseLM):
@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
return 256
@property
def batch_size(self):
......@@ -94,6 +94,14 @@ class T0LM(BaseLM):
inputs, targets = zip(*chunk)
# Fill in empty encoder inputs with eos_token
inputs = (
f"{self.eot_token}"
if len(input_) == 0
else input_
for input_ in inputs
)
inputs_tok = self.tokenizer(
list(inputs),
max_length=self.max_length,
......@@ -172,11 +180,21 @@ class T0LM(BaseLM):
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t0.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
if num_fewshot == 0:
generations = self.t0.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t0.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return generations[0]
......@@ -62,6 +62,7 @@ from . import gem_mlsum
from . import wino_bias
from . import e2e_nlg_cleaned
from . import gem_asset_turk
from . import crows_pairs_multilingual
from . import lama
########################################
......@@ -333,6 +334,10 @@ TASK_REGISTRY = {
"wino_bias_type1_anti": wino_bias.WinoBiasType1Anti,
"wino_bias_type2_pro": wino_bias.WinoBiasType2Pro,
"wino_bias_type2_anti": wino_bias.WinoBiasType2Anti,
# Crows-Pairs
"crows_pairs_english": crows_pairs_multilingual.CrowsPairsEnglish,
"crows_pairs_french": crows_pairs_multilingual.CrowsPairsFrench,
}
......
......@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)),
}
def stopping_criteria(self):
return "\n\n"
# def stopping_criteria(self):
# return "\n\n"
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
......
"""
French CrowS-Pairs: Extending a challenge dataset for measuring social bias in masked language models to a language other than English
https://hal.inria.fr/hal-03629677/file/ACLFinal.pdf
Measuring social biases in masked language models in English and French.
https://gitlab.inria.fr/french-crows-pairs/acl-2022-paper-data-and-code/-/tree/main
"""
from lm_eval.base import PromptSourceTask
_CITATION = """\
@inproceedings{neveol2022french,
title={French CrowS-Pairs: Extending a challenge dataset for measuring social bias in masked language models to a language other than English},
author={N{\'e}v{\'e}ol, Aur{\'e}lie and Dupont, Yoann and Bezan{\c{c}}on, Julien and Fort, Kar{\"e}n},
booktitle={ACL 2022-60th Annual Meeting of the Association for Computational Linguistics},
year={2022}
"""
class CrowsPairsEnglish(PromptSourceTask):
VERSION = 0
DATASET_PATH = "oskarvanderwal/crows_pairs_multilingual"
DATASET_NAME = "english"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def training_docs(self):
pass
def validation_docs(self):
pass
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
class CrowsPairsFrench(PromptSourceTask):
VERSION = 0
DATASET_PATH = "oskarvanderwal/crows_pairs_multilingual"
DATASET_NAME = "french"
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def training_docs(self):
pass
def validation_docs(self):
pass
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
......@@ -92,8 +92,8 @@ class DROP(PromptSourceTask):
# """
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
def stopping_criteria(self):
return "."
# def stopping_criteria(self):
# return "."
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
......@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask):
def test_docs(self):
return self.dataset[str(self.SPLIT)]
def stopping_criteria(self):
return None
# def stopping_criteria(self):
# return None
def max_generation_length(self):
return 200
......
......@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask):
else:
return self.dataset["test"]
def stopping_criteria(self):
return None
# def stopping_criteria(self):
# return None
def max_generation_length(self):
return 250
......
......@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask):
def has_test_docs(self):
return False
def stopping_criteria(self):
return "\n"
# def stopping_criteria(self):
# return "\n###\n"
def training_docs(self):
if self._training_docs is None:
......
......@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask):
def test_docs(self):
return self.dataset["test"]
def stopping_criteria(self):
return "\n"
# def stopping_criteria(self):
# return "\n"
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
......@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask):
return self.dataset["test"]
def stopping_criteria(self):
# TODO: Denote the string where the generation should be split.
# For example, for `coqa`, this is '\nQ:' and for `drop` '.'.
# Only define this method when you want to control few-shot generations on specific tokens.
# The default is set to '\n###\n'.
# NOTE: You may delete this function if the task does not required generation.
return None
return "\n###\n"
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment