Commit 460584ca authored by haileyschoelkopf's avatar haileyschoelkopf Committed by Hailey Schoelkopf
Browse files

add template_vars + make jinja rendering stricter

parent f275301a
......@@ -2,7 +2,8 @@ dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypothesis}}\"? Yes, no, or maybe?"
template_aliases: "{% set hypo = hypothesis %}"
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypo}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
metric_list: [
[exact_match, mean, true]
......
......@@ -31,9 +31,7 @@ class TaskConfig(dict):
test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# TODO: add this as more jinja2 appended to start of jinja2 templates. Should allow users to set vars
# s.t. they can define e.g. {% set question = query %} to map dataset columns to "canonical" names in prompts.
template_aliases: str = None
template_aliases: str = ""
doc_to_text: str = None
doc_to_target: str = None
......@@ -49,6 +47,13 @@ class TaskConfig(dict):
normalization: str = None # TODO: add length-normalization of various types, mutual info
stop_sequences: list = None # TODO: allow passing of stop sequences to greedy gen.
def __post_init__(self):
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
# field names in prompt
self.doc_to_text = self.template_aliases + self.doc_to_text
self.doc_to_target = self.template_aliases + self.doc_to_target
def __getitem__(self, item):
return getattr(self, item)
......
import random
from lm_eval.api.model import LM
from lm_eval.api.model import LM, register_model
@register_model("dummy")
class DummyLM(LM):
def __init__(self):
pass
......
......@@ -8,7 +8,6 @@ import torch.nn.functional as F
from lm_eval import utils
from lm_eval.api.model import LM, register_model
# from lm_eval.models import register_model
@register_model("hf-causal")
class HFLM(LM):
......
import os
import numpy as np
import transformers
from lm_eval.api.model import LM
from lm_eval.api.model import LM, register_model
from lm_eval import utils
from tqdm import tqdm
import time
......@@ -54,6 +54,7 @@ def oa_completion(**kwargs):
backoff_time *= 1.5
@register_model("openai")
class GPT3LM(LM):
REQ_CHUNK_SIZE = 20
......
......@@ -16,7 +16,7 @@ import os
import requests as _requests
import time
from tqdm import tqdm
from lm_eval.api.model import LM
from lm_eval.api.model import LM, register_model
logger = logging.getLogger(__name__)
......@@ -38,6 +38,7 @@ def textsynth_completion(**kwargs):
backoff_time *= 1.5
@register_model("textsynth")
class TextSynthLM(LM):
def __init__(self, engine, truncate=False):
"""
......
......@@ -8,7 +8,7 @@ import sys
from typing import List
from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment
from jinja2 import BaseLoader, Environment, StrictUndefined
class ExitCodeError(Exception):
......@@ -240,7 +240,7 @@ def run_task_tests(task_list: List[str]):
)
env = Environment(loader=BaseLoader)
env = Environment(loader=BaseLoader, undefined=StrictUndefined)
def apply_template(template, doc):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment