Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
460584ca
Commit
460584ca
authored
Apr 23, 2023
by
haileyschoelkopf
Committed by
Hailey Schoelkopf
Apr 24, 2023
Browse files
add template_vars + make jinja rendering stricter
parent
f275301a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
18 additions
and
10 deletions
+18
-10
examples/configurable_task/sglue_cb.yaml
examples/configurable_task/sglue_cb.yaml
+2
-1
lm_eval/api/task.py
lm_eval/api/task.py
+8
-3
lm_eval/models/dummy.py
lm_eval/models/dummy.py
+2
-1
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+0
-1
lm_eval/models/gpt3.py
lm_eval/models/gpt3.py
+2
-1
lm_eval/models/textsynth.py
lm_eval/models/textsynth.py
+2
-1
lm_eval/utils.py
lm_eval/utils.py
+2
-2
No files found.
examples/configurable_task/sglue_cb.yaml
View file @
460584ca
...
...
@@ -2,7 +2,8 @@ dataset_path: super_glue
dataset_name
:
cb
training_split
:
train
validation_split
:
validation
doc_to_text
:
"
Suppose
{{premise}}
Can
we
infer
that
\"
{{hypothesis}}
\"
?
Yes,
no,
or
maybe?"
template_aliases
:
"
{%
set
hypo
=
hypothesis
%}"
doc_to_text
:
"
Suppose
{{premise}}
Can
we
infer
that
\"
{{hypo}}
\"
?
Yes,
no,
or
maybe?"
doc_to_target
:
"
{%
set
answer_choices
=
['Yes',
'No',
'Maybe']
%}{{answer_choices[label]}}"
metric_list
:
[
[
exact_match
,
mean
,
true
]
...
...
lm_eval/api/task.py
View file @
460584ca
...
...
@@ -31,9 +31,7 @@ class TaskConfig(dict):
test_split
:
str
=
None
fewshot_split
:
str
=
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# TODO: add this as more jinja2 appended to start of jinja2 templates. Should allow users to set vars
# s.t. they can define e.g. {% set question = query %} to map dataset columns to "canonical" names in prompts.
template_aliases
:
str
=
None
template_aliases
:
str
=
""
doc_to_text
:
str
=
None
doc_to_target
:
str
=
None
...
...
@@ -49,6 +47,13 @@ class TaskConfig(dict):
normalization
:
str
=
None
# TODO: add length-normalization of various types, mutual info
stop_sequences
:
list
=
None
# TODO: allow passing of stop sequences to greedy gen.
def
__post_init__
(
self
):
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
# field names in prompt
self
.
doc_to_text
=
self
.
template_aliases
+
self
.
doc_to_text
self
.
doc_to_target
=
self
.
template_aliases
+
self
.
doc_to_target
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
...
...
lm_eval/models/dummy.py
View file @
460584ca
import
random
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
,
register_model
@
register_model
(
"dummy"
)
class
DummyLM
(
LM
):
def
__init__
(
self
):
pass
...
...
lm_eval/models/gpt2.py
View file @
460584ca
...
...
@@ -8,7 +8,6 @@ import torch.nn.functional as F
from
lm_eval
import
utils
from
lm_eval.api.model
import
LM
,
register_model
# from lm_eval.models import register_model
@
register_model
(
"hf-causal"
)
class
HFLM
(
LM
):
...
...
lm_eval/models/gpt3.py
View file @
460584ca
import
os
import
numpy
as
np
import
transformers
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
,
register_model
from
lm_eval
import
utils
from
tqdm
import
tqdm
import
time
...
...
@@ -54,6 +54,7 @@ def oa_completion(**kwargs):
backoff_time
*=
1.5
@
register_model
(
"openai"
)
class
GPT3LM
(
LM
):
REQ_CHUNK_SIZE
=
20
...
...
lm_eval/models/textsynth.py
View file @
460584ca
...
...
@@ -16,7 +16,7 @@ import os
import
requests
as
_requests
import
time
from
tqdm
import
tqdm
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
,
register_model
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -38,6 +38,7 @@ def textsynth_completion(**kwargs):
backoff_time
*=
1.5
@
register_model
(
"textsynth"
)
class
TextSynthLM
(
LM
):
def
__init__
(
self
,
engine
,
truncate
=
False
):
"""
...
...
lm_eval/utils.py
View file @
460584ca
...
...
@@ -8,7 +8,7 @@ import sys
from
typing
import
List
from
omegaconf
import
OmegaConf
from
jinja2
import
BaseLoader
,
Environment
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
class
ExitCodeError
(
Exception
):
...
...
@@ -240,7 +240,7 @@ def run_task_tests(task_list: List[str]):
)
env
=
Environment
(
loader
=
BaseLoader
)
env
=
Environment
(
loader
=
BaseLoader
,
undefined
=
StrictUndefined
)
def
apply_template
(
template
,
doc
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment