Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
5e31e40e
Commit
5e31e40e
authored
Jun 09, 2021
by
Leo Gao
Browse files
Implement WikiText
parent
efa99cb2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
124 deletions
+89
-124
lm_eval/base.py
lm_eval/base.py
+8
-8
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+3
-1
lm_eval/tasks/wikitext.py
lm_eval/tasks/wikitext.py
+78
-115
No files found.
lm_eval/base.py
View file @
5e31e40e
...
...
@@ -322,17 +322,17 @@ class PerplexityTask(Task, abc.ABC):
def
construct_requests
(
self
,
doc
,
ctx
):
assert
not
ctx
req
=
rf
.
loglikelihood_rolling
(
doc
)
req
=
rf
.
loglikelihood_rolling
(
self
.
doc_to_target
(
doc
)
)
return
req
def
process_results
(
self
,
doc
,
results
):
loglikelihood
,
=
results
words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
)
)
bytes
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
)
)
words
=
self
.
count_words
(
doc
)
bytes
=
self
.
count_bytes
(
doc
)
return
{
"word_perplexity"
:
(
loglikelihood
,
words
),
"byte_perplexity"
:
(
loglikelihood
,
bytes
),
"bits_per_byte"
:
(
-
loglikelihood
,
self
.
count_bytes
(
self
.
doc_to_target
(
doc
)
))
"bits_per_byte"
:
(
-
loglikelihood
,
self
.
count_bytes
(
doc
))
}
def
aggregation
(
self
):
...
...
@@ -342,12 +342,12 @@ class PerplexityTask(Task, abc.ABC):
"bits_per_byte"
:
weighted_mean
}
def
count_bytes
(
self
,
s
):
return
len
(
s
.
encode
(
"utf-8"
))
def
count_bytes
(
self
,
doc
):
return
len
(
doc
.
encode
(
"utf-8"
))
def
count_words
(
self
,
s
):
def
count_words
(
self
,
doc
):
""" Downstream tasks with custom word boundaries should override this! """
return
len
(
re
.
split
(
r
"\s+"
,
s
))
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
req_ret_lens
=
{
...
...
lm_eval/tasks/__init__.py
View file @
5e31e40e
...
...
@@ -38,6 +38,7 @@ from . import hendrycks_math
from
.
import
cbt
from
.
import
lambada_cloze
from
.
import
pile
from
.
import
wikitext
########################################
# Translation tasks
...
...
@@ -95,6 +96,7 @@ TASK_REGISTRY = {
"drop"
:
drop
.
DROP
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
"wikitext"
:
wikitext
.
WikiText
,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
...
...
@@ -113,7 +115,7 @@ TASK_REGISTRY = {
"arc_challenge"
:
arc
.
ARCChallenge
,
# "quac": quac.QuAC, # not implemented yet
"logiqa"
:
logiqa
.
LogiQA
,
"hellaswag"
:
hellaswag
.
HellaSwag
,
# not implemented yet
"hellaswag"
:
hellaswag
.
HellaSwag
,
"openbookqa"
:
openbookqa
.
OpenBookQA
,
# "sat": sat.SATAnalogies, # not implemented yet
"squad2"
:
squad
.
SQuAD2
,
...
...
lm_eval/tasks/wikitext.py
View file @
5e31e40e
from
.
common
import
HFTask
class
WikiText103
(
HFTask
):
import
os
import
re
from
lm_eval.base
import
rf
,
PerplexityTask
from
lm_eval.utils
import
sh
from
best_download
import
download_file
def
wikitext_detokenizer
(
string
):
# contractions
string
=
string
.
replace
(
"s '"
,
"s'"
)
string
=
re
.
sub
(
r
"/' [0-9]/"
,
r
"/'[0-9]/"
,
string
)
# number separators
string
=
string
.
replace
(
" @-@ "
,
"-"
)
string
=
string
.
replace
(
" @,@ "
,
","
)
string
=
string
.
replace
(
" @.@ "
,
"."
)
# punctuation
string
=
string
.
replace
(
" : "
,
": "
)
string
=
string
.
replace
(
" ; "
,
"; "
)
string
=
string
.
replace
(
" . "
,
". "
)
string
=
string
.
replace
(
" ! "
,
"! "
)
string
=
string
.
replace
(
" ? "
,
"? "
)
string
=
string
.
replace
(
" , "
,
", "
)
# double brackets
string
=
re
.
sub
(
r
"\(\s*([^\)]*?)\s*\)"
,
r
"(\1)"
,
string
)
string
=
re
.
sub
(
r
"\[\s*([^\]]*?)\s*\]"
,
r
"[\1]"
,
string
)
string
=
re
.
sub
(
r
"{\s*([^}]*?)\s*}"
,
r
"{\1}"
,
string
)
string
=
re
.
sub
(
r
"\"\s*([^\"]*?)\s*\""
,
r
'"\1"'
,
string
)
string
=
re
.
sub
(
r
"'\s*([^']*?)\s*'"
,
r
"'\1'"
,
string
)
# miscellaneous
string
=
string
.
replace
(
"= = = ="
,
"===="
)
string
=
string
.
replace
(
"= = ="
,
"==="
)
string
=
string
.
replace
(
"= ="
,
"=="
)
string
=
string
.
replace
(
" "
+
chr
(
176
)
+
" "
,
chr
(
176
))
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
" N "
,
" 1 "
)
string
=
string
.
replace
(
" 's"
,
"'s"
)
return
string
class
WikiText
(
PerplexityTask
):
VERSION
=
0
NLP_PATH
=
"wikitext"
NLP_NAME
=
"wikitext-103-raw-v1"
def
download
(
self
):
if
not
os
.
path
.
exists
(
'data/wikitext/wikitext-2-raw/wiki.valid.raw'
):
os
.
makedirs
(
"data/wikitext/"
,
exist_ok
=
True
)
download_file
(
"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip"
,
"data/wikitext/wikitext-2-raw-v1.zip"
,
"ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11"
)
sh
(
"cd data/wikitext/ && unzip wikitext-2-raw-v1.zip"
)
def
fewshot_description
(
self
):
# TODO: figure out fewshot description
return
""
def
doc_to_text
(
self
,
doc
):
# TODO: implement
pass
def
has_validation_docs
(
self
):
return
True
def
doc_to_target
(
self
,
doc
):
# TODO: implement
pass
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
def
has_train_docs
(
self
):
return
True
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
has_test_docs
(
self
):
return
True
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
aggregation
(
self
):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
class
WikiText2
(
HFTask
):
VERSION
=
0
NLP_PATH
=
"wikitext"
NLP_NAME
=
"wikitext-2-raw-v1"
def
fewshot_description
(
self
):
# TODO: figure out fewshot description
return
""
def
doc_to_text
(
self
,
doc
):
# TODO: implement
pass
def
docs_for_split
(
self
,
split
):
ret
=
[]
for
line
in
open
(
f
"data/wikitext/wikitext-2-raw/wiki.
{
split
}
.raw"
).
read
().
split
(
'
\n
'
):
rline
=
line
.
replace
(
"= ="
,
"=="
).
replace
(
"= = ="
,
"==="
).
strip
()
if
rline
.
startswith
(
'= '
)
and
rline
.
strip
().
endswith
(
' ='
):
s
=
'
\n
'
.
join
(
ret
)
if
s
.
strip
():
yield
s
ret
=
[]
ret
.
append
(
line
)
yield
'
\n
'
.
join
(
ret
)
def
validation_docs
(
self
):
return
self
.
docs_for_split
(
'valid'
)
def
train_docs
(
self
):
return
self
.
docs_for_split
(
'train'
)
def
test_docs
(
self
):
return
self
.
docs_for_split
(
'test'
)
def
doc_to_target
(
self
,
doc
):
# TODO: implement
pass
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
return
wikitext_detokenizer
(
doc
)
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
aggregation
(
self
):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
count_words
(
self
,
doc
):
# count number of words in *original doc before detokenization*
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment