Commit 460584ca authored by haileyschoelkopf's avatar haileyschoelkopf Committed by Hailey Schoelkopf
Browse files

add template_vars + make jinja rendering stricter

parent f275301a
...@@ -2,7 +2,8 @@ dataset_path: super_glue ...@@ -2,7 +2,8 @@ dataset_path: super_glue
dataset_name: cb dataset_name: cb
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypothesis}}\"? Yes, no, or maybe?" template_aliases: "{% set hypo = hypothesis %}"
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypo}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}" doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
metric_list: [ metric_list: [
[exact_match, mean, true] [exact_match, mean, true]
......
...@@ -31,9 +31,7 @@ class TaskConfig(dict): ...@@ -31,9 +31,7 @@ class TaskConfig(dict):
test_split: str = None test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# TODO: add this as more jinja2 appended to start of jinja2 templates. Should allow users to set vars template_aliases: str = ""
# s.t. they can define e.g. {% set question = query %} to map dataset columns to "canonical" names in prompts.
template_aliases: str = None
doc_to_text: str = None doc_to_text: str = None
doc_to_target: str = None doc_to_target: str = None
...@@ -49,6 +47,13 @@ class TaskConfig(dict): ...@@ -49,6 +47,13 @@ class TaskConfig(dict):
normalization: str = None # TODO: add length-normalization of various types, mutual info normalization: str = None # TODO: add length-normalization of various types, mutual info
stop_sequences: list = None # TODO: allow passing of stop sequences to greedy gen. stop_sequences: list = None # TODO: allow passing of stop sequences to greedy gen.
def __post_init__(self):
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
# field names in prompt
self.doc_to_text = self.template_aliases + self.doc_to_text
self.doc_to_target = self.template_aliases + self.doc_to_target
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
......
import random import random
from lm_eval.api.model import LM from lm_eval.api.model import LM, register_model
@register_model("dummy")
class DummyLM(LM): class DummyLM(LM):
def __init__(self): def __init__(self):
pass pass
......
...@@ -8,7 +8,6 @@ import torch.nn.functional as F ...@@ -8,7 +8,6 @@ import torch.nn.functional as F
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM, register_model
# from lm_eval.models import register_model
@register_model("hf-causal") @register_model("hf-causal")
class HFLM(LM): class HFLM(LM):
......
import os import os
import numpy as np import numpy as np
import transformers import transformers
from lm_eval.api.model import LM from lm_eval.api.model import LM, register_model
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm from tqdm import tqdm
import time import time
...@@ -54,6 +54,7 @@ def oa_completion(**kwargs): ...@@ -54,6 +54,7 @@ def oa_completion(**kwargs):
backoff_time *= 1.5 backoff_time *= 1.5
@register_model("openai")
class GPT3LM(LM): class GPT3LM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import requests as _requests import requests as _requests
import time import time
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.model import LM from lm_eval.api.model import LM, register_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -38,6 +38,7 @@ def textsynth_completion(**kwargs): ...@@ -38,6 +38,7 @@ def textsynth_completion(**kwargs):
backoff_time *= 1.5 backoff_time *= 1.5
@register_model("textsynth")
class TextSynthLM(LM): class TextSynthLM(LM):
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate=False):
""" """
......
...@@ -8,7 +8,7 @@ import sys ...@@ -8,7 +8,7 @@ import sys
from typing import List from typing import List
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment from jinja2 import BaseLoader, Environment, StrictUndefined
class ExitCodeError(Exception): class ExitCodeError(Exception):
...@@ -240,7 +240,7 @@ def run_task_tests(task_list: List[str]): ...@@ -240,7 +240,7 @@ def run_task_tests(task_list: List[str]):
) )
env = Environment(loader=BaseLoader) env = Environment(loader=BaseLoader, undefined=StrictUndefined)
def apply_template(template, doc): def apply_template(template, doc):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment