Commit 4f29f7cc authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

merge big-refactor into fix branch

parents f832c776 9dea125b
group:
- super-glue-lm-eval-v1
task: "boolq-seq2seq"
dataset_path: super_glue
dataset_name: boolq
output_type: greedy_until
training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{answer_choices[label]}}"
gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
......@@ -7,7 +7,8 @@ output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int.
doc_to_target: "{{answer_choices[labe]}}"
gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}"
metric_list:
- metric: acc
......
......@@ -14,6 +14,7 @@ from typing import List, Union
import gc
import torch
import transformers
from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined
......@@ -391,7 +392,13 @@ def load_yaml_config(yaml_path):
return yaml_config
def regex_replace(string, pattern, repl, count=0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
env = Environment(loader=BaseLoader, undefined=StrictUndefined)
env.filters["regex_replace"] = regex_replace
def apply_template(template, doc):
......@@ -408,6 +415,111 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0]
if tensor_len < max_length:
if padding_side == "right":
# right-pad
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
tensor, # [seq]
],
dim=0,
).unsqueeze(0)
else:
tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim=0)
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
# Multi-token stopping criteria
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
......@@ -41,7 +41,6 @@ def parse_args():
parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--output_base_path", type=str, default=None)
......@@ -78,12 +77,6 @@ def main():
eval_logger.info(f"Selected Tasks: {task_names}")
# TODO: description_dict?
# description_dict = {}
# if args.description_dict_path:
# with open(args.description_dict_path, "r") as f:
# description_dict = json.load(f)
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
......@@ -94,7 +87,6 @@ def main():
device=args.device,
no_cache=args.no_cache,
limit=args.limit,
# description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_out=args.write_out,
......@@ -103,8 +95,7 @@ def main():
if results is not None:
samples = results.pop("samples")
dumped = json.dumps(results, indent=2)
dumped = json.dumps(results, indent=2, default=lambda o: str(o))
print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
......
......@@ -13,12 +13,10 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--output_base_path", required=True)
parser.add_argument("--tasks", default="all_tasks")
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--sets", type=str, default="val") # example: val,test
parser.add_argument("--num_fewshot", type=int, default=1)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num_examples", type=int, default=1)
parser.add_argument("--description_dict_path", default=None)
return parser.parse_args()
......@@ -32,11 +30,6 @@ def main():
task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names)
# description_dict = {}
# if args.description_dict_path:
# with open(args.description_dict_path, "r") as f:
# description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True)
for task_name, task in task_dict.items():
rnd = random.Random()
......@@ -55,12 +48,6 @@ def main():
docs = join_iters(iters)
# description = (
# description_dict[task_name]
# if description_dict and task_name in description_dict
# else ""
# )
with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in (
zip(range(args.num_examples), docs)
......@@ -72,7 +59,6 @@ def main():
doc=doc,
num_fewshot=args.num_fewshot,
rnd=rnd,
# description=description,
)
f.write(ctx + "\n")
......
......@@ -28,7 +28,9 @@ setuptools.setup(
python_requires=">=3.9",
install_requires=[
"accelerate>=0.18.0",
"evaluate",
"datasets>=2.0.0",
"evaluate>=0.4.0",
"jsonlines",
"numexpr",
"openai>=0.6.4",
......
......@@ -3,7 +3,7 @@ import lm_eval.tasks
import lm_eval.models
def test_description_dict():
def test_description():
seed = 42
num_examples = 1
task_names = ["arc_challenge", "lambada"]
......@@ -41,6 +41,5 @@ def test_description_dict():
doc=doc,
num_fewshot=1,
rnd=rnd,
description=description,
)
assert description in ctx
......@@ -61,7 +61,6 @@ def test_evaluator(taskname, task_class):
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None,
)
e2 = evaluator.evaluate(
lm=lm,
......@@ -69,7 +68,6 @@ def test_evaluator(taskname, task_class):
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None,
)
# check that caching is working
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment