Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e7c18e53
Commit
e7c18e53
authored
Apr 23, 2023
by
haileyschoelkopf
Browse files
refactor Instance dataclass
parent
d2a9b759
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
61 deletions
+12
-61
lm_eval/api/instance.py
lm_eval/api/instance.py
+3
-51
lm_eval/api/task.py
lm_eval/api/task.py
+5
-4
lm_eval/tasks/gsm8k.py
lm_eval/tasks/gsm8k.py
+2
-2
lm_eval/tasks/lambada.py
lm_eval/tasks/lambada.py
+2
-4
No files found.
lm_eval/api/instance.py
View file @
e7c18e53
from
dataclasses
import
dataclass
,
field
from
typing
import
Literal
@
dataclass
class
Instance
:
request_type
:
str
=
None
# TODO: make this an enum?
request_type
:
str
=
Literal
[
"loglikelihood"
,
"loglikelihood_rolling"
,
"greedy_until"
]
doc
:
dict
=
None
arguments
:
tuple
=
None
id_
:
int
=
None
...
...
@@ -15,6 +16,7 @@ class Instance:
repeats
:
str
=
None
def
__post_init__
(
self
):
# unpack metadata field
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
@
property
...
...
@@ -23,53 +25,3 @@ class Instance:
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
return
self
.
arguments
if
isinstance
(
self
.
arguments
,
tuple
)
else
(
self
.
arguments
,)
# import abc
# class Instance(abc.ABC):
# """
# A class used to bind together all necessary information and metadata for
# running forward pass of a model on a specific datapoint.
# """
# # all Instance subclasses have an attribute which is the name of the LM() class function they call to get outputs.
# request_type = None
# def __init__(self, doc, arguments=None, id_=None, metadata=("", None, None)):
# self.doc = doc # store the document which we're using. this is a dict
# self.arguments = arguments
# # need: task name, doc idx, num. repeats
# self.task_name, self.doc_id, self.repeats = metadata
# # id_ = idx within a doc's requests
# self.id_ = id_
# # handle repeats internally. should be able to run K times on exact same input/output pair
# # self.repeats = repeats
# # list containing the returns from each call of the model on this particular set of arguments.
# self.resps = []
# # filtered_resps should end up a dict, with a different key for each set of filters to apply. calculate results against each key in filtered_resps
# self.filtered_resps = {}
# #TODO: add more info as needed for detailed logging
# def __repr__(self):
# return f"Req_{self.request_type}{self.args}{self.id_}"
@
dataclass
class
LoglikelihoodInstance
(
Instance
):
request_type
:
str
=
"loglikelihood"
@
dataclass
class
RollingLoglikelihoodInstance
(
Instance
):
request_type
:
str
=
"loglikelihood_rolling"
@
dataclass
class
GenerationInstance
(
Instance
):
request_type
:
str
=
"greedy_until"
lm_eval/api/task.py
View file @
e7c18e53
...
...
@@ -10,7 +10,7 @@ import datasets
import
numpy
as
np
from
lm_eval.api
import
METRIC_REGISTRY
,
AGGREGATION_REGISTRY
from
lm_eval.api.instance
import
LoglikelihoodInstance
,
RollingLoglikelihoodInstance
,
Generation
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.metrics
import
mean
,
weighted_perplexity
,
weighted_mean
,
bits_per_byte
from
lm_eval
import
utils
...
...
@@ -460,7 +460,7 @@ class ConfigurableTask(Task):
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
if
self
.
OUTPUT_TYPE
==
"greedy_until"
:
return
GenerationInstance
(
doc
=
doc
,
arguments
=
(
ctx
,
"
\n\n
"
),
id_
=
0
,
**
kwargs
)
return
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
arguments
=
(
ctx
,
"
\n\n
"
),
id_
=
0
,
**
kwargs
)
def
process_results
(
self
,
doc
,
results
):
...
...
@@ -498,7 +498,8 @@ class MultipleChoiceTask(Task):
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
return
[
LoglikelihoodInstance
(
return
[
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
(
ctx
,
" {}"
.
format
(
choice
)),
id_
=
i
,
...
...
@@ -579,7 +580,7 @@ class PerplexityTask(Task, abc.ABC):
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
assert
not
ctx
return
RollingLoglikelihoodInstance
(
doc
=
doc
,
arguments
=
(
self
.
doc_to_target
(
doc
),),
id_
=
0
,
**
kwargs
)
return
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
arguments
=
(
self
.
doc_to_target
(
doc
),),
id_
=
0
,
**
kwargs
)
# req = rf.loglikelihood_rolling(self.doc_to_target(doc))
# return req
...
...
lm_eval/tasks/gsm8k.py
View file @
e7c18e53
...
...
@@ -18,7 +18,7 @@ Homepage: https://github.com/openai/grade-school-math
"""
import
re
from
lm_eval.api.task
import
Task
from
lm_eval.api.instance
import
Generation
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.metrics
import
mean
from
lm_eval
import
utils
...
...
@@ -87,7 +87,7 @@ class GradeSchoolMath8K(Task):
"""
# NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution.
return
GenerationInstance
(
doc
=
doc
,
arguments
=
(
ctx
,
[
"
\n
"
]),
id_
=
0
,
**
kwargs
)
return
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
arguments
=
(
ctx
,
[
"
\n
"
]),
id_
=
0
,
**
kwargs
)
# completion = rf.greedy_until(ctx, ["\n"])
# return completion
...
...
lm_eval/tasks/lambada.py
View file @
e7c18e53
...
...
@@ -13,7 +13,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
from
lm_eval.api.task
import
Task
from
lm_eval.api.instance
import
Loglikelihood
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.metrics
import
mean
,
perplexity
...
...
@@ -59,9 +59,7 @@ class LambadaBase(Task):
return
" "
+
doc
[
"text"
].
rsplit
(
" "
,
1
)[
1
]
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
return
LoglikelihoodInstance
(
doc
=
doc
,
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
)),
**
kwargs
)
return
ll
,
is_greedy
return
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
)),
**
kwargs
)
def
process_results
(
self
,
doc
,
results
):
# TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment