Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
76e65788
Commit
76e65788
authored
Dec 27, 2020
by
Leo Gao
Browse files
Update interfaces
parent
9edbc7c0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
24 deletions
+89
-24
lm_eval/base.py
lm_eval/base.py
+70
-15
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+19
-9
No files found.
lm_eval/base.py
View file @
76e65788
import
abc
import
abc
import
random
import
random
import
collections
class
LM
(
abc
.
ABC
):
class
LM
(
abc
.
ABC
):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
context
,
continuation
):
def
loglikelihood
(
self
,
requests
):
"""Compute log-likelihood of generating a continuation from a context
"""Compute log-likelihood of generating a continuation from a context
:param context: str
:param requests: list
Context string
A list of pairs (context, continuation)
:param continuation: str
context: str
The continuation over which log likelihood will be calculated. If
Context string
there is a word boundary, the space should be in the continuation.
continuation: str
For example, context="hello" continuation=" world" is correct.
The continuation over which log likelihood will be calculated. If
:return: float
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: list
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `contination`
isgreedy:
Whether `contination` would be generated by greedy sampling from `context`
"""
pass
@
abc
.
abstractmethod
def
gen_greedy
(
self
,
requests
):
"""Generate greedily until a stopping sequence
:param requests: list
A list of pairs (context, until)
context: str
Context string
until: str
The string sequence to generate until. This string sequence may
span across msultiple tokens, or may be part of one token.
:return: list
A list of strings continuation
continuation: str
The generated continuation.
"""
"""
pass
pass
...
@@ -80,20 +106,29 @@ class Dataset(abc.ABC):
...
@@ -80,20 +106,29 @@ class Dataset(abc.ABC):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
doc_to_text
(
self
,
doc
,
include_target
=
True
):
def
doc_to_text
(
self
,
doc
,
include_target
=
True
):
pass
pass
@
abc
.
abstractmethod
def
construct_requests
(
self
,
doc
,
nshot
=
0
,
prompt
=
False
):
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
evaluate
(
self
,
doc
s
,
lm
,
provide_description
,
num_fewshot
):
def
process_results
(
self
,
doc
,
results
):
"""Take
iterable of doc
s and evaluates, returning a dict with the following format:
"""Take
a single document and the LM result
s and evaluates, returning a dict with the following format:
{
{
"
major": float
,
"
submetric": str
,
"
minor": dic
t,
"
value": floa
t,
"higher_is_better": bool,
"higher_is_better": bool,
"aggregation": (list -> float),
}
}
* `
major
` should be
a single, representative number, for programmatic comparison
* `
submetric
` should be
the name of the metric
* `
minor
` should be
a dictionary containing all relevant sub-
metric
s
* `
value
` should be
the value of the
metric
* `higher_is_better` determines whether a higher metric is better
* `higher_is_better` determines whether a higher metric is better
* `aggregation` should be a function that takes a list of floats and
aggregates them into one float. This should be the same for all
submetrics of the same name; if it differs, an error should be
raised.
"""
"""
pass
pass
...
@@ -107,4 +142,24 @@ class Dataset(abc.ABC):
...
@@ -107,4 +142,24 @@ class Dataset(abc.ABC):
map
(
self
.
doc_to_text
,
self
.
fewshot_examples
(
k
=
num_fewshot
))
map
(
self
.
doc_to_text
,
self
.
fewshot_examples
(
k
=
num_fewshot
))
)
+
"
\n\n
"
)
+
"
\n\n
"
example
=
self
.
doc_to_text
(
doc
,
include_target
=
False
).
strip
()
example
=
self
.
doc_to_text
(
doc
,
include_target
=
False
).
strip
()
return
description
+
labeled_examples
+
example
return
description
+
labeled_examples
+
example
\ No newline at end of file
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
def
median
(
arr
):
return
arr
[
len
(
arr
)
//
2
]
Request
=
collections
.
namedtuple
(
'Request'
,
(
'type'
,
'args'
,
'kwargs'
))
class
RequestFactory
:
def
__getattr__
(
self
,
attr
):
def
fn
(
*
args
,
**
kwargs
):
return
Request
(
attr
,
args
,
kwargs
)
return
fn
rf
=
RequestFactory
()
lm_eval/models/gpt2.py
View file @
76e65788
...
@@ -17,14 +17,24 @@ class GPT2LM(LM):
...
@@ -17,14 +17,24 @@ class GPT2LM(LM):
args
=
utils
.
simple_parse_args_string
(
arg_string
)
args
=
utils
.
simple_parse_args_string
(
arg_string
)
return
cls
(
device
=
args
.
get
(
"device"
,
"cpu"
))
return
cls
(
device
=
args
.
get
(
"device"
,
"cpu"
))
def
loglikelihood
(
self
,
context
,
continuation
,
truncate
=
True
):
def
loglikelihood
(
self
,
requests
):
# when too long to fit in context, truncate from the left
res
=
[]
context_enc
=
self
.
tokenizer
.
encode
(
context
)
# TODO: vectorize properly
continuation_enc
=
self
.
tokenizer
.
encode
(
continuation
)
for
context
,
continuation
in
requests
:
inp
=
torch
.
tensor
([(
context_enc
+
continuation_enc
)[
-
1024
:]],
dtype
=
torch
.
long
).
to
(
self
.
device
)
# when too long to fit in context, truncate from the left
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
1024
)
context_enc
=
self
.
tokenizer
.
encode
(
context
)
continuation_enc
=
self
.
tokenizer
.
encode
(
continuation
)
inp
=
torch
.
tensor
([(
context_enc
+
continuation_enc
)[
-
1024
:]],
dtype
=
torch
.
long
).
to
(
self
.
device
)
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
1024
)
cont_toks
=
inp
[:,
ctxlen
:]
# [batch, seq]
cont_toks
=
inp
[:,
ctxlen
:]
# [batch, seq]
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
],
dim
=-
1
)[:,
ctxlen
-
1
:
-
1
]
# [batch, seq, vocab]
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
],
dim
=-
1
)[:,
ctxlen
-
1
:
-
1
]
# [batch, seq, vocab]
return
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# TODO: implement isgreedy
res
.
append
((
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
),
False
))
return
res
def
gen_greedy
(
self
,
requests
):
# TODO: implement
pass
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment