Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e49cf8da
Commit
e49cf8da
authored
Apr 25, 2022
by
cjlovering
Browse files
SST with PS integration. (It was already done.)
parent
31a019c2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
78 deletions
+30
-78
lm_eval/tasks/glue.py
lm_eval/tasks/glue.py
+30
-78
No files found.
lm_eval/tasks/glue.py
View file @
e49cf8da
...
@@ -67,7 +67,7 @@ class CoLA(PromptSourceTask):
...
@@ -67,7 +67,7 @@ class CoLA(PromptSourceTask):
def
validation_docs
(
self
):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
return
self
.
dataset
[
"validation"
]
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
answer_choices_list
=
self
.
prompt
.
get_answer_choices_list
(
doc
)
answer_choices_list
=
self
.
prompt
.
get_answer_choices_list
(
doc
)
pred
=
np
.
argmax
(
results
)
pred
=
np
.
argmax
(
results
)
target
=
answer_choices_list
.
index
(
self
.
doc_to_target
(
doc
).
strip
())
target
=
answer_choices_list
.
index
(
self
.
doc_to_target
(
doc
).
strip
())
...
@@ -79,19 +79,13 @@ class CoLA(PromptSourceTask):
...
@@ -79,19 +79,13 @@ class CoLA(PromptSourceTask):
print
(
f
"PRED:
{
pred
}
"
)
print
(
f
"PRED:
{
pred
}
"
)
print
(
"*"
*
80
)
print
(
"*"
*
80
)
return
{
return
{
"mcc"
:
(
target
,
pred
)}
"mcc"
:
(
target
,
pred
)
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"mcc"
:
True
}
"mcc"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"mcc"
:
matthews_corrcoef
}
"mcc"
:
matthews_corrcoef
}
class
SST
(
PromptSourceTask
):
class
SST
(
PromptSourceTask
):
...
@@ -116,16 +110,6 @@ class SST(PromptSourceTask):
...
@@ -116,16 +110,6 @@ class SST(PromptSourceTask):
def
validation_docs
(
self
):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
return
self
.
dataset
[
"validation"
]
def
higher_is_better
(
self
):
return
{
"acc"
:
True
}
def
aggregation
(
self
):
return
{
"acc"
:
mean
}
# Inference Tasks
# Inference Tasks
...
@@ -160,19 +144,13 @@ class MNLI(PromptSourceTask):
...
@@ -160,19 +144,13 @@ class MNLI(PromptSourceTask):
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
gold
=
doc
[
"label"
]
gold
=
doc
[
"label"
]
pred
=
np
.
argmax
(
results
)
pred
=
np
.
argmax
(
results
)
return
{
return
{
"acc"
:
pred
==
gold
}
"acc"
:
pred
==
gold
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
class
MNLIMismatched
(
MNLI
):
class
MNLIMismatched
(
MNLI
):
...
@@ -213,19 +191,13 @@ class QNLI(Task):
...
@@ -213,19 +191,13 @@ class QNLI(Task):
ll_yes
,
ll_no
=
results
ll_yes
,
ll_no
=
results
pred
=
ll_no
>
ll_yes
pred
=
ll_no
>
ll_yes
gold
=
doc
[
"label"
]
gold
=
doc
[
"label"
]
return
{
return
{
"acc"
:
pred
==
gold
}
"acc"
:
pred
==
gold
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
class
WNLI
(
PromptSourceTask
):
class
WNLI
(
PromptSourceTask
):
...
@@ -252,14 +224,10 @@ class WNLI(PromptSourceTask):
...
@@ -252,14 +224,10 @@ class WNLI(PromptSourceTask):
return
self
.
dataset
[
"validation"
]
return
self
.
dataset
[
"validation"
]
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
class
RTE
(
PromptSourceTask
):
class
RTE
(
PromptSourceTask
):
...
@@ -285,14 +253,10 @@ class RTE(PromptSourceTask):
...
@@ -285,14 +253,10 @@ class RTE(PromptSourceTask):
return
self
.
dataset
[
"validation"
]
return
self
.
dataset
[
"validation"
]
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
# Similarity and Paraphrase Tasks
# Similarity and Paraphrase Tasks
...
@@ -330,16 +294,10 @@ class MRPC(Task):
...
@@ -330,16 +294,10 @@ class MRPC(Task):
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
,
"f1"
:
True
}
"acc"
:
True
,
"f1"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
,
"f1"
:
f1_score
}
"acc"
:
mean
,
"f1"
:
f1_score
}
class
QQP
(
Task
):
class
QQP
(
Task
):
...
@@ -388,16 +346,10 @@ class QQP(Task):
...
@@ -388,16 +346,10 @@ class QQP(Task):
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"acc"
:
True
,
"f1"
:
True
}
"acc"
:
True
,
"f1"
:
True
}
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
return
{
"acc"
:
mean
,
"f1"
:
f1_score
}
"acc"
:
mean
,
"f1"
:
f1_score
}
class
STSB
(
Task
):
class
STSB
(
Task
):
...
@@ -435,22 +387,22 @@ class STSB(Task):
...
@@ -435,22 +387,22 @@ class STSB(Task):
return
" {}"
.
format
(
doc
[
"label"
])
return
" {}"
.
format
(
doc
[
"label"
])
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
"""
# TODO: implement evaluation.
# TODO: implement evaluation.
raise
NotImplementedError
(
'
Evaluation not implemented
'
)
raise
NotImplementedError
(
"
Evaluation not implemented
"
)
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
dict where keys are the names of submetrics and values are the values of
the metric for that one document
the metric for that one document
:param doc:
:param doc:
...
@@ -459,22 +411,22 @@ class STSB(Task):
...
@@ -459,22 +411,22 @@ class STSB(Task):
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
# TODO: implement evaluation.
# TODO: implement evaluation.
raise
NotImplementedError
(
'
Evaluation not implemented
'
)
raise
NotImplementedError
(
"
Evaluation not implemented
"
)
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
:returns: {str: [float] -> float}
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
# TODO: implement evaluation.
# TODO: implement evaluation.
raise
NotImplementedError
(
'
Evaluation not implemented
'
)
raise
NotImplementedError
(
"
Evaluation not implemented
"
)
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
:returns: {str: bool}
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
# TODO: implement evaluation.
# TODO: implement evaluation.
raise
NotImplementedError
(
'
Evaluation not implemented
'
)
raise
NotImplementedError
(
"
Evaluation not implemented
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment