"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5469369480d01b357947ac938e6069e117d154d8"
Unverified Commit 1d651868 authored by Leandro von Werra's avatar Leandro von Werra Committed by GitHub
Browse files

add custom stopping criteria to human eval script (#14897)

parent 6b655cc6
transformers==4.12.2 transformers==4.15.0
datasets==1.16.0 datasets==1.16.0
accelerate==0.5.1 accelerate==0.5.1
wandb==0.12.0 wandb==0.12.0
......
...@@ -8,12 +8,40 @@ from tqdm import tqdm ...@@ -8,12 +8,40 @@ from tqdm import tqdm
import transformers import transformers
from arguments import HumanEvalArguments from arguments import HumanEvalArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, pipeline, set_seed from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
StoppingCriteria,
StoppingCriteriaList,
pipeline,
set_seed,
)
EOF_STRINGS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"]
class EndOfFunctionCriteria(StoppingCriteria):
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
def __init__(self, start_length, eof_strings, tokenizer):
self.start_length = start_length
self.eof_strings = eof_strings
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
"""Returns true if all generated sequences contain any of the end-of-function strings."""
decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
done = []
for decoded_generation in decoded_generations:
done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings]))
return all(done)
def first_block(string): def first_block(string):
"""Split off first block of code by scanning for class, def etc. on newlines.""" """Split off first block of code by scanning for class, def etc. on newlines."""
return re.split("\nclass|\ndef|\n#|\n@|\nprint|\nif", string)[0].rstrip() return re.split("|".join(EOF_STRINGS), string)[0].rstrip()
def complete_code(pipe, prompt, num_completions=1, **gen_kwargs): def complete_code(pipe, prompt, num_completions=1, **gen_kwargs):
...@@ -39,6 +67,11 @@ def main(): ...@@ -39,6 +67,11 @@ def main():
set_seed(args.seed) set_seed(args.seed)
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int)
# Generation settings # Generation settings
gen_kwargs = { gen_kwargs = {
"do_sample": args.do_sample, "do_sample": args.do_sample,
...@@ -46,13 +79,9 @@ def main(): ...@@ -46,13 +79,9 @@ def main():
"max_new_tokens": args.max_new_tokens, "max_new_tokens": args.max_new_tokens,
"top_p": args.top_p, "top_p": args.top_p,
"top_k": args.top_k, "top_k": args.top_k,
"stopping_criteria": StoppingCriteriaList([EndOfFunctionCriteria(0, EOF_STRINGS, tokenizer)]),
} }
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int)
# Load evaluation dataset and metric # Load evaluation dataset and metric
human_eval = load_dataset("openai_humaneval") human_eval = load_dataset("openai_humaneval")
code_eval_metric = load_metric("code_eval") code_eval_metric = load_metric("code_eval")
...@@ -72,6 +101,7 @@ def main(): ...@@ -72,6 +101,7 @@ def main():
for task in tqdm(range(n_tasks)): for task in tqdm(range(n_tasks)):
task_generations = [] task_generations = []
prompt = human_eval["test"][task]["prompt"].strip() prompt = human_eval["test"][task]["prompt"].strip()
gen_kwargs["stopping_criteria"][0].start_length = len(tokenizer(prompt)["input_ids"])
for batch in range(args.n_samples // args.batch_size): for batch in range(args.n_samples // args.batch_size):
task_generations.extend(complete_code(pipe, prompt, num_completions=args.batch_size, **gen_kwargs)) task_generations.extend(complete_code(pipe, prompt, num_completions=args.batch_size, **gen_kwargs))
generations.append([prompt + gen for gen in task_generations]) generations.append([prompt + gen for gen in task_generations])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment