Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
99fdec0c
Commit
99fdec0c
authored
Jan 06, 2023
by
Albert Jiang
Browse files
implementing temperature sampling and maj@k for gpt2
parent
8c048e26
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
93 additions
and
5 deletions
+93
-5
lm_eval/base.py
lm_eval/base.py
+40
-1
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+10
-4
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+1
-0
lm_eval/tasks/hendrycks_math.py
lm_eval/tasks/hendrycks_math.py
+42
-0
No files found.
lm_eval/base.py
View file @
99fdec0c
...
...
@@ -152,7 +152,7 @@ class BaseLM(LM):
pass
@
abstractmethod
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
,
temperature
=
0.
):
pass
@
abstractmethod
...
...
@@ -329,6 +329,44 @@ class BaseLM(LM):
return
re_ord
.
get_original
(
res
)
def
multiple_temperature_sample_until
(
self
,
requests
,
k
=
32
,
temperature
=
0.3
):
res
=
[]
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
return
len
(
toks
),
x
[
0
]
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
re_ord
.
get_reordered
()):
if
isinstance
(
until
,
str
):
until
=
[
until
]
(
primary_until
,)
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
assert
context_enc
.
shape
[
0
]
==
1
context_enc
=
context_enc
.
expand
(
k
,
context_enc
.
shape
[
1
])
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
,
temperature
=
temperature
)
generated_tokens
=
cont
[:,
context_enc
.
shape
[
1
]:]
s
=
[
self
.
tok_decode
(
candidate
)
for
candidate
in
generated_tokens
]
for
term
in
until
:
s
=
[
candidate
.
split
(
term
)[
0
]
for
candidate
in
s
]
# partial caching
self
.
cache_hook
.
add_partial
(
"multiple_temperature_sample_until"
,
(
context
,
until
,
k
,
temperature
),
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
# TODO: implement fully general `until` that handles until that are
# multiple tokens or that span multiple tokens correctly
...
...
@@ -843,6 +881,7 @@ class CachingLM:
REQUEST_RETURN_LENGTHS
=
{
"loglikelihood"
:
2
,
"greedy_until"
:
None
,
"multiple_temperature_sample_until"
:
None
,
"loglikelihood_rolling"
:
None
,
}
...
...
lm_eval/models/gpt2.py
View file @
99fdec0c
...
...
@@ -121,10 +121,16 @@ class HFLM(BaseLM):
with
torch
.
no_grad
():
return
self
.
gpt2
(
inps
)[
0
][:,
:,
:
50257
]
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
return
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
)
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
,
temperature
=
0.
):
assert
temperature
>=
0.
if
temperature
==
0.
:
return
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
)
else
:
return
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
True
,
temperature
=
temperature
)
# for backwards compatibility
...
...
lm_eval/tasks/__init__.py
View file @
99fdec0c
...
...
@@ -157,6 +157,7 @@ TASK_REGISTRY = {
"mutual_plus"
:
mutual
.
MuTualPlus
,
# math
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_algebra_maj@k"
:
hendrycks_math
.
MathAlgebraMaj
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
"math_geometry"
:
hendrycks_math
.
MathGeometry
,
"math_intermediate_algebra"
:
hendrycks_math
.
MathIntermediateAlgebra
,
...
...
lm_eval/tasks/hendrycks_math.py
View file @
99fdec0c
...
...
@@ -286,6 +286,48 @@ class MathAlgebra(Math):
DATASET_NAME
=
"algebra"
class
MathAlgebraMaj
(
Math
):
VERSION
=
1
DATASET_NAME
=
"algebra"
def
construct_requests
(
self
,
doc
,
ctx
):
return
rf
.
multiple_temperature_sample_until
(
ctx
,
[
"
\n
"
])
def
process_results
(
self
,
doc
,
results
):
retval
=
0
candidates
=
results
[
0
]
answers
=
[]
for
candidate
in
candidates
:
indices
=
[
pos
for
pos
,
char
in
enumerate
(
candidate
)
if
char
==
"$"
]
if
len
(
indices
)
<=
1
:
answer
=
candidate
else
:
answer
=
candidate
[
indices
[
0
]
+
1
:
indices
[
-
1
]]
try
:
answer
=
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
answer
))
except
:
answer
=
None
answers
.
append
(
answer
)
answer_votes
=
{}
for
answer
in
answers
:
answer_votes
[
answer
]
=
answer_votes
.
get
(
answer
,
0
)
+
1
max_vote
=
0
elected
=
None
for
answer
,
vote
in
answer_votes
.
items
():
if
vote
>
max_vote
and
answer
is
not
None
:
elected
=
answer
max_vote
=
vote
if
self
.
is_equiv
(
elected
,
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
doc
[
"solution"
]))
):
retval
=
1
return
{
"acc"
:
retval
}
class
MathCountingAndProbability
(
Math
):
VERSION
=
1
DATASET_NAME
=
"counting_and_probability"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment