Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
6ec93da2
Commit
6ec93da2
authored
Apr 25, 2022
by
jon-tow
Browse files
Add `eos_token` property
parent
34f591af
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
5 deletions
+8
-5
lm_eval/base.py
lm_eval/base.py
+8
-5
No files found.
lm_eval/base.py
View file @
6ec93da2
...
@@ -348,7 +348,8 @@ class BaseLM(LM):
...
@@ -348,7 +348,8 @@ class BaseLM(LM):
if
isinstance
(
until
,
str
):
if
isinstance
(
until
,
str
):
until
=
[
until
]
until
=
[
until
]
(
primary_until
,)
=
self
.
tok_encode
(
until
[
0
])
# TODO: Come back to for generation `eos`.
primary_until
=
self
.
tok_encode
(
until
[
0
])[
0
]
context_enc
=
torch
.
tensor
(
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
...
@@ -616,7 +617,6 @@ class Task(abc.ABC):
...
@@ -616,7 +617,6 @@ class Task(abc.ABC):
)
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
...
@@ -639,6 +639,9 @@ class PromptSourceTask(Task):
...
@@ -639,6 +639,9 @@ class PromptSourceTask(Task):
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
self
.
prompt
=
prompt
self
.
prompt
=
prompt
def
eos_token
(
self
):
raise
NotImplementedError
()
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
_
,
target
=
self
.
prompt
.
apply
(
doc
)
_
,
target
=
self
.
prompt
.
apply
(
doc
)
return
f
"
{
target
}
"
return
f
"
{
target
}
"
...
@@ -659,7 +662,6 @@ class PromptSourceTask(Task):
...
@@ -659,7 +662,6 @@ class PromptSourceTask(Task):
part of the document for `doc`.
part of the document for `doc`.
"""
"""
_requests
=
[]
_requests
=
[]
answer_choices_list
=
self
.
prompt
.
get_answer_choices_list
(
doc
)
answer_choices_list
=
self
.
prompt
.
get_answer_choices_list
(
doc
)
if
answer_choices_list
:
if
answer_choices_list
:
for
answer_choice
in
answer_choices_list
:
for
answer_choice
in
answer_choices_list
:
...
@@ -667,8 +669,8 @@ class PromptSourceTask(Task):
...
@@ -667,8 +669,8 @@ class PromptSourceTask(Task):
_requests
.
append
(
ll_answer_choice
)
_requests
.
append
(
ll_answer_choice
)
else
:
else
:
# TODO(Albert): What is the stop symbol? Is it model specific?
# TODO(Albert): What is the stop symbol? Is it model specific?
ll_greedy
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:"
])
cont_request
=
rf
.
greedy_until
(
ctx
,
[
self
.
eos_token
()
])
_requests
.
append
(
ll_greedy
)
_requests
.
append
(
cont_request
)
return
_requests
return
_requests
...
@@ -694,6 +696,7 @@ class PromptSourceTask(Task):
...
@@ -694,6 +696,7 @@ class PromptSourceTask(Task):
}
}
else
:
else
:
continuation
=
results
continuation
=
results
raise
NotImplementedError
()
# Map metric name to HF metric.
# Map metric name to HF metric.
# TODO(Albert): What is Other?
# TODO(Albert): What is Other?
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment