Commit 9907e0a7 authored by FarzanehNakhaee's avatar FarzanehNakhaee
Browse files

Merge branch 'big-refactor' into add-qa4mre-config

parents 649a7f95 070b6b9c
...@@ -44,10 +44,10 @@ To install additional multilingual tokenization and text segmentation packages, ...@@ -44,10 +44,10 @@ To install additional multilingual tokenization and text segmentation packages,
pip install -e ".[multilingual]" pip install -e ".[multilingual]"
``` ```
To support loading GPTQ quantized models, install the package with the `auto-gptq` extra: To support loading GPTQ quantized models, install the package with the `gptq` extra:
```bash ```bash
pip install -e ".[auto-gptq]" pip install -e ".[gptq]"
``` ```
## Basic Usage ## Basic Usage
...@@ -94,7 +94,7 @@ accelerate launch main.py \ ...@@ -94,7 +94,7 @@ accelerate launch main.py \
This will perform *data-parallel evaluation*: that is, placing a **single full copy** of your model onto each available GPU and *splitting batches across GPUs* to evaluate on K GPUs K times faster than on one. This will perform *data-parallel evaluation*: that is, placing a **single full copy** of your model onto each available GPU and *splitting batches across GPUs* to evaluate on K GPUs K times faster than on one.
However, if your model *is too large to be run on a single one of your GPUs*, then we provide an alternative method to run these large models. However, if your model *is too large to be run on a single one of your GPUs*, then we provide an alternative method to run these large models: use of the `parallelize` argument.
``` ```
python main.py \ python main.py \
...@@ -110,6 +110,8 @@ To pass even more advanced keyword arguments to `accelerate`, we allow for the f ...@@ -110,6 +110,8 @@ To pass even more advanced keyword arguments to `accelerate`, we allow for the f
- `max_cpu_memory`: the max amount of CPU memory to use when offloading the model weights to RAM. - `max_cpu_memory`: the max amount of CPU memory to use when offloading the model weights to RAM.
- `offload_folder`: a folder where model weights will be offloaded to disk if needed. - `offload_folder`: a folder where model weights will be offloaded to disk if needed.
Using this setting helps for massive models like BLOOM which require, or to avoid exceeding your total system RAM (by default, with `accelerate launch` one copy of the model for each GPU is initialized in RAM before moving it to GPU, resulting in large RAM usage spikes around the start of the script that may cause errors such as `Killed`.) However, it naively splits models across GPUs, resulting in only a single GPU performing work at any point in time, and so is much slower than launching with `accelerate launch`, possibly by a factor of the total # of GPUs.
**Note that this option requires launching evaluation via `python main.py` rather than `accelerate launch main.py`.** **Note that this option requires launching evaluation via `python main.py` rather than `accelerate launch main.py`.**
### Commercial APIs ### Commercial APIs
...@@ -158,17 +160,17 @@ For models loaded with the HuggingFace `transformers` library, any arguments pr ...@@ -158,17 +160,17 @@ For models loaded with the HuggingFace `transformers` library, any arguments pr
```bash ```bash
python main.py \ python main.py \
--model hf \ --model hf \
--model_args pretrained=EleutherAI/gpt-j-6b,peft=nomic-ai/gpt4all-j-lora \ --model_args pretrained=EleutherAI/gpt-j-6b,parallelize=True,load_in_4bit=True,peft=nomic-ai/gpt4all-j-lora \
--tasks openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq \ --tasks openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq \
--device cuda:0 --device cuda:0
``` ```
GPTQ quantized models can be loaded by specifying their file names in `,quantized=NAME` (or `,quantized=True` for default names) in the `model_args` argument: [GPTQ](https://github.com/PanQiWei/AutoGPTQ) quantized models can be loaded by specifying their file names in `,gptq=NAME` (or `,gptq=True` for default names) in the `model_args` argument:
```bash ```bash
python main.py \ python main.py \
--model hf \ --model hf \
--model_args pretrained=model-name-or-path,quantized=model.safetensors,gptq_use_triton=True \ --model_args pretrained=model-name-or-path,gptq=model.safetensors,gptq_use_triton=True \
--tasks hellaswag --tasks hellaswag
``` ```
......
...@@ -281,7 +281,7 @@ def evaluate( ...@@ -281,7 +281,7 @@ def evaluate(
"doc_id": doc_id, "doc_id": doc_id,
"doc": doc, "doc": doc,
"target": target, "target": target,
"arguments": req.args, "arguments": requests[0].args,
"resps": [req.resps for req in requests], "resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[key] for req in requests], "filtered_resps": [req.filtered_resps[key] for req in requests],
} }
...@@ -292,6 +292,15 @@ def evaluate( ...@@ -292,6 +292,15 @@ def evaluate(
if lm.world_size > 1: if lm.world_size > 1:
# if multigpu, then gather data across all ranks # if multigpu, then gather data across all ranks
# first gather logged samples across all ranks
for task_name, task_samples in list(samples.items()):
full_samples = [None] * lm.world_size
torch.distributed.all_gather_object(full_samples, task_samples)
samples[task_name] = list(itertools.chain.from_iterable(full_samples))
# then collect metrics across all ranks
vals_torch = collections.defaultdict(list) vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
......
import torch import torch
import transformers import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from peft import __version__ as PEFT_VERSION, PeftModel
import copy import copy
from collections import defaultdict from collections import defaultdict
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path
import torch.nn.functional as F import torch.nn.functional as F
...@@ -58,15 +60,16 @@ class HFLM(LM): ...@@ -58,15 +60,16 @@ class HFLM(LM):
def __init__( def __init__(
self, self,
device="cuda", pretrained: Optional[str] = "gpt2",
pretrained="gpt2", revision: Optional[str] = "main",
revision="main", subfolder: Optional[str] = None,
low_cpu_mem_usage=None, tokenizer: Optional[str] = None,
max_length=None, max_length: Optional[int] = None,
subfolder=None, device: Optional[str] = "cuda",
tokenizer=None,
batch_size=1,
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: Optional[int] = 1,
low_cpu_mem_usage: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
...@@ -74,6 +77,14 @@ class HFLM(LM): ...@@ -74,6 +77,14 @@ class HFLM(LM):
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload", offload_folder: Optional[str] = "./offload",
# PEFT and quantization options
peft: Optional[str] = None,
load_in_8bit: Optional[bool] = False,
load_in_4bit: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
gptq: Optional[Union[bool, str]] = False,
gptq_use_triton: Optional[bool] = False,
): ):
super().__init__() super().__init__()
...@@ -82,11 +93,16 @@ class HFLM(LM): ...@@ -82,11 +93,16 @@ class HFLM(LM):
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
accelerator = Accelerator()
if gpus <= 1 and not parallelize: if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set(
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device: if device:
if device not in ["cuda", "cpu"]: if device not in device_list:
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
...@@ -100,7 +116,7 @@ class HFLM(LM): ...@@ -100,7 +116,7 @@ class HFLM(LM):
) )
else: else:
eval_logger.info( eval_logger.info(
f"Passed device '{device}', but using `accelerate launch` or `parallelize=True`. This will be overridden when placing model." f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
) )
# TODO: include in warning that `load_in_8bit` etc. affect this too # TODO: include in warning that `load_in_8bit` etc. affect this too
self._device = device self._device = device
...@@ -117,10 +133,10 @@ class HFLM(LM): ...@@ -117,10 +133,10 @@ class HFLM(LM):
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
# get config
self._config = transformers.AutoConfig.from_pretrained( self._config = transformers.AutoConfig.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code,
) )
if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
...@@ -133,13 +149,56 @@ class HFLM(LM): ...@@ -133,13 +149,56 @@ class HFLM(LM):
transformers.AutoModelForSeq2SeqLM, transformers.AutoModelForSeq2SeqLM,
] ]
if not gptq:
if load_in_4bit:
assert (
transformers.__version__ >= "4.30.0"
), "load_in_4bit requires transformers >= 4.30.0"
if transformers.__version__ >= "4.30.0":
model_kwargs["load_in_4bit"] = load_in_4bit
if load_in_4bit:
if bnb_4bit_quant_type:
model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type
if bnb_4bit_compute_dtype:
model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype(
bnb_4bit_compute_dtype
)
self._model = self.AUTO_MODEL_CLASS.from_pretrained( self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
torch_dtype=utils.get_dtype(dtype),
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
trust_remote_code=trust_remote_code,
load_in_8bit=load_in_8bit,
**model_kwargs, **model_kwargs,
torch_dtype=utils.get_dtype(dtype),
) )
else:
try:
from auto_gptq import AutoGPTQForCausalLM
except ModuleNotFoundError:
raise Exception(
"Tried to load auto_gptq, but auto-gptq is not installed ",
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
)
self._model = AutoGPTQForCausalLM.from_quantized(
pretrained,
model_basename=None if gptq is True else Path(gptq).stem,
low_cpu_mem_usage=low_cpu_mem_usage,
trust_remote_code=trust_remote_code,
use_safetensors=True if gptq is True else gptq.endswith(".safetensors"),
use_triton=gptq_use_triton,
warmup_triton=gptq_use_triton,
**model_kwargs,
)
if peft:
if load_in_4bit:
assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0"
self._model = PeftModel.from_pretrained(
self._model, peft, revision=revision
)
# forever after, access self._model through self.model property # forever after, access self._model through self.model property
self.model.eval() self.model.eval()
self.model.tie_weights() self.model.tie_weights()
...@@ -150,6 +209,7 @@ class HFLM(LM): ...@@ -150,6 +209,7 @@ class HFLM(LM):
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, pretrained if tokenizer is None else tokenizer,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code,
) )
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
...@@ -162,7 +222,6 @@ class HFLM(LM): ...@@ -162,7 +222,6 @@ class HFLM(LM):
# multigpu data-parallel support when launched with accelerate # multigpu data-parallel support when launched with accelerate
if gpus > 1: if gpus > 1:
accelerator = Accelerator()
if parallelize: if parallelize:
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
raise RuntimeError( raise RuntimeError(
...@@ -361,16 +420,27 @@ class HFLM(LM): ...@@ -361,16 +420,27 @@ class HFLM(LM):
return logits return logits
def _encode_pair(self, context, continuation):
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc = [self.eot_token_id] context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
continuation
)
else: else:
context_enc = self.tok_encode(context) context_enc, continuation_enc = self._encode_pair(context, continuation)
continuation_enc = self.tok_encode(continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc)) new_reqs.append(((context, continuation), context_enc, continuation_enc))
...@@ -442,7 +512,6 @@ class HFLM(LM): ...@@ -442,7 +512,6 @@ class HFLM(LM):
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
self.batch_size, self.batch_size,
): ):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
inplens = [] inplens = []
...@@ -544,7 +613,6 @@ class HFLM(LM): ...@@ -544,7 +613,6 @@ class HFLM(LM):
for (cache_key, _, _), logits, inplen, cont_toks in zip( for (cache_key, _, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list chunk, multi_logits, inplens, cont_toks_list
): ):
# Slice to original seq length # Slice to original seq length
contlen = len(cont_toks) contlen = len(cont_toks)
# take only logits in the continuation # take only logits in the continuation
......
...@@ -12,7 +12,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -12,7 +12,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] Lambada (Multilingual) - [ ] Lambada (Multilingual)
- [x] Wikitext - [x] Wikitext
- [x] PiQA - [x] PiQA
- [ ] PROST - [ ] PROST (WIP)
- [ ] MCTACO - [ ] MCTACO
- [ ] Pubmed QA (WIP) - [ ] Pubmed QA (WIP)
- [x] SciQ - [x] SciQ
...@@ -20,11 +20,11 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -20,11 +20,11 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] QA4MRE - [ ] QA4MRE
- [ ] TriviaQA - [ ] TriviaQA
- [x] AI2 ARC - [x] AI2 ARC
- [ ] LogiQA - [ ] LogiQA (WIP)
- [x] HellaSwag - [x] HellaSwag
- [ ] SWAG (WIP) - [ ] SWAG (WIP)
- [x] OpenBookQA - [x] OpenBookQA
- [ ] SQuADv2 - [ ] SQuADv2 (WIP)
- [ ] RACE (WIP) - [ ] RACE (WIP)
- [ ] HeadQA - [ ] HeadQA
- [ ] MathQA - [ ] MathQA
...@@ -35,7 +35,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -35,7 +35,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] Hendrycks Ethics - [ ] Hendrycks Ethics
- [ ] TruthfulQA - [ ] TruthfulQA
- [ ] MuTual - [ ] MuTual
- [ ] Hendrycks Math - [ ] Hendrycks Math (WIP)
- [ ] Asdiv - [ ] Asdiv
- [ ] GSM8k - [ ] GSM8k
- [ ] Arithmetic (WIP) - [ ] Arithmetic (WIP)
...@@ -45,6 +45,8 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -45,6 +45,8 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] ~~Pile (perplexity)~~ - [x] ~~Pile (perplexity)~~
- [ ] BLiMP - [ ] BLiMP
- [ ] ToxiGen - [ ] ToxiGen
- [ ] StoryCloze
- [ ] NaturalQs
- [ ] CrowS-Pairs - [ ] CrowS-Pairs
- [ ] XCopa - [ ] XCopa
- [ ] BIG-Bench - [ ] BIG-Bench
...@@ -55,6 +57,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -55,6 +57,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] MGSM - [ ] MGSM
- [ ] SCROLLS - [ ] SCROLLS
- [ ] JSON Task (reference: https://github.com/EleutherAI/lm-evaluation-harness/pull/481) - [ ] JSON Task (reference: https://github.com/EleutherAI/lm-evaluation-harness/pull/481)
- [ ] Babi
# Novel Tasks # Novel Tasks
Tasks added in the revamped harness that were not previously available. Again, a strikethrough denotes checking performed *against the original task's implementation or published results introducing the task*. Tasks added in the revamped harness that were not previously available. Again, a strikethrough denotes checking performed *against the original task's implementation or published results introducing the task*.
......
...@@ -7,7 +7,7 @@ output_type: multiple_choice ...@@ -7,7 +7,7 @@ output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
test_split: null test_split: null
template_aliases: "{% set gold = label %}{% set answer_choices = endings|map('trim')|map('replace', ' [title]', '. ')|map('regex_replace', '\\[.*?\\]', '')|map('replace', ' ', ' ')|list %}" template_aliases: "{% set gold = label | int %}{% set answer_choices = endings|map('trim')|map('replace', ' [title]', '. ')|map('regex_replace', '\\[.*?\\]', '')|map('replace', ' ', ' ')|list %}"
doc_to_text: "{% set text = activity_label ~ ': ' ~ ctx_a ~ ' ' ~ ctx_b.capitalize() %}{{text|trim|replace(' [title]', '. ')|regex_replace('\\[.*?\\]', '')|replace(' ', ' ')}}" doc_to_text: "{% set text = activity_label ~ ': ' ~ ctx_a ~ ' ' ~ ctx_b.capitalize() %}{{text|trim|replace(' [title]', '. ')|regex_replace('\\[.*?\\]', '')|replace(' ', ' ')}}"
doc_to_target: "{{answer_choices[gold]}}" doc_to_target: "{{answer_choices[gold]}}"
gold_alias: "{{gold}}" gold_alias: "{{gold}}"
......
group:
- super-glue-promptsource
task: "GPT-3 Style"
dataset_path: super_glue
dataset_name: boolq
training_split: train
validation_split: validation
use_prompt: "promptsource:GPT-3 Style"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "based on the previous passage"
use_prompt: "promptsource:based on the previous passage"
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "based on the following passage"
use_prompt: "promptsource:based on the following passage"
group: group:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1-seq2seq
task: "boolq-seq2seq" task: "boolq-seq2seq"
dataset_path: super_glue dataset_path: super_glue
dataset_name: boolq dataset_name: boolq
......
group: group:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: "default" task: "cb"
dataset_path: super_glue dataset_path: super_glue
dataset_name: cb dataset_name: cb
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:" doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: "{{answer_choices[labe]}}" doc_to_target: "{{answer_choices[label]}}"
gold_alias: "{{label}}" # this will be cast to an int. gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}" template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}"
metric_list: metric_list:
......
group:
- super-glue-promptsource
task: "GPT-3 style"
dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
use_prompt: "promptsource:GPT-3 style"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "MNLI crowdsource"
use_prompt: "promptsource:MNLI crowdsource"
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "based on the previous passage"
use_prompt: "promptsource:based on the previous passage"
group: group:
- super-glue-t5-prompt - super-glue-t5-prompt
task: t5-prompt task: super_glue-cb-t5-prompt
reference: "From Raffel et. al. 2019"
dataset_path: super_glue dataset_path: super_glue
dataset_name: cb dataset_name: cb
training_split: train training_split: train
......
group:
- super-glue-lm-eval-v1-
task: "copa"
dataset_path: super_glue
dataset_name: copa
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = [{{doc.choice1}}, 'b'] %} {{answer_choices}}"
metric_list:
- metric: acc
group:
- super-glue-promptsource
task: "C1 or C2? premise, so/because…"
dataset_path: super_glue
dataset_name: copa
training_split: train
validation_split: validation
use_prompt: "promptsource:C1 or C2? premise, so/because…"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "best_option"
use_prompt: "promptsource:best_option"
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "cause_effect"
use_prompt: "promptsource:cause_effect"
group: group:
- super-glue-t5-prompt - super-glue-t5-prompt
task: t5-prompt task: super_glue-copa-t5-prompt
reference: "From Raffel et. al. 2019"
dataset_path: super_glue dataset_path: super_glue
dataset_name: copa dataset_name: copa
training_split: train training_split: train
......
def convert_choice(choice):
return choice[0].lower() + choice[1:]
def doc_to_text(doc):
# Drop the period
connector = {
"cause": "because",
"effect": "therefore",
}[doc["question"]]
return doc["premise"].strip()[:-1] + f" {connector}"
def doc_to_target(doc):
correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"]
# Connect the sentences
return " " + convert_choice(correct_choice)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment