Unverified Commit 23f30926 authored by gakada's avatar gakada Committed by GitHub
Browse files

Fix LLaMA tokenization issue (#531)



* Fix tokenization issue in BaseLM.loglikelihood

* Add a regression script

* Use entire non-continuation length as context

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 441e6ac1
...@@ -170,16 +170,27 @@ class BaseLM(LM): ...@@ -170,16 +170,27 @@ class BaseLM(LM):
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length. # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
# TODO: enforce this somehow # TODO: enforce this somehow
def _encode_pair(self, context, continuation):
whole_enc = self.tok_encode(context + continuation)
whole_enc_len = len(whole_enc)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
if context_enc_len < whole_enc_len:
continuation_enc = whole_enc[context_enc_len:]
else:
continuation_enc = self.tok_encode(continuation)
continuation_enc_len = len(continuation_enc)
context_enc = whole_enc[:-continuation_enc_len]
return context_enc, continuation_enc
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
for context, continuation in requests: for context, continuation in requests:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc = [self.eot_token_id] context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(continuation)
else: else:
context_enc = self.tok_encode(context) context_enc, continuation_enc = self._encode_pair(context, continuation)
continuation_enc = self.tok_encode(continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc)) new_reqs.append(((context, continuation), context_enc, continuation_enc))
...@@ -264,7 +275,7 @@ class BaseLM(LM): ...@@ -264,7 +275,7 @@ class BaseLM(LM):
_, context_enc, continuation_enc = re_ord.get_reordered()[0] _, context_enc, continuation_enc = re_ord.get_reordered()[0]
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]) max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
if (self.batch_size == 'auto'): if (self.batch_size == 'auto'):
if override_bs is None: if override_bs is None:
print('Passed argument batch_size = auto. Detecting largest batch size') print('Passed argument batch_size = auto. Detecting largest batch size')
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
......
...@@ -5,6 +5,7 @@ import collections ...@@ -5,6 +5,7 @@ import collections
import functools import functools
import inspect import inspect
import sys import sys
import fnmatch
from typing import List, Union from typing import List, Union
import torch import torch
...@@ -84,6 +85,42 @@ def group(arr, fn): ...@@ -84,6 +85,42 @@ def group(arr, fn):
return list(res.values()) return list(res.values())
def _is_json_task(task_name):
return task_name == "json" or task_name.startswith("json=")
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0 and not _is_json_task(
value
):
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
if _is_json_task(pattern):
task_names.add(pattern)
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return sorted(list(task_names))
def general_detokenize(string): def general_detokenize(string):
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" )", ")") string = string.replace(" )", ")")
......
import argparse import argparse
import json import json
import logging import logging
import fnmatch
import os import os
from lm_eval import tasks, evaluator from lm_eval import tasks, evaluator, utils
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
def _is_json_task(task_name):
return task_name == "json" or task_name.startswith("json=")
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0 and not _is_json_task(
value
):
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True) parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="") parser.add_argument("--model_args", default="")
parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS)) parser.add_argument("--tasks", default=None, choices=utils.MultiChoice(tasks.ALL_TASKS))
parser.add_argument("--provide_description", action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0) parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=str, default=None) parser.add_argument("--batch_size", type=str, default=None)
...@@ -56,19 +32,6 @@ def parse_args(): ...@@ -56,19 +32,6 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
if _is_json_task(pattern):
task_names.add(pattern)
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return sorted(list(task_names))
def main(): def main():
args = parse_args() args = parse_args()
...@@ -82,7 +45,7 @@ def main(): ...@@ -82,7 +45,7 @@ def main():
if args.tasks is None: if args.tasks is None:
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
else: else:
task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS) task_names = utils.pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}") print(f"Selected Tasks: {task_names}")
......
import argparse
import json
import os
import subprocess
import time
from pathlib import Path
from lm_eval import tasks, utils
seq2seq_models = ["google/flan-t5-small"]
causal_models = ["gpt2", "facebook/opt-125m", "EleutherAI/gpt-neo-125m", "EleutherAI/pythia-160m"]
model_names = seq2seq_models + causal_models
completion_tasks = ["boolq", "lambada_openai", "winogrande"]
choice_tasks = ["hellaswag", "openbookqa", "piqa"]
perplexity_tasks = ["wikitext"]
generation_tasks = []
task_names = completion_tasks + choice_tasks + perplexity_tasks + generation_tasks
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--branches", default=[])
parser.add_argument("--models", default=model_names)
parser.add_argument("--tasks", default=task_names)
parser.add_argument("--acc_norm", type=bool, default=False)
parser.add_argument("--perplexity", default=None)
# TODO: implement num_fewshot and limit per task, e.g. task1:5,task2:1:100,task3::1000
parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--limit", type=float, default=None)
# TODO: implement hf-auto to pick between causal and seq2seq models so we don't need this
parser.add_argument("--model", default="hf-causal-experimental")
# Use whatever is faster here
parser.add_argument("--model_args", default="use_accelerate=True,load_in_8bit=True")
parser.add_argument("--batch_size", default="auto")
return parser.parse_args()
def eval_models(args, branch=None):
if branch is not None:
if os.system(f"git checkout {branch}") != 0:
return {}, 0
branch = branch or initial_branch
start_time = time.time()
results = {}
for model in args.models:
model_type = "hf-causal-experimental" if model in causal_models \
else "hf-seq2seq" if model in seq2seq_models else args.model
model_args = f"pretrained={model},{args.model_args}"
# TODO: split_and_pad_windows in AutoSeq2SeqLM doesn"t exist, #527
tasks = args.tasks if model in causal_models or model_type == "hf-causal-experimental" \
else list(filter(lambda task: task not in perplexity_tasks, args.tasks))
# TODO: OOM with auto for seq2seq models, also can OOM with llama
batch_size = args.batch_size if model in causal_models or model_type == "hf-causal-experimental" \
else 64 if args.batch_size == "auto" else args.batch_size
output_path = f"data/regression/{int(start_time)}-{branch}-{Path(model).name}.json"
command = f"python3 main.py --model {model_type} --model_args {model_args} --tasks {','.join(tasks)} " \
f"--num_fewshot {args.num_fewshot}{'' if args.limit is None else f' --limit {args.limit}'} " \
f"--batch_size {batch_size} --no_cache --output_path {output_path}"
print(f"{'=' * 80}\nEvaluating {model} on {', '.join(tasks)} at {branch} with:\n\n{command}\n{'=' * 80}")
ret = os.system(command)
results[model] = json.load(open(output_path)) if ret == 0 else {"results": {}}
end_time = time.time()
return results, end_time - start_time
def extract_value(args, results, model, task, err=False):
if model not in results:
return 0
results = results[model]["results"]
if task not in results:
return 0
results = results[task]
if args.acc_norm and "acc_norm" in results:
return results["acc_norm"] if not err else results["acc_norm_stderr"]
if "acc" in results:
return results["acc"] if not err else results["acc_stderr"]
if (args.perplexity or "word_perplexity") in results:
return results[args.perplexity or "word_perplexity"] if not err else 0
return 0
def format_value(args, results, model, task):
val = 100 * extract_value(args, results, model, task)
err = 100 * extract_value(args, results, model, task, err=True)
return f"{val:.2f}{f' ± {err:.2f}' if err != 0 else ''}"
def format_diff(args, results1, results2, model, task):
val1 = 100 * extract_value(args, results1, model, task)
val2 = 100 * extract_value(args, results2, model, task)
diff = val2 - val1
return f"**+{diff:.2f}**" if diff > 0 else f"{diff:.2f}"
def main():
args = parse_args()
args.branches = args.branches.split(",") if type(args.branches) == str else args.branches
args.models = args.models.split(",") if type(args.models) == str else args.models
args.tasks = tasks.ALL_TASKS if args.tasks == "all_tasks" \
else utils.pattern_match(args.tasks.split(",") if type(args.tasks) == str else args.tasks, tasks.ALL_TASKS)
global initial_branch
initial_branch = subprocess.check_output("git branch --show-current", shell=True).decode("ascii").strip()
# TODO: implement proper timing for each task
# TODO: reduce IO by sharing tasks between models?
results, runtime = eval_models(args)
print(results, runtime)
runs = []
for branch in args.branches:
runs.append((branch, *eval_models(args, branch)))
os.system(f"git checkout {initial_branch}")
print("")
print(f"|task|{'|'.join(map(lambda model: Path(model).name, args.models))}|")
print(f"|--|{'--|' * len(args.models)}")
for task in args.tasks:
print(f"|{task} ({initial_branch})|{'|'.join(map(lambda model: format_value(args, results, model, task), args.models))}|")
for branch, branch_results, branch_runtime in runs:
print(f"|{task} ({branch})|{'|'.join(map(lambda model: format_value(args, branch_results, model, task), args.models))}|")
print(f"|{task} (diff)|{'|'.join(map(lambda model: format_diff(args, results, branch_results, model, task), args.models))}|")
print("")
print("|branch|runtime|%|")
print("|--|--|--|")
print(f"|{initial_branch}|{runtime:.1f}s|100%|")
for branch, _, branch_runtime in runs:
print(f"|{branch}|{branch_runtime:.1f}s|{100 * branch_runtime / runtime:.2f}%|")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment