Commit 4f29f7cc authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

merge big-refactor into fix branch

parents f832c776 9dea125b
group:
- super-glue-lm-eval-v1
task: "boolq-seq2seq"
dataset_path: super_glue
dataset_name: boolq
output_type: greedy_until
training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{answer_choices[label]}}"
gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
...@@ -7,7 +7,8 @@ output_type: multiple_choice ...@@ -7,7 +7,8 @@ output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:" doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int. doc_to_target: "{{answer_choices[labe]}}"
gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}" template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}"
metric_list: metric_list:
- metric: acc - metric: acc
......
...@@ -14,6 +14,7 @@ from typing import List, Union ...@@ -14,6 +14,7 @@ from typing import List, Union
import gc import gc
import torch import torch
import transformers
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
...@@ -391,7 +392,13 @@ def load_yaml_config(yaml_path): ...@@ -391,7 +392,13 @@ def load_yaml_config(yaml_path):
return yaml_config return yaml_config
def regex_replace(string, pattern, repl, count=0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
env = Environment(loader=BaseLoader, undefined=StrictUndefined) env = Environment(loader=BaseLoader, undefined=StrictUndefined)
env.filters["regex_replace"] = regex_replace
def apply_template(template, doc): def apply_template(template, doc):
...@@ -408,6 +415,111 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -408,6 +415,111 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0]
if tensor_len < max_length:
if padding_side == "right":
# right-pad
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
tensor, # [seq]
],
dim=0,
).unsqueeze(0)
else:
tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim=0)
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
# Multi-token stopping criteria
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
...@@ -41,7 +41,6 @@ def parse_args(): ...@@ -41,7 +41,6 @@ def parse_args():
parser.add_argument("--data_sampling", type=float, default=None) parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--no_cache", action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true") parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False) parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--output_base_path", type=str, default=None) parser.add_argument("--output_base_path", type=str, default=None)
...@@ -78,12 +77,6 @@ def main(): ...@@ -78,12 +77,6 @@ def main():
eval_logger.info(f"Selected Tasks: {task_names}") eval_logger.info(f"Selected Tasks: {task_names}")
# TODO: description_dict?
# description_dict = {}
# if args.description_dict_path:
# with open(args.description_dict_path, "r") as f:
# description_dict = json.load(f)
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
model=args.model, model=args.model,
model_args=args.model_args, model_args=args.model_args,
...@@ -94,7 +87,6 @@ def main(): ...@@ -94,7 +87,6 @@ def main():
device=args.device, device=args.device,
no_cache=args.no_cache, no_cache=args.no_cache,
limit=args.limit, limit=args.limit,
# description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path, decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
write_out=args.write_out, write_out=args.write_out,
...@@ -103,8 +95,7 @@ def main(): ...@@ -103,8 +95,7 @@ def main():
if results is not None: if results is not None:
samples = results.pop("samples") samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=lambda o: str(o))
dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
......
...@@ -13,12 +13,10 @@ def parse_args(): ...@@ -13,12 +13,10 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--output_base_path", required=True) parser.add_argument("--output_base_path", required=True)
parser.add_argument("--tasks", default="all_tasks") parser.add_argument("--tasks", default="all_tasks")
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--sets", type=str, default="val") # example: val,test parser.add_argument("--sets", type=str, default="val") # example: val,test
parser.add_argument("--num_fewshot", type=int, default=1) parser.add_argument("--num_fewshot", type=int, default=1)
parser.add_argument("--seed", type=int, default=42) parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num_examples", type=int, default=1) parser.add_argument("--num_examples", type=int, default=1)
parser.add_argument("--description_dict_path", default=None)
return parser.parse_args() return parser.parse_args()
...@@ -32,11 +30,6 @@ def main(): ...@@ -32,11 +30,6 @@ def main():
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names) task_dict = tasks.get_task_dict(task_names)
# description_dict = {}
# if args.description_dict_path:
# with open(args.description_dict_path, "r") as f:
# description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True) os.makedirs(args.output_base_path, exist_ok=True)
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
rnd = random.Random() rnd = random.Random()
...@@ -55,12 +48,6 @@ def main(): ...@@ -55,12 +48,6 @@ def main():
docs = join_iters(iters) docs = join_iters(iters)
# description = (
# description_dict[task_name]
# if description_dict and task_name in description_dict
# else ""
# )
with open(os.path.join(args.output_base_path, task_name), "w") as f: with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in ( for i, doc in (
zip(range(args.num_examples), docs) zip(range(args.num_examples), docs)
...@@ -72,7 +59,6 @@ def main(): ...@@ -72,7 +59,6 @@ def main():
doc=doc, doc=doc,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
rnd=rnd, rnd=rnd,
# description=description,
) )
f.write(ctx + "\n") f.write(ctx + "\n")
......
...@@ -28,7 +28,9 @@ setuptools.setup( ...@@ -28,7 +28,9 @@ setuptools.setup(
python_requires=">=3.9", python_requires=">=3.9",
install_requires=[ install_requires=[
"accelerate>=0.18.0", "accelerate>=0.18.0",
"evaluate",
"datasets>=2.0.0", "datasets>=2.0.0",
"evaluate>=0.4.0",
"jsonlines", "jsonlines",
"numexpr", "numexpr",
"openai>=0.6.4", "openai>=0.6.4",
......
...@@ -3,7 +3,7 @@ import lm_eval.tasks ...@@ -3,7 +3,7 @@ import lm_eval.tasks
import lm_eval.models import lm_eval.models
def test_description_dict(): def test_description():
seed = 42 seed = 42
num_examples = 1 num_examples = 1
task_names = ["arc_challenge", "lambada"] task_names = ["arc_challenge", "lambada"]
...@@ -41,6 +41,5 @@ def test_description_dict(): ...@@ -41,6 +41,5 @@ def test_description_dict():
doc=doc, doc=doc,
num_fewshot=1, num_fewshot=1,
rnd=rnd, rnd=rnd,
description=description,
) )
assert description in ctx assert description in ctx
...@@ -61,7 +61,6 @@ def test_evaluator(taskname, task_class): ...@@ -61,7 +61,6 @@ def test_evaluator(taskname, task_class):
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None,
) )
e2 = evaluator.evaluate( e2 = evaluator.evaluate(
lm=lm, lm=lm,
...@@ -69,7 +68,6 @@ def test_evaluator(taskname, task_class): ...@@ -69,7 +68,6 @@ def test_evaluator(taskname, task_class):
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None,
) )
# check that caching is working # check that caching is working
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment