Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
a77f4be9
Unverified
Commit
a77f4be9
authored
Nov 01, 2023
by
Stella Biderman
Committed by
GitHub
Nov 01, 2023
Browse files
Merge pull request #536 from danny980521/update/klue_ynat
Update `KLUE-YNAT` prompt
parents
a3b76ab1
d2dd333e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
47 deletions
+46
-47
lm_eval/tasks/klue.py
lm_eval/tasks/klue.py
+46
-47
No files found.
lm_eval/tasks/klue.py
View file @
a77f4be9
...
...
@@ -69,8 +69,7 @@ class STS(Task):
def
doc_to_text
(
self
,
doc
):
return
"질문: 문장 1과 문장 2는 서로 유사한 의미를 가지나요?
\n
문장 1: {}
\n
문장 2: {}
\n
정답:"
.
format
(
general_detokenize
(
doc
[
"sentence1"
]),
general_detokenize
(
doc
[
"sentence2"
])
general_detokenize
(
doc
[
"sentence1"
]),
general_detokenize
(
doc
[
"sentence2"
])
)
def
doc_to_target
(
self
,
doc
):
...
...
@@ -84,22 +83,13 @@ class STS(Task):
def
process_results
(
self
,
doc
,
results
):
pred
=
np
.
argmax
(
results
)
gold
=
doc
[
"labels"
][
"binary-label"
]
return
{
"acc"
:
pred
==
gold
,
"f1"
:
(
gold
,
pred
)
}
return
{
"acc"
:
pred
==
gold
,
"f1"
:
(
gold
,
pred
)}
def
higher_is_better
(
self
):
return
{
"acc"
:
True
,
"f1"
:
True
}
return
{
"acc"
:
True
,
"f1"
:
True
}
def
aggregation
(
self
):
return
{
"acc"
:
mean
,
"f1"
:
f1_score
}
return
{
"acc"
:
mean
,
"f1"
:
f1_score
}
class
YNAT
(
MultipleChoiceTask
):
...
...
@@ -118,7 +108,7 @@ class YNAT(MultipleChoiceTask):
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
return
self
.
_training_docs
def
validation_docs
(
self
):
...
...
@@ -128,32 +118,30 @@ class YNAT(MultipleChoiceTask):
out_doc
=
{
"title"
:
doc
[
"title"
],
"choices"
:
[
"과학"
,
"경제"
,
"사회"
,
"생활"
,
"세계"
,
"스포츠"
,
"정치"
],
"gold"
:
doc
[
"label"
]
"gold"
:
doc
[
"label"
]
,
}
return
out_doc
def
doc_to_text
(
self
,
doc
):
return
"
{}
"
.
format
(
doc
[
"title"
])
return
"
질문: 다음의 제목을 가지는 뉴스는 어느 분야의 뉴스인가요?
\n
제목: {}
\n
분야:
"
.
format
(
doc
[
"title"
])
def
doc_to_target
(
self
,
doc
):
return
" ({})"
.
format
({
0
:
"과학"
,
1
:
"경제"
,
2
:
"사회"
,
3
:
"생활"
,
4
:
"세계"
,
5
:
"스포츠"
,
6
:
"정치"
}[
doc
[
"gold"
]])
return
" {}"
.
format
(
{
0
:
"과학"
,
1
:
"경제"
,
2
:
"사회"
,
3
:
"생활"
,
4
:
"세계"
,
5
:
"스포츠"
,
6
:
"정치"
}[
doc
[
"gold"
]
]
)
def
process_results
(
self
,
doc
,
results
):
pred
=
np
.
argmax
(
results
)
gold
=
doc
[
"gold"
]
return
{
"f1"
:
(
gold
,
pred
)
}
return
{
"f1"
:
(
gold
,
pred
)}
def
higher_is_better
(
self
):
return
{
"f1"
:
True
}
return
{
"f1"
:
True
}
def
aggregation
(
self
):
return
{
"f1"
:
macro_f1_score
}
return
{
"f1"
:
macro_f1_score
}
class
NLI
(
Task
):
...
...
@@ -232,7 +220,18 @@ class MRC(Task):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"제목: "
+
doc
[
"title"
]
+
"
\n\n
"
+
"본문: "
+
doc
[
"context"
]
+
"
\n\n
"
+
"질문: "
+
doc
[
"question"
]
+
"
\n\n
"
+
"답:"
return
(
"제목: "
+
doc
[
"title"
]
+
"
\n\n
"
+
"본문: "
+
doc
[
"context"
]
+
"
\n\n
"
+
"질문: "
+
doc
[
"question"
]
+
"
\n\n
"
+
"답:"
)
def
doc_to_target
(
self
,
doc
):
answer
=
doc
[
"answers"
][
"text"
][
0
]
...
...
@@ -241,7 +240,7 @@ class MRC(Task):
return
" "
+
answer
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
...
...
@@ -251,7 +250,7 @@ class MRC(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation
=
rf
.
greedy_until
(
ctx
,
{
"until"
:
[
"
\n
"
]
}
)
continuation
=
rf
.
greedy_until
(
ctx
,
[
"
\n
"
])
is_unanswerable
=
rf
.
loglikelihood
(
ctx
,
" "
+
"대답 불가"
)
return
continuation
,
is_unanswerable
...
...
@@ -270,15 +269,15 @@ class MRC(Task):
no_answer_probability
=
exp
(
logprob_unanswerable
)
predictions
=
{
'
id
'
:
doc
[
'
guid
'
],
'
prediction_text
'
:
continuation
,
'
no_answer_probability
'
:
no_answer_probability
,
"
id
"
:
doc
[
"
guid
"
],
"
prediction_text
"
:
continuation
,
"
no_answer_probability
"
:
no_answer_probability
,
}
references
=
{
'
id
'
:
doc
[
'
guid
'
],
'
answers
'
:
doc
[
'
answers
'
],
'
unanswerable
'
:
doc
[
'
is_impossible
'
],
"
id
"
:
doc
[
"
guid
"
],
"
answers
"
:
doc
[
"
answers
"
],
"
unanswerable
"
:
doc
[
"
is_impossible
"
],
}
return
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment