Unverified Commit 23f30926 authored by gakada's avatar gakada Committed by GitHub
Browse files

Fix LLaMA tokenization issue (#531)



* Fix tokenization issue in BaseLM.loglikelihood

* Add a regression script

* Use entire non-continuation length as context

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 441e6ac1
......@@ -170,16 +170,27 @@ class BaseLM(LM):
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
# TODO: enforce this somehow
def _encode_pair(self, context, continuation):
whole_enc = self.tok_encode(context + continuation)
whole_enc_len = len(whole_enc)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
if context_enc_len < whole_enc_len:
continuation_enc = whole_enc[context_enc_len:]
else:
continuation_enc = self.tok_encode(continuation)
continuation_enc_len = len(continuation_enc)
context_enc = whole_enc[:-continuation_enc_len]
return context_enc, continuation_enc
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [self.eot_token_id]
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(continuation)
else:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation)
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
......@@ -264,7 +275,7 @@ class BaseLM(LM):
_, context_enc, continuation_enc = re_ord.get_reordered()[0]
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
if (self.batch_size == 'auto'):
if override_bs is None:
print('Passed argument batch_size = auto. Detecting largest batch size')
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
......
......@@ -5,6 +5,7 @@ import collections
import functools
import inspect
import sys
import fnmatch
from typing import List, Union
import torch
......@@ -84,6 +85,42 @@ def group(arr, fn):
return list(res.values())
def _is_json_task(task_name):
return task_name == "json" or task_name.startswith("json=")
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0 and not _is_json_task(
value
):
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
if _is_json_task(pattern):
task_names.add(pattern)
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return sorted(list(task_names))
def general_detokenize(string):
string = string.replace(" n't", "n't")
string = string.replace(" )", ")")
......
import argparse
import json
import logging
import fnmatch
import os
from lm_eval import tasks, evaluator
from lm_eval import tasks, evaluator, utils
logging.getLogger("openai").setLevel(logging.WARNING)
def _is_json_task(task_name):
return task_name == "json" or task_name.startswith("json=")
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0 and not _is_json_task(
value
):
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="")
parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument("--tasks", default=None, choices=utils.MultiChoice(tasks.ALL_TASKS))
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=str, default=None)
......@@ -56,19 +32,6 @@ def parse_args():
return parser.parse_args()
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
if _is_json_task(pattern):
task_names.add(pattern)
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return sorted(list(task_names))
def main():
args = parse_args()
......@@ -82,7 +45,7 @@ def main():
if args.tasks is None:
task_names = tasks.ALL_TASKS
else:
task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
task_names = utils.pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")
......
import argparse
import json
import os
import subprocess
import time
from pathlib import Path
from lm_eval import tasks, utils
seq2seq_models = ["google/flan-t5-small"]
causal_models = ["gpt2", "facebook/opt-125m", "EleutherAI/gpt-neo-125m", "EleutherAI/pythia-160m"]
model_names = seq2seq_models + causal_models
completion_tasks = ["boolq", "lambada_openai", "winogrande"]
choice_tasks = ["hellaswag", "openbookqa", "piqa"]
perplexity_tasks = ["wikitext"]
generation_tasks = []
task_names = completion_tasks + choice_tasks + perplexity_tasks + generation_tasks
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--branches", default=[])
parser.add_argument("--models", default=model_names)
parser.add_argument("--tasks", default=task_names)
parser.add_argument("--acc_norm", type=bool, default=False)
parser.add_argument("--perplexity", default=None)
# TODO: implement num_fewshot and limit per task, e.g. task1:5,task2:1:100,task3::1000
parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--limit", type=float, default=None)
# TODO: implement hf-auto to pick between causal and seq2seq models so we don't need this
parser.add_argument("--model", default="hf-causal-experimental")
# Use whatever is faster here
parser.add_argument("--model_args", default="use_accelerate=True,load_in_8bit=True")
parser.add_argument("--batch_size", default="auto")
return parser.parse_args()
def eval_models(args, branch=None):
if branch is not None:
if os.system(f"git checkout {branch}") != 0:
return {}, 0
branch = branch or initial_branch
start_time = time.time()
results = {}
for model in args.models:
model_type = "hf-causal-experimental" if model in causal_models \
else "hf-seq2seq" if model in seq2seq_models else args.model
model_args = f"pretrained={model},{args.model_args}"
# TODO: split_and_pad_windows in AutoSeq2SeqLM doesn"t exist, #527
tasks = args.tasks if model in causal_models or model_type == "hf-causal-experimental" \
else list(filter(lambda task: task not in perplexity_tasks, args.tasks))
# TODO: OOM with auto for seq2seq models, also can OOM with llama
batch_size = args.batch_size if model in causal_models or model_type == "hf-causal-experimental" \
else 64 if args.batch_size == "auto" else args.batch_size
output_path = f"data/regression/{int(start_time)}-{branch}-{Path(model).name}.json"
command = f"python3 main.py --model {model_type} --model_args {model_args} --tasks {','.join(tasks)} " \
f"--num_fewshot {args.num_fewshot}{'' if args.limit is None else f' --limit {args.limit}'} " \
f"--batch_size {batch_size} --no_cache --output_path {output_path}"
print(f"{'=' * 80}\nEvaluating {model} on {', '.join(tasks)} at {branch} with:\n\n{command}\n{'=' * 80}")
ret = os.system(command)
results[model] = json.load(open(output_path)) if ret == 0 else {"results": {}}
end_time = time.time()
return results, end_time - start_time
def extract_value(args, results, model, task, err=False):
if model not in results:
return 0
results = results[model]["results"]
if task not in results:
return 0
results = results[task]
if args.acc_norm and "acc_norm" in results:
return results["acc_norm"] if not err else results["acc_norm_stderr"]
if "acc" in results:
return results["acc"] if not err else results["acc_stderr"]
if (args.perplexity or "word_perplexity") in results:
return results[args.perplexity or "word_perplexity"] if not err else 0
return 0
def format_value(args, results, model, task):
val = 100 * extract_value(args, results, model, task)
err = 100 * extract_value(args, results, model, task, err=True)
return f"{val:.2f}{f' ± {err:.2f}' if err != 0 else ''}"
def format_diff(args, results1, results2, model, task):
val1 = 100 * extract_value(args, results1, model, task)
val2 = 100 * extract_value(args, results2, model, task)
diff = val2 - val1
return f"**+{diff:.2f}**" if diff > 0 else f"{diff:.2f}"
def main():
args = parse_args()
args.branches = args.branches.split(",") if type(args.branches) == str else args.branches
args.models = args.models.split(",") if type(args.models) == str else args.models
args.tasks = tasks.ALL_TASKS if args.tasks == "all_tasks" \
else utils.pattern_match(args.tasks.split(",") if type(args.tasks) == str else args.tasks, tasks.ALL_TASKS)
global initial_branch
initial_branch = subprocess.check_output("git branch --show-current", shell=True).decode("ascii").strip()
# TODO: implement proper timing for each task
# TODO: reduce IO by sharing tasks between models?
results, runtime = eval_models(args)
print(results, runtime)
runs = []
for branch in args.branches:
runs.append((branch, *eval_models(args, branch)))
os.system(f"git checkout {initial_branch}")
print("")
print(f"|task|{'|'.join(map(lambda model: Path(model).name, args.models))}|")
print(f"|--|{'--|' * len(args.models)}")
for task in args.tasks:
print(f"|{task} ({initial_branch})|{'|'.join(map(lambda model: format_value(args, results, model, task), args.models))}|")
for branch, branch_results, branch_runtime in runs:
print(f"|{task} ({branch})|{'|'.join(map(lambda model: format_value(args, branch_results, model, task), args.models))}|")
print(f"|{task} (diff)|{'|'.join(map(lambda model: format_diff(args, results, branch_results, model, task), args.models))}|")
print("")
print("|branch|runtime|%|")
print("|--|--|--|")
print(f"|{initial_branch}|{runtime:.1f}s|100%|")
for branch, _, branch_runtime in runs:
print(f"|{branch}|{branch_runtime:.1f}s|{100 * branch_runtime / runtime:.2f}%|")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment