Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1b97e487
Unverified
Commit
1b97e487
authored
May 23, 2024
by
Jess
Committed by
GitHub
May 23, 2024
Browse files
Merge pull request #28 from JessicaOjo/africamgsm
revert xnli to multiple_choice
parents
692510cc
4583bb42
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
138 additions
and
57 deletions
+138
-57
lm_eval/tasks/afrimgsm/utils.py
lm_eval/tasks/afrimgsm/utils.py
+131
-16
lm_eval/tasks/afrixnli/manual/direct/afrixnli_manual_direct_yaml
.../tasks/afrixnli/manual/direct/afrixnli_manual_direct_yaml
+2
-10
lm_eval/tasks/afrixnli/manual/direct/utils.py
lm_eval/tasks/afrixnli/manual/direct/utils.py
+1
-10
lm_eval/tasks/afrixnli/manual/translate/afrixnli_manual_translate_yaml
.../afrixnli/manual/translate/afrixnli_manual_translate_yaml
+3
-11
lm_eval/tasks/afrixnli/manual/translate/utils.py
lm_eval/tasks/afrixnli/manual/translate/utils.py
+1
-10
No files found.
lm_eval/tasks/afrimgsm/utils.py
View file @
1b97e487
import
argparse
import
argparse
import
yaml
import
yaml
languages
=
[
'eng'
,
'amh'
,
'ibo'
,
'fra'
,
'sna'
,
'lin'
,
'wol'
,
'ewe'
,
'lug'
,
'xho'
,
'kin'
,
'twi'
,
'zul'
,
'orm'
,
'yor'
,
'hau'
,
'sot'
,
'swa'
]
languages
=
[
'eng'
,
'amh'
,
'ibo'
,
'fra'
,
'sna'
,
'lin'
,
'wol'
,
'ewe'
,
'lug'
,
'xho'
,
'kin'
,
'twi'
,
'zul'
,
'orm'
,
'yor'
,
'hau'
,
'sot'
,
'swa'
]
languages_REGEX
=
{
"eng"
:
"The answer is (
\\
-?[0-9
\\
.
\\
,]+)"
,
"amh"
:
"መልሱ (
\\
-?[0-9
\\
.
\\
,]+)"
,
"ibo"
:
"Azịza ya bụ (
\\
-?[0-9
\\
.
\\
,]+)"
,
'fra'
:
"La réponse est(
\\
-?[0-9
\\
.
\\
,]+)"
,
'sna'
:
"Mhinduro kumubvunzo ndi (
\\
-?[0-9
\\
.
\\
,]+)"
,
'lin'
:
"Eyano ezali (
\\
-?[0-9
\\
.
\\
,]+)"
,
'wol'
:
"Tontu li (
\\
-?[0-9
\\
.
\\
,]+)"
,
'ewe'
:
"ŋuɖoɖoae nye (
\\
-?[0-9
\\
.
\\
,]+)"
,
'lug'
:
"Ansa eri (
\\
-?[0-9
\\
.
\\
,]+)"
,
'xho'
:
"Impendulo ngu (
\\
-?[0-9
\\
.
\\
,]+)"
,
'kin'
:
"Igisubizo ni (
\\
-?[0-9
\\
.
\\
,]+)"
,
'twi'
:
"Ne nnyiano yɛ (
\\
-?[0-9
\\
.
\\
,]+)"
,
'zul'
:
"Impendulo ithi (
\\
-?[0-9
\\
.
\\
,]+)"
,
'orm'
:
"Deebiin isaa (
\\
-?[0-9
\\
.
\\
,]+)"
,
'yor'
:
"Ìdáhùn náà ni (
\\
-?[0-9
\\
.
\\
,]+)"
,
'hau'
:
"Amsar ita ce (
\\
-?[0-9
\\
.
\\
,]+)"
,
'sot'
:
"Karabo ke (
\\
-?[0-9
\\
.
\\
,]+)"
,
'swa'
:
"Jibu ni (
\\
-?[0-9
\\
.
\\
,]+)"
,
}
LANGUAGES
=
{}
for
lang
in
languages
:
if
lang
==
'amh'
:
LANGUAGES
[
lang
]
=
{
# English
"QUESTION"
:
"ጥያቄ:"
,
"ANSWER"
:
"በቅደም ተከተል መልስ:"
,
"DIRECT"
:
"Answer:"
,
"REGEX"
:
languages_REGEX
[
lang
]}
elif
lang
==
'yor'
:
LANGUAGES
[
lang
]
=
{
# English
"QUESTION"
:
"Ìbéèrè:"
,
"ANSWER"
:
"Ìdáhùn lẹ́sẹsẹ:"
,
"DIRECT"
:
"Answer:"
,
"REGEX"
:
languages_REGEX
[
lang
]}
configs
=
{
else
:
"QUESTION"
:
"Question:"
,
LANGUAGES
[
lang
]
=
{
# English
"ANSWER"
:
"Step-by-Step Answer:"
,
"QUESTION"
:
"Question:"
,
"DIRECT"
:
"Answer:"
,
"ANSWER"
:
"Step-by-Step Answer:"
,
"REGEX"
:
"The answer is (
\\
-?[0-9
\\
.
\\
,]+)"
}
"DIRECT"
:
"Answer:"
,
"REGEX"
:
languages_REGEX
[
lang
]}
def
add_regex_pattern
(
regex_pattern
):
if
regex_pattern
is
None
:
return
{}
return
{
"filter_list"
:
[
{
"name"
:
"strict-match"
,
"filter"
:
[
{
"function"
:
"regex"
,
"regex_pattern"
:
f
"""
{
regex_pattern
}
"""
,
},
{
"function"
:
"take_first"
,
},
],
},
{
"name"
:
"flexible-extract"
,
"filter"
:
[
{
"function"
:
"regex"
,
"regex_pattern"
:
"""(-?[$0-9.,]{2,})|(-?[0-9]+)"""
,
"group_select"
:
-
1
,
},
{
"function"
:
"take_first"
,
},
],
},
],
}
def
gen_lang_yamls
(
output_dir
:
str
,
overwrite
:
bool
,
mode
:
str
)
->
None
:
def
gen_lang_yamls
(
output_dir
:
str
,
overwrite
:
bool
,
mode
:
str
)
->
None
:
...
@@ -18,28 +91,70 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
...
@@ -18,28 +91,70 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
:param overwrite: Whether to overwrite files if they already exist.
:param overwrite: Whether to overwrite files if they already exist.
"""
"""
err
=
[]
err
=
[]
for
lang
in
languages
:
for
lang
in
LANGUAGES
.
keys
()
:
try
:
try
:
yaml_template
=
"cot_yaml"
filter_list
=
{}
DELIMITER
=
None
if
mode
==
"direct"
:
if
mode
==
"direct"
:
ANSWER
=
LANGUAGES
[
'eng'
][
"DIRECT"
]
QUESTION
=
LANGUAGES
[
'eng'
][
"QUESTION"
]
REGEX
=
None
task_name
=
f
"afrimgsm_direct_
{
lang
}
"
task_name
=
f
"afrimgsm_direct_
{
lang
}
"
yaml_template
=
"afrimgsm_common_yaml"
yaml_template
=
"direct_yaml"
if
mode
==
"direct-native"
:
ANSWER
=
LANGUAGES
[
lang
][
"DIRECT"
]
QUESTION
=
LANGUAGES
[
lang
][
"QUESTION"
]
REGEX
=
None
task_name
=
f
"afrimgsm_direct_native_
{
lang
}
"
yaml_template
=
"direct_native_yaml"
elif
mode
==
"native-cot"
:
elif
mode
==
"native-cot"
:
ANSWER
=
LANGUAGES
[
lang
][
"ANSWER"
]
REGEX
=
LANGUAGES
[
lang
][
"REGEX"
]
QUESTION
=
LANGUAGES
[
lang
][
"QUESTION"
]
task_name
=
f
"afrimgsm_native_cot_
{
lang
}
"
task_name
=
f
"afrimgsm_native_cot_
{
lang
}
"
yaml_template
=
"afrimgsm_common_yaml"
filter_list
=
add_regex_pattern
(
REGEX
)
DELIMITER
=
""
if
lang
in
[
"zh"
,
"ja"
]
else
None
elif
mode
==
"en-cot"
:
elif
mode
==
"en-cot"
:
ANSWER
=
LANGUAGES
[
"eng"
][
"ANSWER"
]
REGEX
=
LANGUAGES
[
"eng"
][
"REGEX"
]
QUESTION
=
LANGUAGES
[
"eng"
][
"QUESTION"
]
task_name
=
f
"afrimgsm_en_cot_
{
lang
}
"
task_name
=
f
"afrimgsm_en_cot_
{
lang
}
"
yaml_template
=
"afrimgsm_common_yaml"
elif
mode
==
"translate-direct"
:
ANSWER
=
LANGUAGES
[
'eng'
][
"DIRECT"
]
QUESTION
=
LANGUAGES
[
'eng'
][
"QUESTION"
]
REGEX
=
None
task_name
=
f
"translate_afrimgsm_direct_
{
lang
}
"
yaml_template
=
"translate_direct_yaml"
file_name
=
f
"
{
task_name
}
.yaml"
file_name
=
f
"
{
task_name
}
.yaml"
ANSWER_TO_SKIP
=
len
(
LANGUAGES
[
lang
][
"ANSWER"
])
+
1
with
open
(
with
open
(
f
"
{
output_dir
}
/
{
file_name
}
"
,
"w"
if
overwrite
else
"x"
,
encoding
=
"utf8"
f
"
{
output_dir
}
/
{
file_name
}
"
,
"w"
if
overwrite
else
"x"
,
encoding
=
"utf8"
)
as
f
:
)
as
f
:
f
.
write
(
"# Generated by utils.py
\n
"
)
f
.
write
(
"# Generated by utils.py
\n
"
)
yaml
.
dump
(
yaml
.
dump
(
{
{
"include"
:
yaml_template
,
"include"
:
yaml_template
,
"dataset_name"
:
lang
,
"dataset_name"
:
lang
,
"task"
:
f
"
{
task_name
}
"
"task"
:
f
"
{
task_name
}
"
,
"doc_to_text"
:
f
"""{{% if answer is not none %}}"""
f
"""{{{{question+"
\\
n
{
ANSWER
}
"}}}}"""
f
"""{{% else %}}"""
f
"""{{{{"
{
QUESTION
}
"+question+"
\\
n
{
ANSWER
}
"}}}}"""
f
"""{{% endif %}}"""
,
"doc_to_target"
:
f
"""{{% if answer is not none %}}"""
f
"""{{{{answer[
{
ANSWER_TO_SKIP
}
:]}}}}"""
f
"""{{% else %}}"""
f
"""{{{{answer_number|string}}}}"""
f
"""{{% endif %}}"""
,
**
filter_list
,
"generation_kwargs"
:
{
"until"
:
[
QUESTION
,
"</s>"
,
"<|im_end|>"
],
"do_sample"
:
False
,
},
**
({
"target_delimiter"
:
DELIMITER
}
if
DELIMITER
else
{}),
},
},
f
,
f
,
allow_unicode
=
True
,
allow_unicode
=
True
,
...
@@ -60,17 +175,17 @@ def main() -> None:
...
@@ -60,17 +175,17 @@ def main() -> None:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
"--overwrite"
,
"--overwrite"
,
default
=
Tru
e
,
default
=
Fals
e
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Overwrite files if they already exist"
,
help
=
"Overwrite files if they already exist"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--output-dir"
,
default
=
".
/direct
"
,
help
=
"Directory to write yaml files to"
"--output-dir"
,
default
=
"."
,
help
=
"Directory to write yaml files to"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mode"
,
"--mode"
,
default
=
"native-cot"
,
default
=
"native-cot"
,
choices
=
[
"direct"
,
"direct-native"
,
"native-cot"
,
"en-cot"
,
"translate-direct"
],
choices
=
[
"direct"
,
"direct-native"
,
"native-cot"
,
"en-cot"
,
"translate-direct"
],
help
=
"Mode of chain-of-thought"
,
help
=
"Mode of chain-of-thought"
,
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -79,4 +194,4 @@ def main() -> None:
...
@@ -79,4 +194,4 @@ def main() -> None:
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
\ No newline at end of file
lm_eval/tasks/afrixnli/manual/direct/afrixnli_manual_direct_yaml
View file @
1b97e487
...
@@ -4,7 +4,7 @@ group:
...
@@ -4,7 +4,7 @@ group:
- afrixnli-manual
- afrixnli-manual
dataset_path: masakhane/afrixnli
dataset_path: masakhane/afrixnli
dataset_name: null
dataset_name: null
output_type:
generate_until
output_type:
multiple_choice
validation_split: validation
validation_split: validation
test_split: test
test_split: test
fewshot_split: validation
fewshot_split: validation
...
@@ -16,14 +16,6 @@ doc_to_choice:
...
@@ -16,14 +16,6 @@ doc_to_choice:
- "contradiction"
- "contradiction"
should_decontaminate: true
should_decontaminate: true
doc_to_decontamination_query: premise
doc_to_decontamination_query: premise
filter_list:
- name: "verbalizer_extract"
filter:
- function: verbalizer
verbalizer_dict: {
"entailment": ['encouragement', 'entitlement', 'entails', 'entailed', 'entailment'],
"contradiction": ['contradictory', 'contradicts', 'contradiction'],
"neutral": ['neutral']}
metric_list:
metric_list:
- metric: f1
- metric: f1
aggregation: !function utils.weighted_f1_score
aggregation: !function utils.weighted_f1_score
...
@@ -32,7 +24,7 @@ metric_list:
...
@@ -32,7 +24,7 @@ metric_list:
ignore_case: true
ignore_case: true
ignore_punctuation: true
ignore_punctuation: true
- metric: acc
- metric: acc
aggregation:
!function utils.manual_accuracy_score
aggregation:
mean
higher_is_better: true
higher_is_better: true
ignore_case: true
ignore_case: true
ignore_punctuation: true
ignore_punctuation: true
...
...
lm_eval/tasks/afrixnli/manual/direct/utils.py
View file @
1b97e487
from
sklearn.metrics
import
f1_score
,
accuracy_score
from
sklearn.metrics
import
f1_score
def
doc_to_text
(
doc
):
def
doc_to_text
(
doc
):
...
@@ -30,12 +30,3 @@ def weighted_f1_score(items):
...
@@ -30,12 +30,3 @@ def weighted_f1_score(items):
preds
=
unzipped_list
[
1
]
preds
=
unzipped_list
[
1
]
fscore
=
f1_score
(
golds
,
preds
,
average
=
"weighted"
)
fscore
=
f1_score
(
golds
,
preds
,
average
=
"weighted"
)
return
fscore
return
fscore
def
manual_accuracy_score
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
accuracy
=
accuracy_score
(
golds
,
preds
)
return
accuracy
lm_eval/tasks/afrixnli/manual/translate/afrixnli_manual_translate_yaml
View file @
1b97e487
...
@@ -5,7 +5,7 @@ group:
...
@@ -5,7 +5,7 @@ group:
- afrixnli-translate-test
- afrixnli-translate-test
dataset_path: masakhane/afrixnli-translate-test
dataset_path: masakhane/afrixnli-translate-test
dataset_name: null
dataset_name: null
output_type:
generate_until
output_type:
multiple_choice
test_split: test
test_split: test
doc_to_text: !function utils.doc_to_text
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
doc_to_target: !function utils.doc_to_target
...
@@ -15,23 +15,15 @@ doc_to_choice:
...
@@ -15,23 +15,15 @@ doc_to_choice:
- "contradiction"
- "contradiction"
should_decontaminate: true
should_decontaminate: true
doc_to_decontamination_query: premise
doc_to_decontamination_query: premise
filter_list:
- name: "verbalizer_extract"
filter:
- function: verbalizer
verbalizer_dict: {
"entailment": ['encouragement', 'entitlement', 'entails', 'entailed', 'entailment'],
"contradiction": ['contradictory', 'contradicts', 'contradiction'],
"neutral": ['neutral']}
metric_list:
metric_list:
- metric: f1
- metric: f1
aggregation:
aggregation:
!function utils.weighted_f1_score
average: weighted
average: weighted
higher_is_better: True
higher_is_better: True
ignore_case: true
ignore_case: true
ignore_punctuation: true
ignore_punctuation: true
- metric: acc
- metric: acc
aggregation:
!function utils.manual_accuracy_score
aggregation:
mean
higher_is_better: true
higher_is_better: true
ignore_case: true
ignore_case: true
ignore_punctuation: true
ignore_punctuation: true
...
...
lm_eval/tasks/afrixnli/manual/translate/utils.py
View file @
1b97e487
from
sklearn.metrics
import
f1_score
,
accuracy_score
from
sklearn.metrics
import
f1_score
def
doc_to_text
(
doc
):
def
doc_to_text
(
doc
):
...
@@ -30,12 +30,3 @@ def weighted_f1_score(items):
...
@@ -30,12 +30,3 @@ def weighted_f1_score(items):
preds
=
unzipped_list
[
1
]
preds
=
unzipped_list
[
1
]
fscore
=
f1_score
(
golds
,
preds
,
average
=
"weighted"
)
fscore
=
f1_score
(
golds
,
preds
,
average
=
"weighted"
)
return
fscore
return
fscore
def
manual_accuracy_score
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
accuracy
=
accuracy_score
(
golds
,
preds
)
return
accuracy
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment